From 3580e51ad070d242a58d598b58233c3f6d0ca6f5 Mon Sep 17 00:00:00 2001 From: You Liang Tan Date: Sun, 3 Nov 2024 23:59:47 -0800 Subject: [PATCH] Gymwrapper: support both gym and gymnasium, and support dict_obs and flatten_obs Signed-off-by: You Liang Tan --- robosuite/wrappers/gym_wrapper.py | 72 ++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/robosuite/wrappers/gym_wrapper.py b/robosuite/wrappers/gym_wrapper.py index f1da6100c1..51bb6f5efb 100644 --- a/robosuite/wrappers/gym_wrapper.py +++ b/robosuite/wrappers/gym_wrapper.py @@ -4,9 +4,21 @@ interface. """ -import gymnasium as gym import numpy as np -from gymnasium import Env, spaces + +try: + import gymnasium as gym + from gymnasium import spaces +except ImportError: + # Most APIs between gym and gymnasium are compatible + print("WARNING! gymnasium is not installed. We will try to use openai gym instead.") + import gym + from gym import spaces + + if not gym.__version__ >= "0.26.0": + # Due to API Changes in gym>=0.26.0, we need to ensure that the version is correct + # Please check: https://github.com/openai/gym/releases/tag/0.26.0 + raise ImportError("Please ensure version of gym>=0.26.0 to use the GymWrapper.") from robosuite.wrappers import Wrapper @@ -23,12 +35,14 @@ class GymWrapper(Wrapper, gym.Env): keys (None or list of str): If provided, each observation will consist of concatenated keys from the wrapped environment's observation dictionary. Defaults to proprio-state and object-state. + flatten_obs (bool): + Whether to flatten the observation dictionary into a 1d array. Defaults to True. Raises: AssertionError: [Object observations must be enabled if no keys] """ - def __init__(self, env, keys=None): + def __init__(self, env, keys=None, flatten_obs=True): # Run super method super().__init__(env=env) # Create name for gym @@ -56,12 +70,32 @@ def __init__(self, env, keys=None): # set up observation and action spaces obs = self.env.reset() - self.modality_dims = {key: obs[key].shape for key in self.keys} - flat_ob = self._flatten_obs(obs) - self.obs_dim = flat_ob.size - high = np.inf * np.ones(self.obs_dim) - low = -high - self.observation_space = spaces.Box(low, high) + + # Whether to flatten the observation space + self.flatten_obs: bool = flatten_obs + + if self.flatten_obs: + flat_ob = self._flatten_obs(obs) + self.obs_dim = flat_ob.size + high = np.inf * np.ones(self.obs_dim) + low = -high + self.observation_space = spaces.Box(low, high) + else: + + def get_box_space(sample): + """Util fn to obtain the space of a single numpy sample data""" + if np.issubdtype(sample.dtype, np.integer): + low = np.iinfo(sample.dtype).min + high = np.iinfo(sample.dtype).max + elif np.issubdtype(sample.dtype, np.inexact): + low = float("-inf") + high = float("inf") + else: + raise ValueError() + return spaces.Box(low=low, high=high, shape=sample.shape, dtype=sample.dtype) + + self.observation_space = spaces.Dict({key: get_box_space(obs[key]) for key in self.keys}) + low, high = self.env.action_spec self.action_space = spaces.Box(low, high) @@ -84,13 +118,19 @@ def _flatten_obs(self, obs_dict, verbose=False): ob_lst.append(np.array(obs_dict[key]).flatten()) return np.concatenate(ob_lst) + def _filter_obs(self, obs_dict) -> dict: + """ + Filters keys of interest out of the observation dictionary, returning a filterd dictionary. + """ + return {key: obs_dict[key] for key in self.keys if key in obs_dict} + def reset(self, seed=None, options=None): """ - Extends env reset method to return flattened observation instead of normal OrderedDict and optionally resets seed + Extends env reset method to return observation instead of normal OrderedDict and optionally resets seed Returns: 2-tuple: - - (np.array) flattened observations from the environment + - (np.array) observations from the environment - (dict) an empty dictionary, as part of the standard return format """ if seed is not None: @@ -99,11 +139,12 @@ def reset(self, seed=None, options=None): else: raise TypeError("Seed must be an integer type!") ob_dict = self.env.reset() - return self._flatten_obs(ob_dict), {} + obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict) + return obs, {} def step(self, action): """ - Extends vanilla step() function call to return flattened observation instead of normal OrderedDict. + Extends vanilla step() function call to return observation instead of normal OrderedDict. Args: action (np.array): Action to take in environment @@ -111,14 +152,15 @@ def step(self, action): Returns: 4-tuple: - - (np.array) flattened observations from the environment + - (np.array) observations from the environment - (float) reward from the environment - (bool) episode ending after reaching an env terminal state - (bool) episode ending after an externally defined condition - (dict) misc information """ ob_dict, reward, terminated, info = self.env.step(action) - return self._flatten_obs(ob_dict), reward, terminated, False, info + obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict) + return obs, reward, terminated, False, info def compute_reward(self, achieved_goal, desired_goal, info): """