From 18118a28c06893da0f363786696cc792457b062b Mon Sep 17 00:00:00 2001 From: "K.R. Zentner" <41180126+krzentner@users.noreply.github.com> Date: Tue, 21 Jun 2022 05:43:03 -0700 Subject: [PATCH] Add seeded_rand_vec flag (#370) --- metaworld/envs/mujoco/env_dict.py | 2 ++ metaworld/envs/mujoco/mujoco_env.py | 8 +++-- .../envs/mujoco/sawyer_xyz/sawyer_xyz_env.py | 8 +++++ .../mujoco/sawyer_xyz/test_seeded_rand_vec.py | 29 +++++++++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py diff --git a/metaworld/envs/mujoco/env_dict.py b/metaworld/envs/mujoco/env_dict.py index e7f4171db..2ca6e84ef 100644 --- a/metaworld/envs/mujoco/env_dict.py +++ b/metaworld/envs/mujoco/env_dict.py @@ -592,6 +592,7 @@ def initialize(env, seed=None): env.reset() env._freeze_rand_vec = True if seed is not None: + env.seed(seed) np.random.set_state(st0) d['__init__'] = initialize @@ -622,6 +623,7 @@ def initialize(env, seed=None): env.reset() env._freeze_rand_vec = True if seed is not None: + env.seed(seed) np.random.set_state(st0) d['__init__'] = initialize diff --git a/metaworld/envs/mujoco/mujoco_env.py b/metaworld/envs/mujoco/mujoco_env.py index 47dd2a58d..94c8f121a 100644 --- a/metaworld/envs/mujoco/mujoco_env.py +++ b/metaworld/envs/mujoco/mujoco_env.py @@ -58,10 +58,14 @@ def __init__(self, model_path, frame_skip): self._did_see_sim_exception = False - self.seed() + self.np_random, _ = seeding.np_random(None) - def seed(self, seed=None): + def seed(self, seed): + assert seed is not None self.np_random, seed = seeding.np_random(seed) + self.action_space.seed(seed) + self.observation_space.seed(seed) + self.goal_space.seed(seed) return [seed] @abc.abstractmethod diff --git a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py index 091edf3a8..6ae933a37 100644 --- a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py +++ b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py @@ -111,6 +111,7 @@ def __init__( self.mocap_low = np.hstack(mocap_low) self.mocap_high = np.hstack(mocap_high) self.curr_path_length = 0 + self.seeded_rand_vec = False self._freeze_rand_vec = True self._last_rand_vec = None @@ -469,6 +470,13 @@ def _get_state_rand_vec(self): if self._freeze_rand_vec: assert self._last_rand_vec is not None return self._last_rand_vec + elif self.seeded_rand_vec: + rand_vec = self.np_random.uniform( + self._random_reset_space.low, + self._random_reset_space.high, + size=self._random_reset_space.low.size) + self._last_rand_vec = rand_vec + return rand_vec else: rand_vec = np.random.uniform( self._random_reset_space.low, diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py new file mode 100644 index 000000000..95aef7daf --- /dev/null +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py @@ -0,0 +1,29 @@ +import random + +import pytest +import numpy as np + +from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE + + +@pytest.mark.parametrize( + 'env_name', + sorted(ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys())) +def test_observations_match(env_name): + seed = random.randrange(1000) + env1 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed) + env1.seeded_rand_vec = True + env2 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed) + env2.seeded_rand_vec = True + + obs1, obs2 = env1.reset(), env2.reset() + assert (obs1 == obs2).all() + + for i in range(env1.max_path_length): + a = np.random.uniform(low=-1, high=-1, size=4) + obs1, r1, done1, _ = env1.step(a) + obs2, r2, done2, _ = env2.step(a) + assert (obs1 == obs2).all() + assert r1 == r2 + assert not done1 + assert not done2