Skip to content

Commit

Permalink
Add seeded_rand_vec flag (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner authored Jun 21, 2022
1 parent b0b66d1 commit 18118a2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
2 changes: 2 additions & 0 deletions metaworld/envs/mujoco/env_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def initialize(env, seed=None):
env.reset()
env._freeze_rand_vec = True
if seed is not None:
env.seed(seed)
np.random.set_state(st0)

d['__init__'] = initialize
Expand Down Expand Up @@ -622,6 +623,7 @@ def initialize(env, seed=None):
env.reset()
env._freeze_rand_vec = True
if seed is not None:
env.seed(seed)
np.random.set_state(st0)

d['__init__'] = initialize
Expand Down
8 changes: 6 additions & 2 deletions metaworld/envs/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ def __init__(self, model_path, frame_skip):

self._did_see_sim_exception = False

self.seed()
self.np_random, _ = seeding.np_random(None)

def seed(self, seed=None):
def seed(self, seed):
assert seed is not None
self.np_random, seed = seeding.np_random(seed)
self.action_space.seed(seed)
self.observation_space.seed(seed)
self.goal_space.seed(seed)
return [seed]

@abc.abstractmethod
Expand Down
8 changes: 8 additions & 0 deletions metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
self.mocap_low = np.hstack(mocap_low)
self.mocap_high = np.hstack(mocap_high)
self.curr_path_length = 0
self.seeded_rand_vec = False
self._freeze_rand_vec = True
self._last_rand_vec = None

Expand Down Expand Up @@ -469,6 +470,13 @@ def _get_state_rand_vec(self):
if self._freeze_rand_vec:
assert self._last_rand_vec is not None
return self._last_rand_vec
elif self.seeded_rand_vec:
rand_vec = self.np_random.uniform(
self._random_reset_space.low,
self._random_reset_space.high,
size=self._random_reset_space.low.size)
self._last_rand_vec = rand_vec
return rand_vec
else:
rand_vec = np.random.uniform(
self._random_reset_space.low,
Expand Down
29 changes: 29 additions & 0 deletions tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import random

import pytest
import numpy as np

from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE


@pytest.mark.parametrize(
'env_name',
sorted(ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys()))
def test_observations_match(env_name):
seed = random.randrange(1000)
env1 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
env1.seeded_rand_vec = True
env2 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
env2.seeded_rand_vec = True

obs1, obs2 = env1.reset(), env2.reset()
assert (obs1 == obs2).all()

for i in range(env1.max_path_length):
a = np.random.uniform(low=-1, high=-1, size=4)
obs1, r1, done1, _ = env1.step(a)
obs2, r2, done2, _ = env2.step(a)
assert (obs1 == obs2).all()
assert r1 == r2
assert not done1
assert not done2

0 comments on commit 18118a2

Please sign in to comment.