diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 0834f7ff..b128cbf7 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -5,6 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ +# TODO: find how to make this compatible with gymnasium. Maybe a global variable that indicates what to use as backend? import gym import numpy as np import os @@ -13,6 +14,7 @@ from robohive.envs.obj_vec_dict import ObsVecDict from robohive.utils import tensor_utils from robohive.robot.robot import Robot +from robohive.utils.implement_for import implement_for from robohive.utils.prompt_utils import prompt, Prompt import skvideo.io from sys import platform @@ -264,8 +266,23 @@ def step(self, a, **kwargs): render_cbk=self.mj_render if self.mujoco_render_frames else None) return self.forward(**kwargs) + @implement_for("gym", None, "0.24") + def forward(self, **kwargs): + return self._forward(**kwargs) + + @implement_for("gym", "0.24", None) + def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + @implement_for("gymnasium") def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + + def _forward(self, **kwargs): """ Forward propagate env to recover env details Returns current obs(t), rwd(t), done(t), info(t) @@ -463,8 +480,7 @@ def seed(self, seed=None): def get_input_seed(self): return self.input_seed - - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): """ Reset the environment Default implemention provided. Override if env needs custom reset @@ -473,7 +489,15 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel self.robot.reset(qpos, qvel, **kwargs) return self.get_obs() - + @implement_for("gym", None, "0.26") + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) + @implement_for("gym", "0.26", None) + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} + @implement_for("gymnasium") + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} @property def _step(self, a): diff --git a/robohive/envs/env_variants.py b/robohive/envs/env_variants.py index 07bb3374..4c5cf33f 100644 --- a/robohive/envs/env_variants.py +++ b/robohive/envs/env_variants.py @@ -11,6 +11,59 @@ from copy import deepcopy from flatten_dict import flatten, unflatten +from robohive.utils.implement_for import implement_for + +#TODO: check versions +@implement_for("gym", None, "0.24") +def gym_registry_specs(): + return gym.envs.registry.env_specs + +@implement_for("gym", "0.24", None) +def gym_registry_specs(): + return gym.envs.registry + +@implement_for("gymnasium") +def gym_registry_specs(): + return gym.envs.registry + +# TODO: move to within the function? +@implement_for("gym", None, "0.24") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + +@implement_for("gym", "0.24", None) +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gymnasium") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gym", None, "0.24") +def _entry_point(env_variant_specs): + return env_variant_specs._entry_point + +@implement_for("gym", "0.24", None) +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gymnasium") +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gym", None, "0.24") +def _kwargs(env_variant_specs): + return env_variant_specs._kwargs + +@implement_for("gym", "0.24", None) +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs + +@implement_for("gymnasium") +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs # Update base_dict using update_dict def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None): @@ -47,10 +100,10 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals """ # check if the base env is registered - assert env_id in gym.envs.registry.env_specs.keys(), "ERROR: {} not found in env registry".format(env_id) + assert env_id in gym_registry_specs().keys(), "ERROR: {} not found in env registry".format(env_id) # recover the specs of the existing env - env_variant_specs = deepcopy(gym.envs.registry.env_specs[env_id]) + env_variant_specs = deepcopy(gym_registry_specs()[env_id]) env_variant_id = env_variant_specs.id[:-3] # update horizon if requested @@ -60,16 +113,16 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals del variants['max_episode_steps'] # merge specs._kwargs with variants - env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + variants_update_keyval_str = _update_env_spec_kwarg(env_variant_specs, variants, override_keys) env_variant_id += variants_update_keyval_str # finalize name and register env env_variant_specs.id = env_variant_id+env_variant_specs.id[-3:] if variant_id is None else variant_id register( id=env_variant_specs.id, - entry_point=env_variant_specs._entry_point, + entry_point=_entry_point(env_variant_specs), max_episode_steps=env_variant_specs.max_episode_steps, - kwargs=env_variant_specs._kwargs + kwargs=_kwargs(env_variant_specs) ) if not silent: print("Registered a new env-variant:", env_variant_specs.id) @@ -96,11 +149,11 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals # Test variant print("Base-env kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[base_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[base_env_name]._kwargs) print("Env-variant kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_env_name]._kwargs) print("Env-variant (with override) kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_overide_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_overide_env_name]._kwargs) # Test one of the newly minted env env = gym.make(variant_env_name) diff --git a/robohive/utils/implement_for.py b/robohive/utils/implement_for.py new file mode 100644 index 00000000..c1b4e102 --- /dev/null +++ b/robohive/utils/implement_for.py @@ -0,0 +1,211 @@ +from __future__ import annotations +import collections +import inspect +import sys +from copy import copy +from functools import wraps +from importlib import import_module +from typing import Union, Callable, Dict +from packaging.version import parse + +class implement_for: + """A version decorator that checks the version in the environment and implements a function with the fitting one. + + If specified module is missing or there is no fitting implementation, call of the decorated function + will lead to the explicit error. + In case of intersected ranges, last fitting implementation is used. + + This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, + numpy vs jax-numpy etc). + + Args: + module_name (str or callable): version is checked for the module with this + name (e.g. "gym"). If a callable is provided, it should return the + module. + from_version: version from which implementation is compatible. Can be open (None). + to_version: version from which implementation is no longer compatible. Can be open (None). + + Examples: + >>> @implement_for("gym", "0.13", "0.14") + >>> def fun(self, x): + ... # Older gym versions will return x + 1 + ... return x + 1 + ... + >>> @implement_for("gym", "0.14", "0.23") + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for(lambda: import_module("gym"), "0.23", None) + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for("gymnasium", "0.27", None) + >>> def fun(self, x): + ... # If gymnasium is to be used instead of gym, x+3 will be returned + ... return x + 3 + ... + + This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. + """ + + # Stores pointers to fitting implementations: dict[func_name] = func_pointer + _implementations = {} + _setters = [] + _cache_modules = {} + + def __init__( + self, + module_name: Union[str, Callable], + from_version: str = None, + to_version: str = None, + ): + self.module_name = module_name + self.from_version = from_version + self.to_version = to_version + implement_for._setters.append(self) + + @staticmethod + def check_version(version, from_version, to_version): + return (from_version is None or parse(version) >= parse(from_version)) and ( + to_version is None or parse(version) < parse(to_version) + ) + + @staticmethod + def get_class_that_defined_method(f): + """Returns the class of a method, if it is defined, and None otherwise.""" + out = f.__globals__.get(f.__qualname__.split(".")[0], None) + return out + + @classmethod + def get_func_name(cls, fn): + # produces a name like torchrl.module.Class.method or torchrl.module.function + first = str(fn).split(".")[0][len(" str: + """Imports module and returns its version.""" + if not callable(module_name): + module = cls._cache_modules.get(module_name, None) + if module is None: + if module_name in sys.modules: + sys.modules[module_name] = module = import_module(module_name) + else: + cls._cache_modules[module_name] = module = import_module( + module_name + ) + else: + module = module_name() + return module.__version__ + + _lazy_impl = collections.defaultdict(list) + + def _delazify(self, func_name): + for local_call in implement_for._lazy_impl[func_name]: + out = local_call() + return out + + def __call__(self, fn): + # function names are unique + self.func_name = self.get_func_name(fn) + self.fn = fn + implement_for._lazy_impl[self.func_name].append(self._call) + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + return self._delazify(self.func_name)(*args, **kwargs) + + return _lazy_call_fn + + def _call(self): + + # If the module is missing replace the function with the mock. + fn = self.fn + func_name = self.func_name + implementations = implement_for._implementations + + @wraps(fn) + def unsupported(*args, **kwargs): + raise ModuleNotFoundError( + f"Supported version of '{func_name}' has not been found." + ) + + self.do_set = False + # Return fitting implementation if it was encountered before. + if func_name in implementations: + try: + # check that backends don't conflict + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + if not self.do_set: + return implementations[func_name].fn + except ModuleNotFoundError: + # then it's ok, there is no conflict + return implementations[func_name].fn + else: + try: + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + except ModuleNotFoundError: + return unsupported + if self.do_set: + self.module_set() + return fn + return unsupported + + @classmethod + def reset(cls, setters_dict: Dict[str, implement_for] = None): + """Resets the setters in setter_dict. + + ``setter_dict`` is a copy of implementations. We just need to iterate through its + values and call :meth:`~.module_set` for each. + + """ + if setters_dict is None: + setters_dict = copy(cls._implementations) + for setter in setters_dict.values(): + setter.module_set() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"module_name={self.module_name}({self.from_version, self.to_version}), " + f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" + )