Skip to content

Commit

Permalink
add unwrapped to env in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kallinteris-Andreas committed Feb 14, 2024
1 parent b6bbe94 commit 0db8416
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/envs/franka_kitchen/test_kitchen_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete
# Complete a task sequentially for each environment step
for task in TASKS:
# Force task to be achieved
env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
env.unwrapped.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
_, _, terminated, _, info = env.step(env.action_space.sample())
completed_tasks.add(task)

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete
# Complete a task sequentially for each environment step
for task in TASKS:
# Force task to be achieved
env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
env.unwrapped.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
completed_tasks.add(task)

_, _, terminated, _, info = env.step(env.action_space.sample())
Expand Down
2 changes: 1 addition & 1 deletion tests/envs/hand/test_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_serialize_deserialize(environment_id):
env1.reset()
env2 = pickle.loads(pickle.dumps(env1))

assert env1.target_position == env2.target_position, (
assert env1.unwrapped.target_position == env2.unwrapped.target_position, (
env1.target_position,
env2.target_position,
)
2 changes: 1 addition & 1 deletion tests/envs/hand/test_manipulate_touch_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_serialize_deserialize(environment_id):
env1.reset()
env2 = pickle.loads(pickle.dumps(env1))

assert env1.target_position == env2.target_position, (
assert env1.unwrapped.target_position == env2.unwrapped.target_position, (
env1.target_position,
env2.target_position,
)
2 changes: 1 addition & 1 deletion tests/envs/hand/test_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_serialize_deserialize():
env1.reset()
env2 = pickle.loads(pickle.dumps(env1))

assert env1.distance_threshold == env2.distance_threshold, (
assert env1.unwrapped.distance_threshold == env2.unwrapped.distance_threshold, (
env1.distance_threshold,
env2.distance_threshold,
)

0 comments on commit 0db8416

Please sign in to comment.