Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement_for to account for all gym versions #118

Merged
merged 5 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions robohive/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
================================================= """

# TODO: find how to make this compatible with gymnasium. Maybe a global variable that indicates what to use as backend?
import gym
import numpy as np
import os
Expand All @@ -13,6 +14,7 @@
from robohive.envs.obj_vec_dict import ObsVecDict
from robohive.utils import tensor_utils
from robohive.robot.robot import Robot
from robohive.utils.implement_for import implement_for
from robohive.utils.prompt_utils import prompt, Prompt
import skvideo.io
from sys import platform
Expand Down Expand Up @@ -264,8 +266,23 @@ def step(self, a, **kwargs):
render_cbk=self.mj_render if self.mujoco_render_frames else None)
return self.forward(**kwargs)

@implement_for("gym", None, "0.24")
def forward(self, **kwargs):
return self._forward(**kwargs)

@implement_for("gym", "0.24", None)
def forward(self, **kwargs):
obs, reward, done, info = self._forward(**kwargs)
terminal = done
return obs, reward, terminal, False, info

@implement_for("gymnasium")
def forward(self, **kwargs):
obs, reward, done, info = self._forward(**kwargs)
terminal = done
return obs, reward, terminal, False, info

def _forward(self, **kwargs):
"""
Forward propagate env to recover env details
Returns current obs(t), rwd(t), done(t), info(t)
Expand Down Expand Up @@ -463,8 +480,7 @@ def seed(self, seed=None):
def get_input_seed(self):
return self.input_seed


def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
"""
Reset the environment
Default implemention provided. Override if env needs custom reset
Expand All @@ -473,7 +489,15 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(qpos, qvel, **kwargs)
return self.get_obs()

@implement_for("gym", None, "0.26")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs)
@implement_for("gym", "0.26", None)
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}
@implement_for("gymnasium")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}

@property
def _step(self, a):
Expand Down
69 changes: 61 additions & 8 deletions robohive/envs/env_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,59 @@
from copy import deepcopy
from flatten_dict import flatten, unflatten

from robohive.utils.implement_for import implement_for

#TODO: check versions
@implement_for("gym", None, "0.24")
def gym_registry_specs():
return gym.envs.registry.env_specs

@implement_for("gym", "0.24", None)
def gym_registry_specs():
return gym.envs.registry

@implement_for("gymnasium")
def gym_registry_specs():
return gym.envs.registry

# TODO: move to within the function?
@implement_for("gym", None, "0.24")
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys)

@implement_for("gym", "0.24", None)
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys)
return variants_update_keyval_str

@implement_for("gymnasium")
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys)
return variants_update_keyval_str

@implement_for("gym", None, "0.24")
def _entry_point(env_variant_specs):
return env_variant_specs._entry_point

@implement_for("gym", "0.24", None)
def _entry_point(env_variant_specs):
return env_variant_specs.entry_point

@implement_for("gymnasium")
def _entry_point(env_variant_specs):
return env_variant_specs.entry_point

@implement_for("gym", None, "0.24")
def _kwargs(env_variant_specs):
return env_variant_specs._kwargs

@implement_for("gym", "0.24", None)
def _kwargs(env_variant_specs):
return env_variant_specs.kwargs

@implement_for("gymnasium")
def _kwargs(env_variant_specs):
return env_variant_specs.kwargs

# Update base_dict using update_dict
def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None):
Expand Down Expand Up @@ -47,10 +100,10 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals
"""

# check if the base env is registered
assert env_id in gym.envs.registry.env_specs.keys(), "ERROR: {} not found in env registry".format(env_id)
assert env_id in gym_registry_specs().keys(), "ERROR: {} not found in env registry".format(env_id)

# recover the specs of the existing env
env_variant_specs = deepcopy(gym.envs.registry.env_specs[env_id])
env_variant_specs = deepcopy(gym_registry_specs()[env_id])
env_variant_id = env_variant_specs.id[:-3]

# update horizon if requested
Expand All @@ -60,16 +113,16 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals
del variants['max_episode_steps']

# merge specs._kwargs with variants
env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys)
variants_update_keyval_str = _update_env_spec_kwarg(env_variant_specs, variants, override_keys)
env_variant_id += variants_update_keyval_str

# finalize name and register env
env_variant_specs.id = env_variant_id+env_variant_specs.id[-3:] if variant_id is None else variant_id
register(
id=env_variant_specs.id,
entry_point=env_variant_specs._entry_point,
entry_point=_entry_point(env_variant_specs),
max_episode_steps=env_variant_specs.max_episode_steps,
kwargs=env_variant_specs._kwargs
kwargs=_kwargs(env_variant_specs)
)
if not silent:
print("Registered a new env-variant:", env_variant_specs.id)
Expand All @@ -96,11 +149,11 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals

