Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored and vikashplus committed Nov 18, 2023
1 parent 009a373 commit 2b6a164
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions robohive/envs/env_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,39 @@
from copy import deepcopy
from flatten_dict import flatten, unflatten

from robohive.utils.implement_for import implement_for

#TODO: check versions
@implement_for("gym", None, "0.24")
def gym_registry_specs():
return gym.envs.registry.env_specs
@implement_for("gym", "0.24", None)
def gym_registry_specs():
return gym.envs.registry
# TODO: move to within the function?
@implement_for("gym", None, "0.24")
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys)

@implement_for("gym", "0.24", None)
def _update_env_spec_kwarg(env_variant_specs, variants, override_keys):
env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys)
return variants_update_keyval_str

@implement_for("gym", None, "0.24")
def _entry_point(env_variant_specs):
return env_variant_specs._entry_point

@implement_for("gym", "0.24", None)
def _entry_point(env_variant_specs):
return env_variant_specs.entry_point
@implement_for("gym", None, "0.24")
def _kwargs(env_variant_specs):
return env_variant_specs._kwargs

@implement_for("gym", "0.24", None)
def _kwargs(env_variant_specs):
return env_variant_specs.kwargs

# Update base_dict using update_dict
def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None):
Expand Down Expand Up @@ -47,10 +80,10 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals
"""

# check if the base env is registered
assert env_id in gym.envs.registry.env_specs.keys(), "ERROR: {} not found in env registry".format(env_id)
assert env_id in gym_registry_specs().keys(), "ERROR: {} not found in env registry".format(env_id)

# recover the specs of the existing env
env_variant_specs = deepcopy(gym.envs.registry.env_specs[env_id])
env_variant_specs = deepcopy(gym_registry_specs()[env_id])
env_variant_id = env_variant_specs.id[:-3]

# update horizon if requested
Expand All @@ -60,16 +93,16 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals
del variants['max_episode_steps']

# merge specs._kwargs with variants
env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys)
variants_update_keyval_str = _update_env_spec_kwarg(env_variant_specs, variants, override_keys)
env_variant_id += variants_update_keyval_str

# finalize name and register env
env_variant_specs.id = env_variant_id+env_variant_specs.id[-3:] if variant_id is None else variant_id
register(
id=env_variant_specs.id,
entry_point=env_variant_specs._entry_point,
entry_point=_entry_point(env_variant_specs),
max_episode_steps=env_variant_specs.max_episode_steps,
kwargs=env_variant_specs._kwargs
kwargs=_kwargs(env_variant_specs)
)
if not silent:
print("Registered a new env-variant:", env_variant_specs.id)
Expand All @@ -96,11 +129,11 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals

# Test variant
print("Base-env kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[base_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[base_env_name]._kwargs)
print("Env-variant kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[variant_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[variant_env_name]._kwargs)
print("Env-variant (with override) kwargs: ")
pprint.pprint(gym.envs.registry.env_specs[variant_overide_env_name]._kwargs)
pprint.pprint(gym_registry_specs()[variant_overide_env_name]._kwargs)

# Test one of the newly minted env
env = gym.make(variant_env_name)
Expand Down

0 comments on commit 2b6a164

Please sign in to comment.