Skip to content

Commit

Permalink
Fix test_dict_info_to_list test on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
traversaro authored May 24, 2024
1 parent 596bdcc commit 23f8958
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/wrappers/vector/test_dict_info_to_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ def test_update_info():
"_e": np.array([True]),
}
_, list_info = env.reset(options=vector_infos)

# The return dtype of np.array([0]) is platform dependent
np_array_int_default_dtype = np.array([0]).dtype.type

expected_list_info = [
{
"a": np.int64(0),
"a": np_array_int_default_dtype(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
Expand Down Expand Up @@ -90,21 +94,21 @@ def test_update_info():
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{
"a": np.int64(0),
"a": np_array_int_default_dtype(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(1),
},
{
"a": np.int64(1),
"a": np_array_int_default_dtype(1),
"b": np.float64(1.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(2),
},
{
"a": np.int64(2),
"a": np_array_int_default_dtype(2),
"b": np.float64(2.0),
"c": None,
"d": np.zeros((2,)),
Expand Down Expand Up @@ -134,7 +138,7 @@ def test_update_info():
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"a": np.int64(1), "b": np.float64(1.0)},
{"a": np_array_int_default_dtype(1), "b": np.float64(1.0)},
{"c": None, "d": np.zeros((2,))},
{"e": Discrete(3)},
]
Expand All @@ -156,8 +160,8 @@ def test_update_info():
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"episode": {"a": np.int64(1), "b": np.float64(1.0)}},
{"episode": {"a": np.int64(2), "b": np.float64(2.0)}, "a": np.int64(1)},
{"a": np.int64(2)},
{"episode": {"a": np_array_int_default_dtype(1), "b": np.float64(1.0)}},
{"episode": {"a": np_array_int_default_dtype(2), "b": np.float64(2.0)}, "a": np_array_int_default_dtype(1)},
{"a": np_array_int_default_dtype(2)},
]
assert data_equivalence(list_info, expected_list_info)

0 comments on commit 23f8958

Please sign in to comment.