diff options
author | Chih-Hsuan Yen | 2022-05-05 20:04:14 +0800 |
---|---|---|
committer | Chih-Hsuan Yen | 2022-05-05 20:04:14 +0800 |
commit | 7b21f8b1fb2fb1307887c1055bb1f21d4b253047 (patch) | |
tree | ab03a5050c01804028433b698b075bbe8b390a73 /onnxruntime-specify-provider.diff | |
download | aur-python-onnx2pytorch.tar.gz |
new package
Diffstat (limited to 'onnxruntime-specify-provider.diff')
-rw-r--r-- | onnxruntime-specify-provider.diff | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/onnxruntime-specify-provider.diff b/onnxruntime-specify-provider.diff new file mode 100644 index 000000000000..f3542496d1a6 --- /dev/null +++ b/onnxruntime-specify-provider.diff @@ -0,0 +1,60 @@ +diff --git a/onnx2pytorch/utils.py b/onnx2pytorch/utils.py +index 869e0e1..395b7de 100644 +--- a/onnx2pytorch/utils.py ++++ b/onnx2pytorch/utils.py +@@ -169,7 +169,7 @@ def get_activation_value(onnx_model, inputs, activation_names): + onnx.save(onnx_model, buffer) + buffer.seek(0) + onnx_model_new = onnx.load(buffer) +- sess = ort.InferenceSession(onnx_model_new.SerializeToString()) ++ sess = ort.InferenceSession(onnx_model_new.SerializeToString(), providers=["CPUExecutionProvider"]) + + input_names = [x.name for x in sess.get_inputs()] + if not isinstance(inputs, list): +@@ -190,7 +190,7 @@ def get_inputs_sample(onnx_model, to_torch=False): + """Get inputs sample from onnx model.""" + assert ort is not None, "onnxruntime needed. pip install onnxruntime" + +- sess = ort.InferenceSession(onnx_model.SerializeToString()) ++ sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]) + inputs = sess.get_inputs() + input_names = get_inputs_names(onnx_model.graph) + input_tensors = [ +diff --git a/tests/onnx2pytorch/conftest.py b/tests/onnx2pytorch/conftest.py +index 711560d..2916552 100644 +--- a/tests/onnx2pytorch/conftest.py ++++ b/tests/onnx2pytorch/conftest.py +@@ -33,6 +33,6 @@ def onnx_inputs(onnx_model): + + @pytest.fixture + def onnx_model_outputs(onnx_model_path, onnx_model, onnx_inputs): +- ort_session = ort.InferenceSession(onnx_model_path) ++ ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) + onnx_output = ort_session.run(None, onnx_inputs) + return onnx_output +diff --git a/tests/onnx2pytorch/convert/test_loop.py b/tests/onnx2pytorch/convert/test_loop.py +index d2e6890..d29142a 100644 +--- a/tests/onnx2pytorch/convert/test_loop.py ++++ b/tests/onnx2pytorch/convert/test_loop.py +@@ -128,7 +128,7 @@ def test_loop_sum(): + exp_res_y = np.array([13]).astype(np.float32) + exp_res_scan = np.array([-1, 1, 4, 8, 13]).astype(np.float32).reshape((5, 1)) + +- ort_session = ort.InferenceSession(bitstream_data) ++ ort_session = ort.InferenceSession(bitstream_data, providers=["CPUExecutionProvider"]) + ort_inputs = {"trip_count": trip_count_input, "cond": cond_input, "y": y_input} + ort_outputs = ort_session.run(None, ort_inputs) + ort_res_y, ort_res_scan = ort_outputs +diff --git a/tests/onnx2pytorch/test_onnx2pytorch.py b/tests/onnx2pytorch/test_onnx2pytorch.py +index 29f7053..a474624 100644 +--- a/tests/onnx2pytorch/test_onnx2pytorch.py ++++ b/tests/onnx2pytorch/test_onnx2pytorch.py +@@ -50,7 +50,7 @@ def test_onnx2pytorch2onnx(onnx_model, onnx_model_outputs, onnx_inputs): + onnx_model = onnx.ModelProto.FromString(bitstream.getvalue()) + onnx.checker.check_model(onnx_model) + +- ort_session = ort.InferenceSession(bitstream.getvalue()) ++ ort_session = ort.InferenceSession(bitstream.getvalue(), providers=["CPUExecutionProvider"]) + outputs = ort_session.run(None, onnx_inputs) + + for output, onnx_model_output in zip(outputs, onnx_model_outputs): |