diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 60ff1f89..c5a72d21 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -275,7 +275,7 @@ def forward(self, **kwargs): terminal = done return obs, reward, terminal, False, info - @implement_for("gymnasium", "0.24", None) + @implement_for("gymnasium") def forward(self, **kwargs): obs, reward, done, info = self._forward(**kwargs) terminal = done @@ -505,6 +505,9 @@ def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) @implement_for("gym", "0.26", None) + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} + @implement_for("gymnasium") def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} diff --git a/robohive/envs/env_variants.py b/robohive/envs/env_variants.py index 201b3a0b..4c5cf33f 100644 --- a/robohive/envs/env_variants.py +++ b/robohive/envs/env_variants.py @@ -17,9 +17,15 @@ @implement_for("gym", None, "0.24") def gym_registry_specs(): return gym.envs.registry.env_specs + @implement_for("gym", "0.24", None) def gym_registry_specs(): return gym.envs.registry + +@implement_for("gymnasium") +def gym_registry_specs(): + return gym.envs.registry + # TODO: move to within the function? @implement_for("gym", None, "0.24") def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): @@ -30,6 +36,11 @@ def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) return variants_update_keyval_str +@implement_for("gymnasium") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + @implement_for("gym", None, "0.24") def _entry_point(env_variant_specs): return env_variant_specs._entry_point @@ -37,6 +48,11 @@ def _entry_point(env_variant_specs): @implement_for("gym", "0.24", None) def _entry_point(env_variant_specs): return env_variant_specs.entry_point + +@implement_for("gymnasium") +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + @implement_for("gym", None, "0.24") def _kwargs(env_variant_specs): return env_variant_specs._kwargs @@ -45,6 +61,10 @@ def _kwargs(env_variant_specs): def _kwargs(env_variant_specs): return env_variant_specs.kwargs +@implement_for("gymnasium") +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs + # Update base_dict using update_dict def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None): """