diff --git a/onnx2pytorch/utils.py b/onnx2pytorch/utils.py index 869e0e1..395b7de 100644 --- a/onnx2pytorch/utils.py +++ b/onnx2pytorch/utils.py @@ -169,7 +169,7 @@ def get_activation_value(onnx_model, inputs, activation_names): onnx.save(onnx_model, buffer) buffer.seek(0) onnx_model_new = onnx.load(buffer) - sess = ort.InferenceSession(onnx_model_new.SerializeToString()) + sess = ort.InferenceSession(onnx_model_new.SerializeToString(), providers=["CPUExecutionProvider"]) input_names = [x.name for x in sess.get_inputs()] if not isinstance(inputs, list): @@ -190,7 +190,7 @@ def get_inputs_sample(onnx_model, to_torch=False): """Get inputs sample from onnx model.""" assert ort is not None, "onnxruntime needed. pip install onnxruntime" - sess = ort.InferenceSession(onnx_model.SerializeToString()) + sess = ort.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]) inputs = sess.get_inputs() input_names = get_inputs_names(onnx_model.graph) input_tensors = [ diff --git a/tests/onnx2pytorch/conftest.py b/tests/onnx2pytorch/conftest.py index 711560d..2916552 100644 --- a/tests/onnx2pytorch/conftest.py +++ b/tests/onnx2pytorch/conftest.py @@ -33,6 +33,6 @@ def onnx_inputs(onnx_model): @pytest.fixture def onnx_model_outputs(onnx_model_path, onnx_model, onnx_inputs): - ort_session = ort.InferenceSession(onnx_model_path) + ort_session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) onnx_output = ort_session.run(None, onnx_inputs) return onnx_output diff --git a/tests/onnx2pytorch/convert/test_loop.py b/tests/onnx2pytorch/convert/test_loop.py index d2e6890..d29142a 100644 --- a/tests/onnx2pytorch/convert/test_loop.py +++ b/tests/onnx2pytorch/convert/test_loop.py @@ -128,7 +128,7 @@ def test_loop_sum(): exp_res_y = np.array([13]).astype(np.float32) exp_res_scan = np.array([-1, 1, 4, 8, 13]).astype(np.float32).reshape((5, 1)) - ort_session = ort.InferenceSession(bitstream_data) + ort_session = ort.InferenceSession(bitstream_data, providers=["CPUExecutionProvider"]) ort_inputs = {"trip_count": trip_count_input, "cond": cond_input, "y": y_input} ort_outputs = ort_session.run(None, ort_inputs) ort_res_y, ort_res_scan = ort_outputs diff --git a/tests/onnx2pytorch/test_onnx2pytorch.py b/tests/onnx2pytorch/test_onnx2pytorch.py index 29f7053..a474624 100644 --- a/tests/onnx2pytorch/test_onnx2pytorch.py +++ b/tests/onnx2pytorch/test_onnx2pytorch.py @@ -50,7 +50,7 @@ def test_onnx2pytorch2onnx(onnx_model, onnx_model_outputs, onnx_inputs): onnx_model = onnx.ModelProto.FromString(bitstream.getvalue()) onnx.checker.check_model(onnx_model) - ort_session = ort.InferenceSession(bitstream.getvalue()) + ort_session = ort.InferenceSession(bitstream.getvalue(), providers=["CPUExecutionProvider"]) outputs = ort_session.run(None, onnx_inputs) for output, onnx_model_output in zip(outputs, onnx_model_outputs):