Skip to content

Commit

Permalink
add gymnasium
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored and vikashplus committed Nov 18, 2023
1 parent 051d167 commit a7b61bd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion robohive/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward(self, **kwargs):
terminal = done
return obs, reward, terminal, False, info

@implement_for("gymnasium", "0.24", None)
@implement_for("gymnasium")
def forward(self, **kwargs):
obs, reward, done, info = self._forward(**kwargs)
terminal = done
Expand Down Expand Up @@ -505,6 +505,9 @@ def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs)
@implement_for("gym", "0.26", None)
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}
@implement_for("gymnasium")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}

Expand Down
20 changes: 20 additions & 0 deletions robohive/envs/env_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
@implement_for("gym", None, "0.24")
def gym_registry_specs():
return gym.envs.registry.env_specs

@implement_for("gym", "0.24", None)
def gym_registry_specs():
return gym.envs.registry

@implement_for("gymnasium")
def gym_registry_specs():
return gym.envs.registry

# TODO: move to within the function?
@implement_for("gym", None, "0.24")
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
Expand All @@ -30,13 +36,23 @@ def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys)
return variants_update_keyval_str

@implement_for("gymnasium")
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys)
return variants_update_keyval_str

@implement_for("gym", None, "0.24")
def _entry_point(env_variant_specs):
return env_variant_specs._entry_point

@implement_for("gym", "0.24", None)
def _entry_point(env_variant_specs):
return env_variant_specs.entry_point

@implement_for("gymnasium")
def _entry_point(env_variant_specs):
return env_variant_specs.entry_point

@implement_for("gym", None, "0.24")
def _kwargs(env_variant_specs):
return env_variant_specs._kwargs
Expand All @@ -45,6 +61,10 @@ def _kwargs(env_variant_specs):
def _kwargs(env_variant_specs):
return env_variant_specs.kwargs

@implement_for("gymnasium")
def _kwargs(env_variant_specs):
return env_variant_specs.kwargs

# Update base_dict using update_dict
def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None):
"""
Expand Down

0 comments on commit a7b61bd

Please sign in to comment.