summarylogtreecommitdiffstats
path: root/onnxruntime-specify-provider.diff
diff options
context:
space:
mode:
authorChih-Hsuan Yen2022-05-05 20:04:14 +0800
committerChih-Hsuan Yen2022-05-05 20:04:14 +0800
commit7b21f8b1fb2fb1307887c1055bb1f21d4b253047 (patch)
treeab03a5050c01804028433b698b075bbe8b390a73 /onnxruntime-specify-provider.diff
downloadaur-python-onnx2pytorch.tar.gz
new package
Diffstat (limited to 'onnxruntime-specify-provider.diff')
-rw-r--r--onnxruntime-specify-provider.diff60
1 files changed, 60 insertions, 0 deletions
diff --git a/onnxruntime-specify-provider.diff b/onnxruntime-specify-provider.diff
new file mode 100644
index 000000000000..f3542496d1a6
--- /dev/null
+++ b/onnxruntime-specify-provider.diff
@@ -0,0 +1,60 @@
+diff --git a/onnx2pytorch/utils.py b/onnx2pytorch/utils.py
+index 869e0e1..395b7de 100644
+--- a/onnx2pytorch/utils.py
++++ b/onnx2pytorch/utils.py
+@@ -169,7 +169,7 @@ def get_activation_value(onnx_model, inputs, activation_names):
+ onnx.save(onnx_model, buffer)
+ buffer.seek(0)
+ onnx_model_new = onnx.load(buffer)
+- sess = ort.InferenceSession(onnx_model_new.SerializeToString())
++ sess = ort.InferenceSession(onnx_model_new.SerializeToString(), providers=["CPUExecutionProvider"])
+
+ input_names = [x.name for x in sess.get_inputs()]
+ if not isinstance(inputs, list):
+@@ -190,7 +190,7 @@ def get_inputs_sample(onnx_model, to_torch=False):
+ """Get inputs sample from onnx model."""
+ assert ort is not None, "onnxruntime needed. pip install onnxruntime"
+
+- sess = ort.InferenceSession(onnx_model.SerializeToString())
++ sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
+ inputs = sess.get_inputs()
+ input_names = get_inputs_names(onnx_model.graph)
+ input_tensors = [
+diff --git a/tests/onnx2pytorch/conftest.py b/tests/onnx2pytorch/conftest.py
+index 711560d..2916552 100644
+--- a/tests/onnx2pytorch/conftest.py
++++ b/tests/onnx2pytorch/conftest.py
+@@ -33,6 +33,6 @@ def onnx_inputs(onnx_model):
+
+ @pytest.fixture
+ def onnx_model_outputs(onnx_model_path, onnx_model, onnx_inputs):
+- ort_session = ort.InferenceSession(onnx_model_path)
++ ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
+ onnx_output = ort_session.run(None, onnx_inputs)
+ return onnx_output
+diff --git a/tests/onnx2pytorch/convert/test_loop.py b/tests/onnx2pytorch/convert/test_loop.py
+index d2e6890..d29142a 100644
+--- a/tests/onnx2pytorch/convert/test_loop.py
++++ b/tests/onnx2pytorch/convert/test_loop.py
+@@ -128,7 +128,7 @@ def test_loop_sum():
+ exp_res_y = np.array([13]).astype(np.float32)
+ exp_res_scan = np.array([-1, 1, 4, 8, 13]).astype(np.float32).reshape((5, 1))
+
+- ort_session = ort.InferenceSession(bitstream_data)
++ ort_session = ort.InferenceSession(bitstream_data, providers=["CPUExecutionProvider"])
+ ort_inputs = {"trip_count": trip_count_input, "cond": cond_input, "y": y_input}
+ ort_outputs = ort_session.run(None, ort_inputs)
+ ort_res_y, ort_res_scan = ort_outputs
+diff --git a/tests/onnx2pytorch/test_onnx2pytorch.py b/tests/onnx2pytorch/test_onnx2pytorch.py
+index 29f7053..a474624 100644
+--- a/tests/onnx2pytorch/test_onnx2pytorch.py
++++ b/tests/onnx2pytorch/test_onnx2pytorch.py
+@@ -50,7 +50,7 @@ def test_onnx2pytorch2onnx(onnx_model, onnx_model_outputs, onnx_inputs):
+ onnx_model = onnx.ModelProto.FromString(bitstream.getvalue())
+ onnx.checker.check_model(onnx_model)
+
+- ort_session = ort.InferenceSession(bitstream.getvalue())
++ ort_session = ort.InferenceSession(bitstream.getvalue(), providers=["CPUExecutionProvider"])
+ outputs = ort_session.run(None, onnx_inputs)
+
+ for output, onnx_model_output in zip(outputs, onnx_model_outputs):