Skip to content

Commit

Permalink
Merge pull request #547 from youliangtan/enhance-gym-wrapper
Browse files Browse the repository at this point in the history
Gymwrapper: support both gym and gymnasium, and support dict_obs and …
  • Loading branch information
kevin-thankyou-lin authored Nov 6, 2024
2 parents 0d082f8 + 3580e51 commit b81bde4
Showing 1 changed file with 57 additions and 15 deletions.
72 changes: 57 additions & 15 deletions robosuite/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@
interface.
"""

import gymnasium as gym
import numpy as np
from gymnasium import Env, spaces

try:
import gymnasium as gym
from gymnasium import spaces
except ImportError:
# Most APIs between gym and gymnasium are compatible
print("WARNING! gymnasium is not installed. We will try to use openai gym instead.")
import gym
from gym import spaces

if not gym.__version__ >= "0.26.0":
# Due to API Changes in gym>=0.26.0, we need to ensure that the version is correct
# Please check: https://github.com/openai/gym/releases/tag/0.26.0
raise ImportError("Please ensure version of gym>=0.26.0 to use the GymWrapper.")

from robosuite.wrappers import Wrapper

Expand All @@ -23,12 +35,14 @@ class GymWrapper(Wrapper, gym.Env):
keys (None or list of str): If provided, each observation will
consist of concatenated keys from the wrapped environment's
observation dictionary. Defaults to proprio-state and object-state.
flatten_obs (bool):
Whether to flatten the observation dictionary into a 1d array. Defaults to True.
Raises:
AssertionError: [Object observations must be enabled if no keys]
"""

def __init__(self, env, keys=None):
def __init__(self, env, keys=None, flatten_obs=True):
# Run super method
super().__init__(env=env)
# Create name for gym
Expand Down Expand Up @@ -56,12 +70,32 @@ def __init__(self, env, keys=None):

# set up observation and action spaces
obs = self.env.reset()
self.modality_dims = {key: obs[key].shape for key in self.keys}
flat_ob = self._flatten_obs(obs)
self.obs_dim = flat_ob.size
high = np.inf * np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low, high)

# Whether to flatten the observation space
self.flatten_obs: bool = flatten_obs

if self.flatten_obs:
flat_ob = self._flatten_obs(obs)
self.obs_dim = flat_ob.size
high = np.inf * np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low, high)
else:

def get_box_space(sample):
"""Util fn to obtain the space of a single numpy sample data"""
if np.issubdtype(sample.dtype, np.integer):
low = np.iinfo(sample.dtype).min
high = np.iinfo(sample.dtype).max
elif np.issubdtype(sample.dtype, np.inexact):
low = float("-inf")
high = float("inf")
else:
raise ValueError()
return spaces.Box(low=low, high=high, shape=sample.shape, dtype=sample.dtype)

self.observation_space = spaces.Dict({key: get_box_space(obs[key]) for key in self.keys})

low, high = self.env.action_spec
self.action_space = spaces.Box(low, high)

Expand All @@ -84,13 +118,19 @@ def _flatten_obs(self, obs_dict, verbose=False):
ob_lst.append(np.array(obs_dict[key]).flatten())
return np.concatenate(ob_lst)

def _filter_obs(self, obs_dict) -> dict:
"""
Filters keys of interest out of the observation dictionary, returning a filterd dictionary.
"""
return {key: obs_dict[key] for key in self.keys if key in obs_dict}

def reset(self, seed=None, options=None):
"""
Extends env reset method to return flattened observation instead of normal OrderedDict and optionally resets seed
Extends env reset method to return observation instead of normal OrderedDict and optionally resets seed
Returns:
2-tuple:
- (np.array) flattened observations from the environment
- (np.array) observations from the environment
- (dict) an empty dictionary, as part of the standard return format
"""
if seed is not None:
Expand All @@ -99,26 +139,28 @@ def reset(self, seed=None, options=None):
else:
raise TypeError("Seed must be an integer type!")
ob_dict = self.env.reset()
return self._flatten_obs(ob_dict), {}
obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict)
return obs, {}

def step(self, action):
"""
Extends vanilla step() function call to return flattened observation instead of normal OrderedDict.
Extends vanilla step() function call to return observation instead of normal OrderedDict.
Args:
action (np.array): Action to take in environment
Returns:
4-tuple:
- (np.array) flattened observations from the environment
- (np.array) observations from the environment
- (float) reward from the environment
- (bool) episode ending after reaching an env terminal state
- (bool) episode ending after an externally defined condition
- (dict) misc information
"""
ob_dict, reward, terminated, info = self.env.step(action)
return self._flatten_obs(ob_dict), reward, terminated, False, info
obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict)
return obs, reward, terminated, False, info

def compute_reward(self, achieved_goal, desired_goal, info):
"""
Expand Down

0 comments on commit b81bde4

Please sign in to comment.