summarylogtreecommitdiffstats
path: root/onnxruntime-specify-provider.diff
blob: f3542496d1a69f9746af504c61d4baad720a6328 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
diff --git a/onnx2pytorch/utils.py b/onnx2pytorch/utils.py
index 869e0e1..395b7de 100644
--- a/onnx2pytorch/utils.py
+++ b/onnx2pytorch/utils.py
@@ -169,7 +169,7 @@ def get_activation_value(onnx_model, inputs, activation_names):
     onnx.save(onnx_model, buffer)
     buffer.seek(0)
     onnx_model_new = onnx.load(buffer)
-    sess = ort.InferenceSession(onnx_model_new.SerializeToString())
+    sess = ort.InferenceSession(onnx_model_new.SerializeToString(), providers=["CPUExecutionProvider"])
 
     input_names = [x.name for x in sess.get_inputs()]
     if not isinstance(inputs, list):
@@ -190,7 +190,7 @@ def get_inputs_sample(onnx_model, to_torch=False):
     """Get inputs sample from onnx model."""
     assert ort is not None, "onnxruntime needed. pip install onnxruntime"
 
-    sess = ort.InferenceSession(onnx_model.SerializeToString())
+    sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
     inputs = sess.get_inputs()
     input_names = get_inputs_names(onnx_model.graph)
     input_tensors = [
diff --git a/tests/onnx2pytorch/conftest.py b/tests/onnx2pytorch/conftest.py
index 711560d..2916552 100644
--- a/tests/onnx2pytorch/conftest.py
+++ b/tests/onnx2pytorch/conftest.py
@@ -33,6 +33,6 @@ def onnx_inputs(onnx_model):
 
 @pytest.fixture
 def onnx_model_outputs(onnx_model_path, onnx_model, onnx_inputs):
-    ort_session = ort.InferenceSession(onnx_model_path)
+    ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
     onnx_output = ort_session.run(None, onnx_inputs)
     return onnx_output
diff --git a/tests/onnx2pytorch/convert/test_loop.py b/tests/onnx2pytorch/convert/test_loop.py
index d2e6890..d29142a 100644
--- a/tests/onnx2pytorch/convert/test_loop.py
+++ b/tests/onnx2pytorch/convert/test_loop.py
@@ -128,7 +128,7 @@ def test_loop_sum():
     exp_res_y = np.array([13]).astype(np.float32)
     exp_res_scan = np.array([-1, 1, 4, 8, 13]).astype(np.float32).reshape((5, 1))
 
-    ort_session = ort.InferenceSession(bitstream_data)
+    ort_session = ort.InferenceSession(bitstream_data, providers=["CPUExecutionProvider"])
     ort_inputs = {"trip_count": trip_count_input, "cond": cond_input, "y": y_input}
     ort_outputs = ort_session.run(None, ort_inputs)
     ort_res_y, ort_res_scan = ort_outputs
diff --git a/tests/onnx2pytorch/test_onnx2pytorch.py b/tests/onnx2pytorch/test_onnx2pytorch.py
index 29f7053..a474624 100644
--- a/tests/onnx2pytorch/test_onnx2pytorch.py
+++ b/tests/onnx2pytorch/test_onnx2pytorch.py
@@ -50,7 +50,7 @@ def test_onnx2pytorch2onnx(onnx_model, onnx_model_outputs, onnx_inputs):
         onnx_model = onnx.ModelProto.FromString(bitstream.getvalue())
         onnx.checker.check_model(onnx_model)
 
-    ort_session = ort.InferenceSession(bitstream.getvalue())
+    ort_session = ort.InferenceSession(bitstream.getvalue(), providers=["CPUExecutionProvider"])
     outputs = ort_session.run(None, onnx_inputs)
 
     for output, onnx_model_output in zip(outputs, onnx_model_outputs):