Skip to content

Commit

Permalink
Fix onnx test
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Sep 23, 2023
1 parent 68b25f2 commit 43e6769
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/after_training_policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ ONNX is a widely used machine learning model format that is supported by numerou
import onnxruntime as ort
# load ONNX policy via onnxruntime
ort_session = ort.InferenceSession('policy.onnx')
ort_session = ort.InferenceSession('policy.onnx', providers=["CPUExecutionProvider"])
# observation
observation = np.random.rand(1, 3).astype(np.float32)
Expand Down
5 changes: 4 additions & 1 deletion tests/algos/qlearning/algo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ def save_policy_tester(

# check save_policy as ONNX
algo.save_policy(os.path.join("test_data", "model.onnx"))
ort_session = ort.InferenceSession(os.path.join("test_data", "model.onnx"))
ort_session = ort.InferenceSession(
os.path.join("test_data", "model.onnx"),
providers=["CPUExecutionProvider"],
)
observations = np.random.rand(1, *observation_shape).astype("f4")
action = ort_session.run(None, {"input_0": observations})[0]
if algo.get_action_type() == ActionSpace.DISCRETE:
Expand Down

0 comments on commit 43e6769

Please sign in to comment.