# Test variant
print("Base-env kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[base_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[base_env_name]._kwargs)
print("Env-variant kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[variant_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[variant_env_name]._kwargs)
print("Env-variant (with override) kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[variant_overide_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[variant_overide_env_name]._kwargs)

# Test one of the newly minted env
env = gym.make(variant_env_name)
Expand Down
211 changes: 211 additions & 0 deletions robohive/utils/implement_for.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations
import collections
import inspect
import sys
from copy import copy
from functools import wraps
from importlib import import_module
from typing import Union, Callable, Dict
from packaging.version import parse

class implement_for:
"""A version decorator that checks the version in the environment and implements a function with the fitting one.

If specified module is missing or there is no fitting implementation, call of the decorated function
will lead to the explicit error.
In case of intersected ranges, last fitting implementation is used.

This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium,
numpy vs jax-numpy etc).

Args:
module_name (str or callable): version is checked for the module with this
name (e.g. "gym"). If a callable is provided, it should return the
module.
from_version: version from which implementation is compatible. Can be open (None).
to_version: version from which implementation is no longer compatible. Can be open (None).

Examples:
>>> @implement_for("gym", "0.13", "0.14")
>>> def fun(self, x):
... # Older gym versions will return x + 1
... return x + 1
...
>>> @implement_for("gym", "0.14", "0.23")
>>> def fun(self, x):
... # More recent gym versions will return x + 2
... return x + 2
...
>>> @implement_for(lambda: import_module("gym"), "0.23", None)
>>> def fun(self, x):
... # More recent gym versions will return x + 2
... return x + 2
...
>>> @implement_for("gymnasium", "0.27", None)
>>> def fun(self, x):
... # If gymnasium is to be used instead of gym, x+3 will be returned
... return x + 3
...

This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
"""

# Stores pointers to fitting implementations: dict[func_name] = func_pointer
_implementations = {}
_setters = []
_cache_modules = {}

def __init__(
self,
module_name: Union[str, Callable],
from_version: str = None,
to_version: str = None,
):
self.module_name = module_name
self.from_version = from_version
self.to_version = to_version
implement_for._setters.append(self)

@staticmethod
def check_version(version, from_version, to_version):
return (from_version is None or parse(version) >= parse(from_version)) and (
to_version is None or parse(version) < parse(to_version)
)

@staticmethod
def get_class_that_defined_method(f):
"""Returns the class of a method, if it is defined, and None otherwise."""
out = f.__globals__.get(f.__qualname__.split(".")[0], None)
return out

@classmethod
def get_func_name(cls, fn):
# produces a name like torchrl.module.Class.method or torchrl.module.function
first = str(fn).split(".")[0][len("<function ") :]
last = str(fn).split(".")[1:]
if last:
first = [first]
last[-1] = last[-1].split(" ")[0]
else:
last = [first.split(" ")[0]]
first = []
return ".".join([fn.__module__] + first + last)

def _get_cls(self, fn):
cls = self.get_class_that_defined_method(fn)
if cls is None:
# class not yet defined
return
if cls.__class__.__name__ == "function":
cls = inspect.getmodule(fn)
return cls

def module_set(self):
"""Sets the function in its module, if it exists already."""
prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
if prev_setter is not None:
prev_setter.do_set = False
type(self)._implementations[self.get_func_name(self.fn)] = self
cls = self.get_class_that_defined_method(self.fn)
if cls is not None:
if cls.__class__.__name__ == "function":
cls = inspect.getmodule(self.fn)
else:
# class not yet defined
return
setattr(cls, self.fn.__name__, self.fn)

@classmethod
def import_module(cls, module_name: Union[Callable, str]) -> str:
"""Imports module and returns its version."""
if not callable(module_name):
module = cls._cache_modules.get(module_name, None)
if module is None:
if module_name in sys.modules:
sys.modules[module_name] = module = import_module(module_name)
else:
cls._cache_modules[module_name] = module = import_module(
module_name
)
else:
module = module_name()
return module.__version__

_lazy_impl = collections.defaultdict(list)

def _delazify(self, func_name):
for local_call in implement_for._lazy_impl[func_name]:
out = local_call()
return out

def __call__(self, fn):
# function names are unique
self.func_name = self.get_func_name(fn)
self.fn = fn
implement_for._lazy_impl[self.func_name].append(self._call)

@wraps(fn)
def _lazy_call_fn(*args, **kwargs):
# first time we call the function, we also do the replacement.
# This will cause the imports to occur only during the first call to fn
return self._delazify(self.func_name)(*args, **kwargs)

return _lazy_call_fn

def _call(self):

# If the module is missing replace the function with the mock.
fn = self.fn
func_name = self.func_name
implementations = implement_for._implementations

@wraps(fn)
def unsupported(*args, **kwargs):
raise ModuleNotFoundError(
f"Supported version of '{func_name}' has not been found."
)

self.do_set = False
# Return fitting implementation if it was encountered before.
if func_name in implementations:
try:
# check that backends don't conflict
version = self.import_module(self.module_name)
if self.check_version(version, self.from_version, self.to_version):
self.do_set = True
if not self.do_set:
return implementations[func_name].fn
except ModuleNotFoundError:
# then it's ok, there is no conflict
return implementations[func_name].fn
else:
try:
version = self.import_module(self.module_name)
if self.check_version(version, self.from_version, self.to_version):
self.do_set = True
except ModuleNotFoundError:
return unsupported
if self.do_set:
self.module_set()
return fn
return unsupported

@classmethod
def reset(cls, setters_dict: Dict[str, implement_for] = None):
"""Resets the setters in setter_dict.

``setter_dict`` is a copy of implementations. We just need to iterate through its
values and call :meth:`~.module_set` for each.

"""
if setters_dict is None:
setters_dict = copy(cls._implementations)
for setter in setters_dict.values():
setter.module_set()

def __repr__(self):
return (
f"{self.__class__.__name__}("
f"module_name={self.module_name}({self.from_version, self.to_version}), "
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})"
)
Loading