Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gymwrapper: support both gym and gymnasium, and support dict_obs and … #547

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions robosuite/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@
interface.
"""

import gymnasium as gym
import numpy as np
from gymnasium import Env, spaces

try:
import gymnasium as gym
from gymnasium import spaces
except ImportError:
# Most APIs between gym and gymnasium are compatible
print("WARNING! gymnasium is not installed. We will try to use openai gym instead.")
import gym
from gym import spaces

if not gym.__version__ >= "0.26.0":
# Due to API Changes in gym>=0.26.0, we need to ensure that the version is correct
# Please check: https://github.com/openai/gym/releases/tag/0.26.0
raise ImportError("Please ensure version of gym>=0.26.0 to use the GymWrapper.")

from robosuite.wrappers import Wrapper

Expand All @@ -23,12 +35,14 @@ class GymWrapper(Wrapper, gym.Env):
keys (None or list of str): If provided, each observation will
consist of concatenated keys from the wrapped environment's
observation dictionary. Defaults to proprio-state and object-state.
flatten_obs (bool):
Whether to flatten the observation dictionary into a 1d array. Defaults to True.

Raises:
AssertionError: [Object observations must be enabled if no keys]
"""

def __init__(self, env, keys=None):
def __init__(self, env, keys=None, flatten_obs=True):
# Run super method
super().__init__(env=env)
# Create name for gym
Expand Down Expand Up @@ -56,12 +70,32 @@ def __init__(self, env, keys=None):

# set up observation and action spaces
obs = self.env.reset()
self.modality_dims = {key: obs[key].shape for key in self.keys}
flat_ob = self._flatten_obs(obs)
self.obs_dim = flat_ob.size
high = np.inf * np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low, high)

# Whether to flatten the observation space
self.flatten_obs: bool = flatten_obs

if self.flatten_obs:
flat_ob = self._flatten_obs(obs)
self.obs_dim = flat_ob.size
high = np.inf * np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low, high)
else:

def get_box_space(sample):
"""Util fn to obtain the space of a single numpy sample data"""
if np.issubdtype(sample.dtype, np.integer):
low = np.iinfo(sample.dtype).min
high = np.iinfo(sample.dtype).max
elif np.issubdtype(sample.dtype, np.inexact):
low = float("-inf")
high = float("inf")
else:
raise ValueError()
return spaces.Box(low=low, high=high, shape=sample.shape, dtype=sample.dtype)

self.observation_space = spaces.Dict({key: get_box_space(obs[key]) for key in self.keys})

low, high = self.env.action_spec
self.action_space = spaces.Box(low, high)

Expand All @@ -84,13 +118,19 @@ def _flatten_obs(self, obs_dict, verbose=False):
ob_lst.append(np.array(obs_dict[key]).flatten())
return np.concatenate(ob_lst)

def _filter_obs(self, obs_dict) -> dict:
"""
Filters keys of interest out of the observation dictionary, returning a filterd dictionary.
"""
return {key: obs_dict[key] for key in self.keys if key in obs_dict}

def reset(self, seed=None, options=None):
"""
Extends env reset method to return flattened observation instead of normal OrderedDict and optionally resets seed
Extends env reset method to return observation instead of normal OrderedDict and optionally resets seed

Returns:
2-tuple:
- (np.array) flattened observations from the environment
- (np.array) observations from the environment
- (dict) an empty dictionary, as part of the standard return format
"""
if seed is not None:
Expand All @@ -99,26 +139,28 @@ def reset(self, seed=None, options=None):
else:
raise TypeError("Seed must be an integer type!")
ob_dict = self.env.reset()
return self._flatten_obs(ob_dict), {}
obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict)
return obs, {}

def step(self, action):
"""
Extends vanilla step() function call to return flattened observation instead of normal OrderedDict.
Extends vanilla step() function call to return observation instead of normal OrderedDict.

Args:
action (np.array): Action to take in environment

Returns:
4-tuple:

- (np.array) flattened observations from the environment
- (np.array) observations from the environment
- (float) reward from the environment
- (bool) episode ending after reaching an env terminal state
- (bool) episode ending after an externally defined condition
- (dict) misc information
"""
ob_dict, reward, terminated, info = self.env.step(action)
return self._flatten_obs(ob_dict), reward, terminated, False, info
obs = self._flatten_obs(ob_dict) if self.flatten_obs else self._filter_obs(ob_dict)
return obs, reward, terminated, False, info

def compute_reward(self, achieved_goal, desired_goal, info):
"""
Expand Down
Loading