diff --git a/mo_gymnasium/utils.py b/mo_gymnasium/utils.py index 3fc6d0af..def90471 100644 --- a/mo_gymnasium/utils.py +++ b/mo_gymnasium/utils.py @@ -41,7 +41,7 @@ def __init__(self, env: gym.Env, weight: np.ndarray = None): gym.utils.RecordConstructorArgs.__init__(self, weight=weight) gym.Wrapper.__init__(self, env) if weight is None: - weight = np.ones(shape=env.reward_space.shape) + weight = np.ones(shape=env.unwrapped.reward_space.shape) self.set_weight(weight) def set_weight(self, weight: np.ndarray): @@ -51,7 +51,7 @@ def set_weight(self, weight: np.ndarray): weight: new weights to set Returns: nothing """ - assert weight.shape == self.env.reward_space.shape, "Reward weight has different shape than reward vector." + assert weight.shape == self.env.unwrapped.reward_space.shape, "Reward weight has different shape than reward vector." self.w = weight def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: @@ -169,7 +169,7 @@ def __init__( """ SyncVectorEnv.__init__(self, env_fns, copy=copy) # Just overrides the rewards memory to add the number of objectives - self.reward_space = self.envs[0].reward_space + self.reward_space = self.envs[0].unwrapped.reward_space self._rewards = np.zeros( ( self.num_envs, @@ -222,7 +222,7 @@ def __init__(self, env: gym.Env, gamma: float = 1.0, deque_size: int = 100): RecordEpisodeStatistics.__init__(self, env, deque_size=deque_size) # CHANGE: Here we just override the standard implementation to extend to MO # We also take care of the case where the env is vectorized - self.reward_dim = self.env.reward_space.shape[0] + self.reward_dim = self.env.unwrapped.reward_space.shape[0] if self.is_vector_env: self.rewards_shape = (self.num_envs, self.reward_dim) else: @@ -344,7 +344,7 @@ def step(self, action): Returns: Max of the last two observations, reward, terminated, truncated, and info from the environment """ - total_reward = np.zeros(self.env.reward_dim, dtype=np.float32) + total_reward = np.zeros(self.env.unwrapped.reward_dim, dtype=np.float32) terminated = truncated = False info = {} for i in range(self._skip): diff --git a/tests/test_envs.py b/tests/test_envs.py index 5c3a8bf9..4443789c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -100,13 +100,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec): def _test_reward_bounds(env: gym.Env): """Test that the reward bounds are respected.""" - assert env.reward_dim is not None - assert env.reward_space is not None + assert env.unwrapped.reward_dim is not None + assert env.unwrapped.reward_space is not None env.reset() for _ in range(NUM_STEPS): action = env.action_space.sample() _, reward, terminated, truncated, _ = env.step(action) - assert env.reward_space.contains(reward) + assert env.unwrapped.reward_space.contains(reward) if terminated or truncated: env.reset()