diff --git a/README.md b/README.md index 052b916c..aa3aae59 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,16 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use th -![PyPI](https://img.shields.io/pypi/v/robohive) +[![report](https://img.shields.io/badge/Project-Page-blue)](https://sites.google.com/view/robohive/) +[![report](https://img.shields.io/badge/ArXiv-Paper-green)](https://arxiv.org/abs/2310.06828) +[![Documentation](https://img.shields.io/static/v1?label=Wiki&message=Documentation&color= + + + + \ No newline at end of file diff --git a/robohive/envs/arms/pick_place_v0.py b/robohive/envs/arms/pick_place_v0.py index 298dc187..3fd9e66a 100644 --- a/robohive/envs/arms/pick_place_v0.py +++ b/robohive/envs/arms/pick_place_v0.py @@ -6,12 +6,14 @@ ================================================= """ import collections -import gym + import numpy as np from robohive.envs import env_base +from robohive.utils import gym from robohive.utils.quat_math import euler2quat + class PickPlaceV0(env_base.MujocoEnv): DEFAULT_OBS_KEYS = [ @@ -70,6 +72,12 @@ def _setup(self, self.randomize = randomize self.geom_sizes = geom_sizes + # Save body init pos + self.init_body_pos = {} + for body in ["obj0", "obj1", "obj2"]: + bid = self.sim.model.body_name2id(body) + self.init_body_pos[body] = self.sim.model.body_pos[bid].copy() + super()._setup(obs_keys=obs_keys, weighted_reward_keys=weighted_reward_keys, reward_mode=reward_mode, @@ -109,7 +117,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self): + def reset(self, **kwargs): if self.randomize: # target location @@ -119,19 +127,19 @@ def reset(self): # object shapes and locations for body in ["obj0", "obj1", "obj2"]: bid = self.sim.model.body_name2id(body) - self.sim.model.body_pos[bid] += self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos + self.sim.model.body_pos[bid] = self.init_body_pos[body] + self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos self.sim.model.body_quat[bid] = euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat for gid in range(self.sim.model.body_geomnum[bid]): gid+=self.sim.model.body_geomadr[bid] - self.sim.model.geom_type[gid]=self.np_random.randint(low=2, high=7) # random shape + self.sim.model.geom_type[gid]=self.np_random.choice([2,3,4,5,6]) # random shape self.sim.model.geom_size[gid]=self.np_random.uniform(low=self.geom_sizes['low'], high=self.geom_sizes['high']) # random size self.sim.model.geom_pos[gid]=self.np_random.uniform(low=-1*self.sim.model.geom_size[gid], high=self.sim.model.geom_size[gid]) # random pos self.sim.model.geom_quat[gid]=euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat self.sim.model.geom_rgba[gid]=self.np_random.uniform(low=[.2, .2, .2, 1], high=[.9, .9, .9, 1]) # random color self.sim.forward() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) return obs # def viewer_setup(self): diff --git a/robohive/envs/arms/push_base_v0.py b/robohive/envs/arms/push_base_v0.py index fe0ba699..f69af1d8 100644 --- a/robohive/envs/arms/push_base_v0.py +++ b/robohive/envs/arms/push_base_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -103,8 +103,8 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self): + def reset(self, **kwargs): self.sim.model.site_pos[self.target_sid] = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low']) self.sim_obsd.model.site_pos[self.target_sid] = self.sim.model.site_pos[self.target_sid] - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) return obs diff --git a/robohive/envs/arms/reach_base_v0.py b/robohive/envs/arms/reach_base_v0.py index 6a78e031..5460bcc0 100644 --- a/robohive/envs/arms/reach_base_v0.py +++ b/robohive/envs/arms/reach_base_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -97,8 +97,8 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): self.sim.model.site_pos[self.target_sid] = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low']) self.sim_obsd.model.site_pos[self.target_sid] = self.sim.model.site_pos[self.target_sid] - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs diff --git a/robohive/envs/claws/__init__.py b/robohive/envs/claws/__init__.py index 1b78c421..ec019270 100644 --- a/robohive/envs/claws/__init__.py +++ b/robohive/envs/claws/__init__.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register import os curr_dir = os.path.dirname(os.path.abspath(__file__)) from robohive.envs.env_variants import register_env_variant @@ -38,7 +38,7 @@ 'model_path': curr_dir+'/trifinger/trifinger_reorient.xml', 'object_site_name': "object", 'target_site_name': "target", - 'target_xyz_range': {'high':[.05, .05, 0.9], 'low':[-.05, -.05, 0.99]}, + 'target_xyz_range': {'high':[.05, .05, 0.99], 'low':[-.05, -.05, 0.9]}, 'target_euler_range': {'high':[1, 1, 1], 'low':[-1, -1, -1]} } ) diff --git a/robohive/envs/claws/reorient_v0.py b/robohive/envs/claws/reorient_v0.py index ceabeef5..2af1124a 100644 --- a/robohive/envs/claws/reorient_v0.py +++ b/robohive/envs/claws/reorient_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -98,7 +98,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self): + def reset(self, **kwargs): desired_pos = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low']) self.sim.model.site_pos[self.target_sid] = desired_pos self.sim_obsd.model.site_pos[self.target_sid] = desired_pos @@ -108,5 +108,5 @@ def reset(self): self.sim.model.site_quat[self.target_sid] = euler2quat(desired_orien) self.sim_obsd.model.site_quat[self.target_sid] = euler2quat(desired_orien) - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) return obs diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 9bff0158..1072b132 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym import numpy as np import os import time as timer @@ -13,11 +13,13 @@ from robohive.envs.obs_vec_dict import ObsVecDict from robohive.utils import tensor_utils from robohive.robot.robot import Robot +from robohive.utils.implement_for import implement_for from robohive.utils.prompt_utils import prompt, Prompt import skvideo.io from sys import platform from robohive.physics.sim_scene import SimScene import robohive.utils.import_utils as import_utils +from robohive.envs.env_variants import gym_registry_specs # TODO # remove rwd_mode @@ -130,7 +132,7 @@ def _setup(self, self._setup_rgb_encoders(self.visual_keys, device=None) # reset to get the env ready - observation, _reward, done, _info = self.step(np.zeros(self.sim.model.nu)) + observation, _reward, done, *_, _info = self.step(np.zeros(self.sim.model.nu)) # Question: Should we replace above with following? Its specially helpful for hardware as it forces a env reset before continuing, without which the hardware will make a big jump from its position to the position asked by step. # observation = self.reset() assert not done, "Check initialization. Simulation starts in a done state." @@ -263,8 +265,23 @@ def step(self, a, **kwargs): render_cbk=self.mj_render if self.mujoco_render_frames else None) return self.forward(**kwargs) + @implement_for("gym", None, "0.24") + def forward(self, **kwargs): + return self._forward(**kwargs) + + @implement_for("gym", "0.24", None) + def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + @implement_for("gymnasium") def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + + def _forward(self, **kwargs): """ Forward propagate env to recover env details Returns current obs(t), rwd(t), done(t), info(t) @@ -476,20 +493,28 @@ def get_input_seed(self): return self.input_seed - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def _reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs): """ Reset the environment Default implemention provided. Override if env needs custom reset """ qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(qpos, qvel, **kwargs) + self.robot.reset(reset_pos=qpos, reset_vel=qvel, seed=seed, **kwargs) return self.get_obs() + @implement_for("gym", None, "0.26") + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) + @implement_for("gym", "0.26", None) + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} + @implement_for("gymnasium") + def reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, seed=seed, **kwargs), {} - - @property - def _step(self, a): - return self.step(a) + # @property + # def _step(self, a): + # return self.step(a) @property @@ -507,9 +532,15 @@ def id(self): return self.spec.id + @implement_for("gym") + def _horizon(self): + return self.spec.max_episode_steps # paths could have early termination before horizon + @implement_for("gymnasium") + def _horizon(self): + return gym_registry_specs()[self.spec.id].max_episode_steps # gymnasium unwrapper overrides specs (https://github.com/Farama-Foundation/Gymnasium/issues/871) @property def horizon(self): - return self.spec.max_episode_steps # paths could have early termination before horizon + return self._horizon() def get_env_state(self): @@ -702,7 +733,7 @@ def examine_policy(self, ep_rwd = 0.0 while t < horizon and done is False: a = policy.get_action(o)[0] if mode == 'exploration' else policy.get_action(o)[1]['evaluation'] - next_o, rwd, done, env_info = self.step(a) + next_o, rwd, done, *_, env_info = self.step(a) ep_rwd += rwd # render offscreen visuals if render =='offscreen': @@ -794,7 +825,7 @@ def examine_policy_new(self, ep_rwd = 0.0 # Rollout -------------------------------- - obs, rwd, done, env_info = self.forward(update_exteroception=True) # t=0 + obs, rwd, done, *_, env_info = self.forward(update_exteroception=True) # t=0 while t < horizon and done is False: # print(t, t*self.dt, self.time, t*self.dt-self.time) @@ -825,7 +856,7 @@ def examine_policy_new(self, # step env using actions from t=>t+1 ---------------------- - obs, rwd, done, env_info = self.step(act, update_exteroception=True) + obs, rwd, done, *_, env_info = self.step(act, update_exteroception=True) t = t+1 ep_rwd += rwd diff --git a/robohive/envs/env_variants.py b/robohive/envs/env_variants.py index 07bb3374..4ec05f5e 100644 --- a/robohive/envs/env_variants.py +++ b/robohive/envs/env_variants.py @@ -5,12 +5,65 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register import collections from copy import deepcopy from flatten_dict import flatten, unflatten +from robohive.utils.implement_for import implement_for + +#TODO: check versions +@implement_for("gym", None, "0.24") +def gym_registry_specs(): + return gym.envs.registry.env_specs + +@implement_for("gym", "0.24", None) +def gym_registry_specs(): + return gym.envs.registry + +@implement_for("gymnasium") +def gym_registry_specs(): + return gym.envs.registry + +# TODO: move to within the function? +@implement_for("gym", None, "0.24") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gym", "0.24", None) +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gymnasium") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gym", None, "0.24") +def _entry_point(env_variant_specs): + return env_variant_specs._entry_point + +@implement_for("gym", "0.24", None) +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gymnasium") +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gym", None, "0.24") +def _kwargs(env_variant_specs): + return env_variant_specs._kwargs + +@implement_for("gym", "0.24", None) +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs + +@implement_for("gymnasium") +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs # Update base_dict using update_dict def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None): @@ -47,10 +100,10 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals """ # check if the base env is registered - assert env_id in gym.envs.registry.env_specs.keys(), "ERROR: {} not found in env registry".format(env_id) + assert env_id in gym_registry_specs().keys(), "ERROR: {} not found in env registry".format(env_id) # recover the specs of the existing env - env_variant_specs = deepcopy(gym.envs.registry.env_specs[env_id]) + env_variant_specs = deepcopy(gym_registry_specs()[env_id]) env_variant_id = env_variant_specs.id[:-3] # update horizon if requested @@ -60,16 +113,16 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals del variants['max_episode_steps'] # merge specs._kwargs with variants - env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + variants_update_keyval_str = _update_env_spec_kwarg(env_variant_specs, variants, override_keys) env_variant_id += variants_update_keyval_str # finalize name and register env env_variant_specs.id = env_variant_id+env_variant_specs.id[-3:] if variant_id is None else variant_id register( id=env_variant_specs.id, - entry_point=env_variant_specs._entry_point, + entry_point=_entry_point(env_variant_specs), max_episode_steps=env_variant_specs.max_episode_steps, - kwargs=env_variant_specs._kwargs + kwargs=_kwargs(env_variant_specs) ) if not silent: print("Registered a new env-variant:", env_variant_specs.id) @@ -96,11 +149,11 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals # Test variant print("Base-env kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[base_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[base_env_name]._kwargs) print("Env-variant kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_env_name]._kwargs) print("Env-variant (with override) kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_overide_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_overide_env_name]._kwargs) # Test one of the newly minted env env = gym.make(variant_env_name) diff --git a/robohive/envs/fm/__init__.py b/robohive/envs/fm/__init__.py index 16dd0562..b9d71406 100644 --- a/robohive/envs/fm/__init__.py +++ b/robohive/envs/fm/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + import numpy as np import os curr_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/fm/franka_ee_pose_v0.py b/robohive/envs/fm/franka_ee_pose_v0.py index 896ad2f7..82008527 100644 --- a/robohive/envs/fm/franka_ee_pose_v0.py +++ b/robohive/envs/fm/franka_ee_pose_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym; import numpy as np from robohive.envs import env_base from robohive.physics.sim_scene import SimScene @@ -119,9 +119,9 @@ def get_target_pose(self): return self.np_random.uniform(low=self.sim.model.actuator_ctrlrange[:,0], high=self.sim.model.actuator_ctrlrange[:,1]) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): self.target_pose = self.get_target_pose() - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs class FrankaRobotiqPose(FrankaEEPose): diff --git a/robohive/envs/fm/franka_robotiq_data_v0.py b/robohive/envs/fm/franka_robotiq_data_v0.py index bf7d1463..1f1bdbcf 100644 --- a/robohive/envs/fm/franka_robotiq_data_v0.py +++ b/robohive/envs/fm/franka_robotiq_data_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym; import numpy as np from robohive.envs import env_base from robohive.physics.sim_scene import SimScene diff --git a/robohive/envs/hands/__init__.py b/robohive/envs/hands/__init__.py index 11001702..9ffc7b7a 100644 --- a/robohive/envs/hands/__init__.py +++ b/robohive/envs/hands/__init__.py @@ -5,7 +5,8 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + from robohive.envs.env_variants import register_env_variant import os curr_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/hands/baoding_v1.py b/robohive/envs/hands/baoding_v1.py index 1e643916..1b7cafec 100644 --- a/robohive/envs/hands/baoding_v1.py +++ b/robohive/envs/hands/baoding_v1.py @@ -7,7 +7,7 @@ import collections import enum -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -256,7 +256,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs): # reset counters self.counter=0 @@ -264,7 +264,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6) self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel) + obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) return obs def create_goal_trajectory(self, time_step=.1, time_period=6): @@ -272,7 +272,7 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): # populate go-to task with a target location if self.which_task==Task.MOVE_TO_LOCATION: - goal_pos = np.random.randint(4) + goal_pos = self.np_random.choice([0,1,2,3]) desired_position = [] if goal_pos==0: desired_position.append(0.01) #x @@ -325,6 +325,6 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): class BaodingRandomEnvV1(BaodingFixedEnvV1): - def reset(self): - obs = super().reset(time_period = self.np_random.uniform(high=5, low=7)) + def reset(self, **kwargs): + obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) return obs diff --git a/robohive/envs/hands/door_v1.py b/robohive/envs/hands/door_v1.py index 877a84de..8cd9f5e1 100644 --- a/robohive/envs/hands/door_v1.py +++ b/robohive/envs/hands/door_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -128,4 +128,3 @@ def set_env_state(self, state_dict): self.sim.set_state(qpos=qp, qvel=qv) self.sim.model.body_pos[self.door_bid] = state_dict['door_body_pos'] self.sim.forward() - diff --git a/robohive/envs/hands/hammer_v1.py b/robohive/envs/hands/hammer_v1.py index a7607b1d..3ef29033 100644 --- a/robohive/envs/hands/hammer_v1.py +++ b/robohive/envs/hands/hammer_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.utils.quat_math import * diff --git a/robohive/envs/hands/pen_v1.py b/robohive/envs/hands/pen_v1.py index 2a4bcbcb..54539d34 100644 --- a/robohive/envs/hands/pen_v1.py +++ b/robohive/envs/hands/pen_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.utils.vector_math import calculate_cosine diff --git a/robohive/envs/hands/relocate_v1.py b/robohive/envs/hands/relocate_v1.py index 5c19b326..88e273e2 100644 --- a/robohive/envs/hands/relocate_v1.py +++ b/robohive/envs/hands/relocate_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/multi_task/common/franka_appliance_v1.py b/robohive/envs/multi_task/common/franka_appliance_v1.py index 14fc2d1d..9d5a5920 100644 --- a/robohive/envs/multi_task/common/franka_appliance_v1.py +++ b/robohive/envs/multi_task/common/franka_appliance_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.utils.quat_math import euler2quat @@ -54,7 +54,7 @@ def _setup( **kwargs, ) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): # randomize object bodies, if requested if self.obj_body_randomize: for body_name in self.obj_body_randomize: @@ -78,4 +78,4 @@ def reset(self, reset_qpos=None, reset_qvel=None): * (self.robot_ranges[:, 1] - self.robot_ranges[:, 0]) ) - return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel) + return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) diff --git a/robohive/envs/multi_task/common/franka_kitchen_v2.py b/robohive/envs/multi_task/common/franka_kitchen_v2.py index 50217e5b..5162d11d 100644 --- a/robohive/envs/multi_task/common/franka_kitchen_v2.py +++ b/robohive/envs/multi_task/common/franka_kitchen_v2.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym from robohive.envs.multi_task.multi_task_base_v1 import KitchenBase class FrankaKitchen(KitchenBase): @@ -113,7 +113,7 @@ def _setup( ) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): if reset_qpos is None: reset_qpos = self.init_qpos.copy() @@ -128,4 +128,4 @@ def reset(self, reset_qpos=None, reset_qvel=None): if self.robot_base_range: self.sim.model.body_pos[self.robot_base_bid] = self.robot_base_pos + self.np_random.uniform(**self.robot_base_range) - return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel) \ No newline at end of file + return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) \ No newline at end of file diff --git a/robohive/envs/multi_task/multi_task_base_v1.py b/robohive/envs/multi_task/multi_task_base_v1.py index 3b4a443f..a18adda4 100644 --- a/robohive/envs/multi_task/multi_task_base_v1.py +++ b/robohive/envs/multi_task/multi_task_base_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/multi_task/substeps1/__init__.py b/robohive/envs/multi_task/substeps1/__init__.py index 839a8f43..db479c9e 100644 --- a/robohive/envs/multi_task/substeps1/__init__.py +++ b/robohive/envs/multi_task/substeps1/__init__.py @@ -6,7 +6,8 @@ ================================================= """ import os -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + CURR_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/multi_task/substeps1/franka_kitchen.py b/robohive/envs/multi_task/substeps1/franka_kitchen.py index 6305681e..4fdc7de5 100644 --- a/robohive/envs/multi_task/substeps1/franka_kitchen.py +++ b/robohive/envs/multi_task/substeps1/franka_kitchen.py @@ -6,7 +6,8 @@ ================================================= """ import os -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + from robohive.envs.multi_task.common.franka_kitchen_v2 import FrankaKitchen import copy diff --git a/robohive/envs/multi_task/utils/parse_demos.py b/robohive/envs/multi_task/utils/parse_demos.py index 9bed1914..65795e3a 100644 --- a/robohive/envs/multi_task/utils/parse_demos.py +++ b/robohive/envs/multi_task/utils/parse_demos.py @@ -23,7 +23,7 @@ import robohive import time as timer # import skvideo.io -import gym +from robohive.utils import gym from tqdm import tqdm diff --git a/robohive/envs/myo/assets/arm/myoarm_relocate.xml b/robohive/envs/myo/assets/arm/myoarm_relocate.xml new file mode 100644 index 00000000..2e8414e1 --- /dev/null +++ b/robohive/envs/myo/assets/arm/myoarm_relocate.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/robohive/envs/myo/assets/hand/MyoHand_object.xml b/robohive/envs/myo/assets/hand/myohand_object.xml similarity index 97% rename from robohive/envs/myo/assets/hand/MyoHand_object.xml rename to robohive/envs/myo/assets/hand/myohand_object.xml index 7d40a5b4..79f7a8db 100644 --- a/robohive/envs/myo/assets/hand/MyoHand_object.xml +++ b/robohive/envs/myo/assets/hand/myohand_object.xml @@ -7,7 +7,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ====================================================== --> - + diff --git a/robohive/envs/myo/assets/hand/MyoHand_tabletop.xml b/robohive/envs/myo/assets/hand/myohand_tabletop.xml similarity index 100% rename from robohive/envs/myo/assets/hand/MyoHand_tabletop.xml rename to robohive/envs/myo/assets/hand/myohand_tabletop.xml diff --git a/robohive/envs/myo/assets/leg/myolegs_chasetag.xml b/robohive/envs/myo/assets/leg/myolegs_chasetag.xml new file mode 100644 index 00000000..7f6f1dc6 --- /dev/null +++ b/robohive/envs/myo/assets/leg/myolegs_chasetag.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/robohive/envs/myo/base_v0.py b/robohive/envs/myo/base_v0.py index 90b098ef..dd7ddca1 100644 --- a/robohive/envs/myo/base_v0.py +++ b/robohive/envs/myo/base_v0.py @@ -115,17 +115,4 @@ def step(self, a, **kwargs): realTimeSim=self.mujoco_render_frames, render_cbk=self.mj_render if self.mujoco_render_frames else None) - # observation - obs = self.get_obs(**kwargs) - - # rewards - self.expand_dims(self.obs_dict) # required for vectorized rewards calculations - self.rwd_dict = self.get_reward_dict(self.obs_dict) - self.squeeze_dims(self.rwd_dict) - self.squeeze_dims(self.obs_dict) - - # finalize step - env_info = self.get_env_infos() - - # returns obs(t+1), rwd(t+1), done(t+1), info(t+1) - return obs, env_info['rwd_'+self.rwd_mode], bool(env_info['done']), env_info \ No newline at end of file + return self.forward(**kwargs) \ No newline at end of file diff --git a/robohive/envs/myo/myobase/__init__.py b/robohive/envs/myo/myobase/__init__.py index 3ed59ef0..eac00fdb 100644 --- a/robohive/envs/myo/myobase/__init__.py +++ b/robohive/envs/myo/myobase/__init__.py @@ -3,7 +3,8 @@ Authors :: Vikash Kumar (vikashplus@gmail.com), Vittorio Caggiano (caggiano@gmail.com) ================================================= """ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + from robohive.envs.env_variants import register_env_variant import os @@ -287,11 +288,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs): # Gait Torso Reaching ============================== from robohive.physics.sim_scene import SimBackend sim_backend = SimBackend.get_sim_backend() -if sim_backend == SimBackend.MUJOCO_PY: - leg_model='/../../../simhive/myo_sim/leg/myolegs_v0.54(mj210).mjb' -elif sim_backend == SimBackend.MUJOCO: - leg_model='/../../../simhive/myo_sim/leg/myolegs_v0.56(mj237).mjb' - # leg_model='/../../../simhive/myo_sim/leg/myolegs_suspended_v0.56(mj236).mjb' +leg_model='/../../../simhive/myo_sim/leg/myolegs.xml' register_env_with_variants(id='myoLegStandRandom-v0', @@ -299,7 +296,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs): max_episode_steps=150, kwargs={ 'model_path': curr_dir+leg_model, - 'joint_random_range': (0.2, -0.2), #range of joint randomization (jnt = init_qpos + random(range) + 'joint_random_range': (-.2, 0.2), #range of joint randomization (jnt = init_qpos + random(range) 'target_reach_range': { 'pelvis': ((-.05, -.05, 0), (0.05, 0.05, 0)), }, diff --git a/robohive/envs/myo/myobase/baoding_v1.py b/robohive/envs/myo/myobase/baoding_v1.py index 42b7163c..44f63fad 100644 --- a/robohive/envs/myo/myobase/baoding_v1.py +++ b/robohive/envs/myo/myobase/baoding_v1.py @@ -5,7 +5,7 @@ import collections import enum -import gym +from robohive.utils import gym import numpy as np from robohive.envs.myo.base_v0 import BaseV0 @@ -260,7 +260,7 @@ def evaluate_success(self, paths, logger=None, successful_steps=5): logger.log_kv('effort', effort) return success_percentage - def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None, **kwargs): # reset counters self.counter=0 self.x_radius=self.np_random.uniform(low=self.goal_xrange[0], high=self.goal_xrange[1]) @@ -273,7 +273,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=No self.goal = self.create_goal_trajectory(time_step=self.dt, time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel) + obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) return obs def create_goal_trajectory(self, time_step=.1, time_period=6): @@ -281,7 +281,7 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): # populate go-to task with a target location (pos likely needs update) if self.which_task==Task.MOVE_TO_LOCATION: - goal_pos = np.random.randint(4) + goal_pos = self.np_random.choice([0,1,2,3]) desired_position = [] if goal_pos==0: desired_position.append(-.195) #x diff --git a/robohive/envs/myo/myobase/key_turn_v0.py b/robohive/envs/myo/myobase/key_turn_v0.py index 37bdb988..9fc51e1d 100644 --- a/robohive/envs/myo/myobase/key_turn_v0.py +++ b/robohive/envs/myo/myobase/key_turn_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 @@ -110,11 +110,12 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel qpos[-1] = self.np_random.uniform(low=self.key_init_range[0], high=self.key_init_range[1]) if self.key_init_range[0]!=self.key_init_range[1]: # randomEnv self.sim.model.body_pos[-1] = self.key_init_pos+self.np_random.uniform(low=np.array([-0.01, -0.01, -.01]), high=np.array([0.01, 0.01, 0.01])) - self.robot.reset(qpos, qvel) - return self.get_obs() \ No newline at end of file + + obs = super().reset(reset_qpos=qpos, reset_qvel=qvel, **kwargs) + return obs \ No newline at end of file diff --git a/robohive/envs/myo/myobase/obj_hold_v0.py b/robohive/envs/myo/myobase/obj_hold_v0.py index 7aa3ae7b..b0864933 100644 --- a/robohive/envs/myo/myobase/obj_hold_v0.py +++ b/robohive/envs/myo/myobase/obj_hold_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 @@ -101,7 +101,7 @@ def get_reward_dict(self, obs_dict): class ObjHoldRandomEnvV0(ObjHoldFixedEnvV0): - def reset(self): + def reset(self, **kwargs): # randomize target pos self.sim.model.site_pos[self.goal_sid] = self.object_init_pos + self.np_random.uniform(high=np.array([0.030, 0.030, 0.030]), low=np.array([-.030, -.030, -.030])) # randomize object @@ -109,5 +109,5 @@ def reset(self): self.sim.model.geom_size[-1] = size self.sim.model.site_size[self.goal_sid] = size self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs \ No newline at end of file diff --git a/robohive/envs/myo/myobase/pen_v0.py b/robohive/envs/myo/myobase/pen_v0.py index 6a0f9f87..4684ed77 100644 --- a/robohive/envs/myo/myobase/pen_v0.py +++ b/robohive/envs/myo/myobase/pen_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import euler2quat @@ -123,12 +123,12 @@ def get_reward_dict(self, obs_dict): class PenTwirlRandomEnvV0(PenTwirlFixedEnvV0): - def reset(self): + def reset(self, **kwargs): # randomize target desired_orien = np.zeros(3) desired_orien[0] = self.np_random.uniform(low=-1, high=1) desired_orien[1] = self.np_random.uniform(low=-1, high=1) self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs diff --git a/robohive/envs/myo/myobase/pose_v0.py b/robohive/envs/myo/myobase/pose_v0.py index 97798bac..f5f422f9 100644 --- a/robohive/envs/myo/myobase/pose_v0.py +++ b/robohive/envs/myo/myobase/pose_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs.myo.base_v0 import BaseV0 @@ -142,7 +142,7 @@ def update_target(self, restore_sim=False): # reset_type = none; init; random # target_type = generate; switch - def reset(self): + def reset(self, **kwargs): # udpate wegith if self.weight_bodyname is not None: @@ -182,11 +182,11 @@ def reset(self): obs = self.get_obs() elif self.reset_type == "init": # reset to init state - obs = super().reset() + obs = super().reset(**kwargs) elif self.reset_type == "random": # reset to random state jnt_init = self.np_random.uniform(high=self.sim.model.jnt_range[:,1], low=self.sim.model.jnt_range[:,0]) - obs = super().reset(reset_qpos=jnt_init) + obs = super().reset(reset_qpos=jnt_init, **kwargs) else: print("Reset Type not found") diff --git a/robohive/envs/myo/myobase/reach_v0.py b/robohive/envs/myo/myobase/reach_v0.py index 1c9896bc..d66c102b 100644 --- a/robohive/envs/myo/myobase/reach_v0.py +++ b/robohive/envs/myo/myobase/reach_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs.myo.base_v0 import BaseV0 @@ -116,8 +116,8 @@ def generate_target_pose(self): self.sim.forward() - def reset(self): + def reset(self, **kwargs): self.generate_target_pose() self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs \ No newline at end of file diff --git a/robohive/envs/myo/myobase/reorient_sar_v0.py b/robohive/envs/myo/myobase/reorient_sar_v0.py index a87f9339..ed99a603 100644 --- a/robohive/envs/myo/myobase/reorient_sar_v0.py +++ b/robohive/envs/myo/myobase/reorient_sar_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import euler2quat, mulQuat, negQuat, mat2quat @@ -122,7 +122,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict class Geometries8EnvV0(ProprioceptiveEnvV0): - def reset(self): + def reset(self, **kwargs): ellips = {0: [[0.011, 0.025, 0.025], [0.74792, 0.35159, 0.80154, 1.0]], 1: [[0.019, 0.040, 0.040], [0.23366, 0.67864, 0.53721, 1.0]]} box = {0: [[0.017, 0.017, 0.017], [0.42829, 0.76091, 0.4914, 1.0]], 1: [[0.023, 0.023, 0.023], [0.21995, 0.60938, 0.18821, 1.0]]} @@ -191,11 +191,11 @@ def reset(self): self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs class Geometries100EnvV0(ProprioceptiveEnvV0): - def reset(self): + def reset(self, **kwargs): ellips = {0: [[0.02843, 0.0256, 0.02902], [0.74792, 0.35159, 0.80154, 1.0]], 1: [[0.01057, 0.02655, 0.0328], [0.23366, 0.67864, 0.53721, 1.0]], 2: [[0.01126, 0.0273, 0.04264], [0.65043, 0.2313, 0.50699, 1.0]], 3: [[0.02641, 0.03524, 0.02831], [0.62441, 0.5802, 0.43566, 1.0]], 4: [[0.02804, 0.03722, 0.04313], [0.13892, 0.45695, 0.11598, 1.0]], 5: [[0.02305, 0.04456, 0.03709], [0.13836, 0.67832, 0.31776, 1.0]], 6: [[0.02332, 0.02673, 0.02606], [0.15742, 0.40625, 0.11557, 1.0]], 7: [[0.01247, 0.03233, 0.03759], [0.83, 0.24511, 0.30415, 1.0]], 8: [[0.02199, 0.029, 0.04484], [0.23557, 0.72447, 0.75669, 1.0]], 9: [[0.02674, 0.0428, 0.03764], [0.8393, 0.75063, 0.18226, 1.0]], 10: [[0.02278, 0.04006, 0.03556], [0.32785, 0.49373, 0.5858, 1.0]], 11: [[0.02392, 0.04095, 0.03467], [0.8965, 0.22427, 0.41412, 1.0]], 12: [[0.01928, 0.0348, 0.03044], [0.20289, 0.70564, 0.55928, 1.0]], 13: [[0.02388, 0.03644, 0.02817], [0.67021, 0.70081, 0.36769, 1.0]], 14: [[0.02739, 0.04338, 0.03457], [0.28136, 0.77765, 0.28719, 1.0]], 15: [[0.00962, 0.04047, 0.02614], [0.49566, 0.72634, 0.52086, 1.0]], 16: [[0.0163, 0.04443, 0.04326], [0.1731, 0.8899, 0.10808, 1.0]], 17: [[0.02417, 0.03157, 0.04038], [0.21701, 0.29525, 0.62152, 1.0]], 18: [[0.01927, 0.02814, 0.03786], [0.76065, 0.49735, 0.27818, 1.0]], 19: [[0.02477, 0.04456, 0.04493], [0.80959, 0.8233, 0.5421, 1.0]], 20: [[0.01656, 0.0291, 0.03996], [0.33429, 0.36101, 0.28275, 1.0]], 21: [[0.01763, 0.03877, 0.03636], [0.21692, 0.27226, 0.65917, 1.0]], 22: [[0.01915, 0.0346, 0.04245], [0.41227, 0.43577, 0.31358, 1.0]], 23: [[0.02485, 0.03324, 0.02881], [0.65375, 0.75452, 0.47755, 1.0]], 24: [[0.00856, 0.04185, 0.03749], [0.44235, 0.15332, 0.76038, 1.0]]} box = {0: [[0.02295, 0.02306, 0.02221], [0.42829, 0.76091, 0.4914, 1.0]], 1: [[0.02447, 0.0185, 0.02192], [0.21995, 0.20938, 0.48821, 1.0]], 2: [[0.01853, 0.01837, 0.01546], [0.76916, 0.26635, 0.16801, 1.0]], 3: [[0.01586, 0.02079, 0.022], [0.1423, 0.75584, 0.15317, 1.0]], 4: [[0.02293, 0.02116, 0.02255], [0.2323, 0.3077, 0.42493, 1.0]], 5: [[0.01542, 0.01651, 0.02381], [0.26685, 0.38304, 0.60438, 1.0]], 6: [[0.0186, 0.02402, 0.02333], [0.1624, 0.45497, 0.14676, 1.0]], 7: [[0.01782, 0.01584, 0.02208], [0.17459, 0.54131, 0.68087, 1.0]], 8: [[0.01907, 0.0195, 0.02161], [0.54092, 0.40078, 0.68101, 1.0]], 9: [[0.01751, 0.0211, 0.01864], [0.65326, 0.59045, 0.30555, 1.0]], 10: [[0.02258, 0.02334, 0.01856], [0.80835, 0.76567, 0.63477, 1.0]], 11: [[0.02195, 0.01617, 0.02438], [0.86178, 0.17993, 0.61248, 1.0]], 12: [[0.01627, 0.02254, 0.02073], [0.39115, 0.68792, 0.78923, 1.0]], 13: [[0.02364, 0.01946, 0.01777], [0.57742, 0.55447, 0.48724, 1.0]], 14: [[0.01754, 0.02463, 0.01549], [0.73537, 0.1708, 0.49452, 1.0]], 15: [[0.02394, 0.02382, 0.02387], [0.13934, 0.18804, 0.67206, 1.0]], 16: [[0.01997, 0.02372, 0.02032], [0.55577, 0.62793, 0.42524, 1.0]], 17: [[0.01741, 0.02316, 0.02203], [0.8914, 0.54996, 0.31562, 1.0]], 18: [[0.02032, 0.0217, 0.02432], [0.38009, 0.75075, 0.64515, 1.0]], 19: [[0.01961, 0.0248, 0.0176], [0.5283, 0.86192, 0.77579, 1.0]], 20: [[0.01906, 0.01999, 0.02399], [0.49911, 0.89906, 0.44505, 1.0]], 21: [[0.02472, 0.01826, 0.0151], [0.79248, 0.49588, 0.41427, 1.0]], 22: [[0.01636, 0.0158, 0.01958], [0.57198, 0.58271, 0.78801, 1.0]], 23: [[0.01542, 0.02434, 0.02237], [0.22467, 0.88589, 0.83947, 1.0]], 24: [[0.01731, 0.02185, 0.02019], [0.77122, 0.73006, 0.14257, 1.0]]} @@ -260,11 +260,11 @@ def reset(self): self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs class InDistribution(ProprioceptiveEnvV0): - def reset(self): + def reset(self, **kwargs): ellips = {0: [[0.0179, 0.0446, 0.0356], [0.34, 0.2627, 0.3952, 1.0]], 1: [[0.0218, 0.0327, 0.0237], [0.3313, 0.8439, 0.3636, 1.0]], 2: [[0.0144, 0.0335, 0.0322], [0.6041, 0.4344, 0.1575, 1.0]], 3: [[0.0171, 0.0285, 0.0388], [0.8304, 0.58, 0.2844, 1.0]], 4: [[0.0091, 0.0373, 0.0214], [0.204, 0.5209, 0.2448, 1.0]], 5: [[0.0216, 0.0294, 0.0408], [0.8079, 0.4409, 0.8124, 1.0]], 6: [[0.012, 0.0342, 0.0324], [0.7284, 0.2844, 0.874, 1.0]], 7: [[0.0213, 0.0369, 0.0343], [0.3488, 0.1383, 0.132, 1.0]], 8: [[0.0119, 0.0309, 0.0338], [0.3082, 0.8545, 0.8087, 1.0]], 9: [[0.0092, 0.0368, 0.0249], [0.312, 0.4161, 0.3201, 1.0]], 10: [[0.0111, 0.0324, 0.0333], [0.5433, 0.8573, 0.5046, 1.0]], 11: [[0.0084, 0.0233, 0.0431], [0.2799, 0.1479, 0.8274, 1.0]], 12: [[0.0217, 0.0343, 0.029], [0.7758, 0.6188, 0.765, 1.0]], 13: [[0.0157, 0.0311, 0.0283], [0.7821, 0.1809, 0.8041, 1.0]], 14: [[0.012, 0.0408, 0.0395], [0.1677, 0.4763, 0.1768, 1.0]], 15: [[0.0193, 0.0333, 0.0411], [0.2521, 0.1056, 0.6368, 1.0]], 16: [[0.0124, 0.02, 0.024], [0.164, 0.1243, 0.7993, 1.0]], 17: [[0.0188, 0.0432, 0.0255], [0.843, 0.1857, 0.489, 1.0]], 18: [[0.0184, 0.024, 0.0298], [0.1337, 0.5341, 0.6517, 1.0]], 19: [[0.0135, 0.038, 0.0371], [0.6872, 0.7654, 0.373, 1.0]], 20: [[0.0216, 0.0225, 0.0261], [0.4133, 0.5901, 0.7897, 1.0]], 21: [[0.0082, 0.0204, 0.0241], [0.7814, 0.2021, 0.4204, 1.0]], 22: [[0.0096, 0.0232, 0.0246], [0.8161, 0.5409, 0.2123, 1.0]], 23: [[0.0095, 0.025, 0.0274], [0.1395, 0.7302, 0.4261, 1.0]], 24: [[0.0167, 0.0359, 0.0212], [0.4984, 0.6301, 0.5516, 1.0]], 25: [[0.0143, 0.0313, 0.0443], [0.5803, 0.2667, 0.4889, 1.0]], 26: [[0.0122, 0.0261, 0.0388], [0.3106, 0.1305, 0.5166, 1.0]], 27: [[0.0084, 0.0441, 0.0342], [0.2434, 0.1669, 0.6743, 1.0]], 28: [[0.0174, 0.0218, 0.0244], [0.4817, 0.4323, 0.7243, 1.0]], 29: [[0.008, 0.0232, 0.0404], [0.1758, 0.4021, 0.512, 1.0]], 30: [[0.0095, 0.029, 0.0313], [0.7042, 0.1984, 0.366, 1.0]], 31: [[0.0173, 0.0277, 0.0276], [0.4663, 0.3589, 0.2309, 1.0]], 32: [[0.0127, 0.0344, 0.0359], [0.1129, 0.8764, 0.2991, 1.0]], 33: [[0.0178, 0.0392, 0.033], [0.6397, 0.6061, 0.3422, 1.0]], 34: [[0.0193, 0.0272, 0.0208], [0.3213, 0.4007, 0.1996, 1.0]], 35: [[0.0096, 0.0202, 0.0359], [0.8733, 0.5662, 0.7954, 1.0]], 36: [[0.0138, 0.0413, 0.0359], [0.4423, 0.1624, 0.4841, 1.0]], 37: [[0.0158, 0.0291, 0.0403], [0.3637, 0.7927, 0.7459, 1.0]], 38: [[0.0205, 0.0251, 0.0208], [0.5048, 0.5567, 0.3631, 1.0]], 39: [[0.0086, 0.028, 0.0206], [0.2369, 0.1793, 0.5048, 1.0]], 40: [[0.0156, 0.0332, 0.0322], [0.4873, 0.8613, 0.6544, 1.0]], 41: [[0.0197, 0.0437, 0.0261], [0.3981, 0.1823, 0.2919, 1.0]], 42: [[0.0093, 0.0348, 0.0416], [0.624, 0.1488, 0.4907, 1.0]], 43: [[0.0135, 0.0423, 0.0398], [0.1624, 0.107, 0.6028, 1.0]], 44: [[0.0195, 0.0285, 0.023], [0.4965, 0.1902, 0.3155, 1.0]], 45: [[0.0163, 0.0268, 0.0375], [0.566, 0.8469, 0.8336, 1.0]], 46: [[0.0197, 0.0245, 0.0399], [0.1049, 0.7715, 0.563, 1.0]], 47: [[0.0199, 0.0317, 0.0309], [0.6912, 0.7888, 0.2288, 1.0]], 48: [[0.0132, 0.0374, 0.0393], [0.1806, 0.6692, 0.4167, 1.0]], 49: [[0.0134, 0.0421, 0.0387], [0.5341, 0.6567, 0.8963, 1.0]], 50: [[0.0126, 0.0389, 0.0264], [0.5698, 0.302, 0.5201, 1.0]], 51: [[0.0161, 0.0281, 0.0355], [0.8565, 0.4137, 0.6611, 1.0]], 52: [[0.0164, 0.0381, 0.0341], [0.8276, 0.2435, 0.7939, 1.0]], 53: [[0.0208, 0.0343, 0.0435], [0.2087, 0.7211, 0.6089, 1.0]], 54: [[0.011, 0.0355, 0.0249], [0.4889, 0.2257, 0.6651, 1.0]], 55: [[0.0163, 0.0355, 0.0245], [0.5397, 0.842, 0.8906, 1.0]], 56: [[0.0198, 0.022, 0.0284], [0.2942, 0.1247, 0.4229, 1.0]], 57: [[0.0091, 0.0304, 0.0211], [0.4287, 0.1466, 0.395, 1.0]], 58: [[0.0161, 0.039, 0.0276], [0.6075, 0.8244, 0.3582, 1.0]], 59: [[0.0202, 0.0442, 0.0352], [0.391, 0.3324, 0.3166, 1.0]], 60: [[0.0089, 0.0271, 0.0323], [0.5451, 0.225, 0.7743, 1.0]], 61: [[0.0136, 0.0223, 0.039], [0.2951, 0.8605, 0.7955, 1.0]], 62: [[0.0113, 0.0215, 0.0378], [0.2177, 0.6716, 0.1103, 1.0]], 63: [[0.0188, 0.0303, 0.0332], [0.8879, 0.1028, 0.5333, 1.0]], 64: [[0.0194, 0.0234, 0.0225], [0.8584, 0.1146, 0.5777, 1.0]], 65: [[0.0104, 0.0234, 0.0416], [0.1862, 0.5018, 0.3055, 1.0]], 66: [[0.0209, 0.0286, 0.0217], [0.5629, 0.4863, 0.8173, 1.0]], 67: [[0.01, 0.043, 0.021], [0.4847, 0.8588, 0.1886, 1.0]], 68: [[0.0147, 0.0349, 0.0404], [0.3467, 0.7475, 0.7833, 1.0]], 69: [[0.0217, 0.0402, 0.0293], [0.8445, 0.6944, 0.1076, 1.0]], 70: [[0.0108, 0.024, 0.0401], [0.2511, 0.5804, 0.2071, 1.0]], 71: [[0.0177, 0.0214, 0.0294], [0.814, 0.2083, 0.1068, 1.0]], 72: [[0.008, 0.0422, 0.0436], [0.5886, 0.1532, 0.3145, 1.0]], 73: [[0.0118, 0.0405, 0.0418], [0.1081, 0.5158, 0.695, 1.0]], 74: [[0.0086, 0.0321, 0.0365], [0.5083, 0.8788, 0.6599, 1.0]], 75: [[0.0202, 0.0231, 0.0276], [0.2305, 0.617, 0.5928, 1.0]], 76: [[0.0172, 0.0412, 0.0292], [0.7684, 0.5089, 0.6511, 1.0]], 77: [[0.013, 0.0284, 0.0422], [0.3124, 0.6509, 0.4347, 1.0]], 78: [[0.0181, 0.03, 0.0305], [0.7501, 0.4577, 0.5444, 1.0]], 79: [[0.0187, 0.0223, 0.0329], [0.1123, 0.3249, 0.4961, 1.0]], 80: [[0.0085, 0.0378, 0.0427], [0.6627, 0.4468, 0.8872, 1.0]], 81: [[0.014, 0.0311, 0.0343], [0.7632, 0.1089, 0.4045, 1.0]], 82: [[0.0177, 0.0324, 0.0359], [0.8607, 0.153, 0.3992, 1.0]], 83: [[0.0198, 0.0429, 0.0274], [0.4244, 0.6196, 0.4516, 1.0]], 84: [[0.0217, 0.0341, 0.045], [0.6909, 0.1468, 0.7247, 1.0]], 85: [[0.0146, 0.0398, 0.0203], [0.3826, 0.3718, 0.3496, 1.0]], 86: [[0.0123, 0.0394, 0.0258], [0.6635, 0.705, 0.2369, 1.0]], 87: [[0.0178, 0.0222, 0.0214], [0.3096, 0.5513, 0.5841, 1.0]], 88: [[0.0214, 0.0302, 0.0364], [0.7606, 0.1013, 0.8378, 1.0]], 89: [[0.0176, 0.027, 0.032], [0.7617, 0.7469, 0.2241, 1.0]], 90: [[0.01, 0.0383, 0.03], [0.89, 0.3931, 0.1526, 1.0]], 91: [[0.0183, 0.0221, 0.0229], [0.7542, 0.3996, 0.5302, 1.0]], 92: [[0.0196, 0.0205, 0.0345], [0.2052, 0.4023, 0.3618, 1.0]], 93: [[0.0219, 0.0429, 0.026], [0.7436, 0.3471, 0.2587, 1.0]], 94: [[0.0176, 0.0381, 0.0371], [0.7742, 0.1907, 0.3718, 1.0]], 95: [[0.0128, 0.0343, 0.0428], [0.2232, 0.2526, 0.2038, 1.0]], 96: [[0.0157, 0.0206, 0.0299], [0.7561, 0.6327, 0.4422, 1.0]], 97: [[0.0218, 0.0431, 0.0449], [0.1949, 0.842, 0.5496, 1.0]], 98: [[0.0104, 0.0233, 0.0417], [0.5364, 0.5095, 0.7928, 1.0]], 99: [[0.0218, 0.0357, 0.0252], [0.1999, 0.4042, 0.3603, 1.0]], 100: [[0.0175, 0.0318, 0.0299], [0.2983, 0.2628, 0.5035, 1.0]], 101: [[0.0138, 0.0251, 0.0415], [0.5903, 0.784, 0.8442, 1.0]], 102: [[0.0144, 0.037, 0.029], [0.8542, 0.4365, 0.7302, 1.0]], 103: [[0.016, 0.0318, 0.0409], [0.6409, 0.1171, 0.2057, 1.0]], 104: [[0.0198, 0.0298, 0.0394], [0.1016, 0.4901, 0.7547, 1.0]], 105: [[0.0197, 0.0433, 0.0311], [0.1091, 0.4195, 0.6097, 1.0]], 106: [[0.0081, 0.0204, 0.0431], [0.8465, 0.7843, 0.5533, 1.0]], 107: [[0.0167, 0.0406, 0.0447], [0.6874, 0.8019, 0.3514, 1.0]], 108: [[0.0158, 0.0441, 0.025], [0.3182, 0.8791, 0.5031, 1.0]], 109: [[0.0171, 0.0426, 0.0444], [0.5549, 0.6867, 0.6405, 1.0]], 110: [[0.0128, 0.0214, 0.025], [0.4276, 0.4977, 0.6437, 1.0]], 111: [[0.0103, 0.0217, 0.0446], [0.539, 0.492, 0.6311, 1.0]], 112: [[0.0154, 0.0379, 0.045], [0.4229, 0.3995, 0.237, 1.0]], 113: [[0.0127, 0.0256, 0.0409], [0.1151, 0.6203, 0.5401, 1.0]], 114: [[0.0158, 0.0382, 0.0319], [0.5734, 0.3454, 0.7076, 1.0]], 115: [[0.0183, 0.0414, 0.043], [0.1459, 0.7983, 0.1768, 1.0]], 116: [[0.0209, 0.027, 0.0269], [0.8093, 0.2445, 0.6262, 1.0]], 117: [[0.0184, 0.0244, 0.0354], [0.5865, 0.1397, 0.1062, 1.0]], 118: [[0.009, 0.0375, 0.0361], [0.4566, 0.1599, 0.8376, 1.0]], 119: [[0.015, 0.0405, 0.0355], [0.1976, 0.5422, 0.1322, 1.0]], 120: [[0.0186, 0.0285, 0.0331], [0.1194, 0.6774, 0.8185, 1.0]], 121: [[0.0087, 0.0384, 0.0221], [0.204, 0.5089, 0.5555, 1.0]], 122: [[0.0198, 0.0201, 0.0248], [0.5485, 0.6, 0.8289, 1.0]], 123: [[0.0116, 0.0375, 0.0279], [0.3355, 0.1299, 0.2203, 1.0]], 124: [[0.0205, 0.0351, 0.0281], [0.8562, 0.5161, 0.3179, 1.0]], 125: [[0.0092, 0.0206, 0.0388], [0.8301, 0.7225, 0.644, 1.0]], 126: [[0.0094, 0.0394, 0.0274], [0.8076, 0.5332, 0.4689, 1.0]], 127: [[0.0209, 0.042, 0.0267], [0.189, 0.5581, 0.4131, 1.0]], 128: [[0.0081, 0.0368, 0.0257], [0.3423, 0.619, 0.279, 1.0]], 129: [[0.0131, 0.0395, 0.027], [0.3179, 0.3905, 0.7289, 1.0]], 130: [[0.0118, 0.044, 0.0324], [0.4052, 0.7151, 0.5217, 1.0]], 131: [[0.015, 0.0369, 0.0368], [0.6741, 0.689, 0.4182, 1.0]], 132: [[0.0194, 0.0301, 0.0436], [0.8841, 0.3257, 0.6886, 1.0]], 133: [[0.0206, 0.038, 0.026], [0.1673, 0.144, 0.7722, 1.0]], 134: [[0.0123, 0.0221, 0.0394], [0.4072, 0.5328, 0.1444, 1.0]], 135: [[0.015, 0.0331, 0.0268], [0.5753, 0.4666, 0.6661, 1.0]], 136: [[0.0158, 0.0224, 0.0308], [0.8457, 0.2743, 0.5386, 1.0]], 137: [[0.0203, 0.0203, 0.0206], [0.5707, 0.7995, 0.1143, 1.0]], 138: [[0.0108, 0.0444, 0.0302], [0.883, 0.2207, 0.7003, 1.0]], 139: [[0.0142, 0.0227, 0.0441], [0.5038, 0.1334, 0.4707, 1.0]], 140: [[0.0179, 0.0412, 0.0395], [0.1797, 0.8912, 0.7051, 1.0]], 141: [[0.022, 0.0277, 0.0245], [0.7404, 0.6499, 0.6006, 1.0]], 142: [[0.0151, 0.0384, 0.0247], [0.5983, 0.8277, 0.2456, 1.0]], 143: [[0.0158, 0.0318, 0.0285], [0.6195, 0.6399, 0.5961, 1.0]], 144: [[0.0157, 0.0389, 0.0347], [0.7169, 0.8966, 0.4938, 1.0]], 145: [[0.015, 0.0304, 0.0296], [0.4485, 0.3861, 0.6756, 1.0]], 146: [[0.0175, 0.0296, 0.0369], [0.3544, 0.5575, 0.8562, 1.0]], 147: [[0.0157, 0.0355, 0.0217], [0.7051, 0.5541, 0.2275, 1.0]], 148: [[0.0191, 0.0248, 0.041], [0.1171, 0.3022, 0.5573, 1.0]], 149: [[0.0095, 0.0394, 0.0327], [0.1485, 0.7252, 0.3156, 1.0]], 150: [[0.0092, 0.0369, 0.0416], [0.528, 0.8994, 0.1997, 1.0]], 151: [[0.0085, 0.0251, 0.0226], [0.159, 0.1596, 0.3521, 1.0]], 152: [[0.0132, 0.0381, 0.0248], [0.2871, 0.247, 0.4446, 1.0]], 153: [[0.0188, 0.0387, 0.0373], [0.3758, 0.1651, 0.2371, 1.0]], 154: [[0.0136, 0.0334, 0.0428], [0.3628, 0.6433, 0.8349, 1.0]], 155: [[0.0093, 0.0238, 0.0335], [0.6441, 0.4322, 0.1294, 1.0]], 156: [[0.0121, 0.0357, 0.031], [0.4834, 0.3508, 0.498, 1.0]], 157: [[0.0161, 0.0422, 0.0213], [0.4219, 0.6501, 0.1276, 1.0]], 158: [[0.0205, 0.0396, 0.0417], [0.1712, 0.7632, 0.5686, 1.0]], 159: [[0.0215, 0.03, 0.0323], [0.1548, 0.4639, 0.8891, 1.0]], 160: [[0.011, 0.0241, 0.0238], [0.3156, 0.6099, 0.358, 1.0]], 161: [[0.0172, 0.0426, 0.0292], [0.6982, 0.7319, 0.6483, 1.0]], 162: [[0.0206, 0.0229, 0.0336], [0.3875, 0.8807, 0.8088, 1.0]], 163: [[0.0213, 0.0221, 0.0433], [0.5799, 0.824, 0.7818, 1.0]], 164: [[0.0214, 0.0407, 0.0204], [0.6081, 0.4471, 0.2829, 1.0]], 165: [[0.0135, 0.0391, 0.0336], [0.145, 0.5886, 0.3284, 1.0]], 166: [[0.0083, 0.0323, 0.0298], [0.6364, 0.1223, 0.7308, 1.0]], 167: [[0.015, 0.0268, 0.0405], [0.8594, 0.5667, 0.6114, 1.0]], 168: [[0.0126, 0.0299, 0.0224], [0.1043, 0.3084, 0.2389, 1.0]], 169: [[0.019, 0.0387, 0.0295], [0.3984, 0.8755, 0.3889, 1.0]], 170: [[0.0135, 0.0333, 0.0278], [0.6523, 0.5869, 0.2271, 1.0]], 171: [[0.0145, 0.0427, 0.0284], [0.1204, 0.4681, 0.1915, 1.0]], 172: [[0.0215, 0.0331, 0.0201], [0.4475, 0.8037, 0.8816, 1.0]], 173: [[0.0164, 0.026, 0.0418], [0.4476, 0.7209, 0.5665, 1.0]], 174: [[0.0138, 0.0208, 0.0259], [0.4744, 0.2627, 0.8776, 1.0]], 175: [[0.0171, 0.0312, 0.0221], [0.3405, 0.2247, 0.4041, 1.0]], 176: [[0.0126, 0.0314, 0.0228], [0.3861, 0.6054, 0.2509, 1.0]], 177: [[0.0137, 0.0263, 0.0264], [0.5871, 0.4488, 0.1386, 1.0]], 178: [[0.0158, 0.0325, 0.0204], [0.4723, 0.7872, 0.3829, 1.0]], 179: [[0.018, 0.043, 0.0277], [0.1599, 0.818, 0.5206, 1.0]], 180: [[0.0085, 0.0264, 0.0422], [0.8142, 0.7123, 0.7632, 1.0]], 181: [[0.0219, 0.0211, 0.0207], [0.2742, 0.2027, 0.8496, 1.0]], 182: [[0.0156, 0.0212, 0.0356], [0.7043, 0.135, 0.2614, 1.0]], 183: [[0.0187, 0.0211, 0.0376], [0.8517, 0.4961, 0.2077, 1.0]], 184: [[0.0204, 0.0221, 0.0355], [0.5143, 0.5176, 0.6385, 1.0]], 185: [[0.0217, 0.0404, 0.0298], [0.6098, 0.2487, 0.8383, 1.0]], 186: [[0.0191, 0.043, 0.0282], [0.4476, 0.7411, 0.1112, 1.0]], 187: [[0.0159, 0.0293, 0.0377], [0.6904, 0.4926, 0.5994, 1.0]], 188: [[0.0091, 0.0262, 0.0417], [0.6368, 0.3511, 0.1941, 1.0]], 189: [[0.0093, 0.0297, 0.0244], [0.1454, 0.8709, 0.5172, 1.0]], 190: [[0.0183, 0.0413, 0.0418], [0.2906, 0.8729, 0.6184, 1.0]], 191: [[0.0089, 0.0218, 0.0381], [0.4896, 0.6824, 0.1759, 1.0]], 192: [[0.0112, 0.0434, 0.0427], [0.5026, 0.235, 0.7239, 1.0]], 193: [[0.0197, 0.0345, 0.0382], [0.5842, 0.7965, 0.4111, 1.0]], 194: [[0.0095, 0.0404, 0.0245], [0.1075, 0.8717, 0.2179, 1.0]], 195: [[0.0187, 0.0382, 0.043], [0.5261, 0.5464, 0.4336, 1.0]], 196: [[0.0162, 0.0373, 0.0389], [0.4062, 0.174, 0.2001, 1.0]], 197: [[0.0095, 0.0281, 0.0374], [0.7505, 0.4524, 0.8543, 1.0]], 198: [[0.0178, 0.0431, 0.0346], [0.7836, 0.507, 0.2946, 1.0]], 199: [[0.0086, 0.0206, 0.0255], [0.2519, 0.3512, 0.723, 1.0]], 200: [[0.0166, 0.0225, 0.0293], [0.3807, 0.7226, 0.5609, 1.0]], 201: [[0.0216, 0.0245, 0.0301], [0.1207, 0.4613, 0.7248, 1.0]], 202: [[0.0193, 0.0404, 0.0375], [0.2981, 0.8526, 0.134, 1.0]], 203: [[0.0138, 0.0241, 0.0389], [0.6076, 0.1337, 0.8627, 1.0]], 204: [[0.0113, 0.02, 0.0411], [0.1686, 0.7173, 0.1742, 1.0]], 205: [[0.0133, 0.0329, 0.0428], [0.1381, 0.5346, 0.7832, 1.0]], 206: [[0.022, 0.0331, 0.0301], [0.4032, 0.2363, 0.2722, 1.0]], 207: [[0.0195, 0.0306, 0.0281], [0.4538, 0.7098, 0.381, 1.0]], 208: [[0.0094, 0.0414, 0.042], [0.5582, 0.4801, 0.8864, 1.0]], 209: [[0.0085, 0.0286, 0.0346], [0.4346, 0.1496, 0.867, 1.0]], 210: [[0.0161, 0.0225, 0.0287], [0.6963, 0.4941, 0.4147, 1.0]], 211: [[0.0156, 0.0344, 0.0285], [0.7432, 0.6502, 0.7875, 1.0]], 212: [[0.0193, 0.0335, 0.0239], [0.7239, 0.8395, 0.6681, 1.0]], 213: [[0.0177, 0.021, 0.035], [0.2067, 0.115, 0.3646, 1.0]], 214: [[0.0212, 0.0276, 0.0447], [0.8377, 0.5035, 0.1801, 1.0]], 215: [[0.0217, 0.0431, 0.0297], [0.1662, 0.6649, 0.1192, 1.0]], 216: [[0.0161, 0.0381, 0.0243], [0.7918, 0.8644, 0.1653, 1.0]], 217: [[0.0144, 0.039, 0.0205], [0.2515, 0.4627, 0.1858, 1.0]], 218: [[0.0164, 0.0231, 0.0394], [0.8649, 0.6655, 0.6936, 1.0]], 219: [[0.021, 0.034, 0.0412], [0.6322, 0.3109, 0.1315, 1.0]], 220: [[0.0146, 0.0403, 0.0291], [0.4969, 0.6351, 0.7072, 1.0]], 221: [[0.0199, 0.0365, 0.0415], [0.1443, 0.7003, 0.3025, 1.0]], 222: [[0.0175, 0.0357, 0.0434], [0.8958, 0.1377, 0.5829, 1.0]], 223: [[0.0182, 0.0214, 0.0364], [0.7676, 0.4967, 0.6749, 1.0]], 224: [[0.0111, 0.0348, 0.0422], [0.1985, 0.4115, 0.2922, 1.0]], 225: [[0.0143, 0.0291, 0.0392], [0.1168, 0.8456, 0.4299, 1.0]], 226: [[0.013, 0.0323, 0.0267], [0.193, 0.2611, 0.5682, 1.0]], 227: [[0.0097, 0.0404, 0.0399], [0.486, 0.4811, 0.8595, 1.0]], 228: [[0.0163, 0.0266, 0.04], [0.364, 0.3567, 0.3217, 1.0]], 229: [[0.0207, 0.0286, 0.0214], [0.8925, 0.1563, 0.3139, 1.0]], 230: [[0.0097, 0.0375, 0.0255], [0.1164, 0.3525, 0.4542, 1.0]], 231: [[0.0119, 0.039, 0.0255], [0.3666, 0.7901, 0.4978, 1.0]], 232: [[0.0156, 0.0215, 0.0274], [0.8466, 0.1781, 0.6678, 1.0]], 233: [[0.0121, 0.0235, 0.0444], [0.6763, 0.3172, 0.3356, 1.0]], 234: [[0.0098, 0.0233, 0.0413], [0.8073, 0.5084, 0.8901, 1.0]], 235: [[0.0153, 0.0346, 0.0326], [0.5169, 0.5373, 0.6524, 1.0]], 236: [[0.0198, 0.0298, 0.0413], [0.7362, 0.3284, 0.1694, 1.0]], 237: [[0.0166, 0.0338, 0.0254], [0.5003, 0.1036, 0.675, 1.0]], 238: [[0.0195, 0.0284, 0.0218], [0.7115, 0.8238, 0.1391, 1.0]], 239: [[0.0207, 0.0352, 0.0397], [0.5007, 0.3275, 0.1935, 1.0]], 240: [[0.0121, 0.0337, 0.039], [0.4878, 0.7455, 0.3974, 1.0]], 241: [[0.0144, 0.0438, 0.0364], [0.669, 0.3398, 0.4691, 1.0]], 242: [[0.021, 0.0401, 0.0237], [0.655, 0.8149, 0.2127, 1.0]], 243: [[0.0146, 0.0413, 0.0303], [0.679, 0.5893, 0.4474, 1.0]], 244: [[0.0082, 0.0377, 0.0411], [0.688, 0.8363, 0.4124, 1.0]], 245: [[0.0172, 0.0439, 0.0397], [0.1834, 0.4685, 0.4009, 1.0]], 246: [[0.015, 0.0334, 0.0343], [0.3602, 0.6536, 0.5352, 1.0]], 247: [[0.0095, 0.0311, 0.0418], [0.3878, 0.3408, 0.1713, 1.0]], 248: [[0.0206, 0.0414, 0.0255], [0.7165, 0.2262, 0.6045, 1.0]], 249: [[0.0098, 0.0289, 0.0259], [0.8523, 0.179, 0.7258, 1.0]]} caps = {0: [[0.0116, 0.0303, 0.0327], [0.3035, 0.8151, 0.1706, 1.0]], 1: [[0.0135, 0.0431, 0.027], [0.2176, 0.3692, 0.2129, 1.0]], 2: [[0.0211, 0.0425, 0.0369], [0.4259, 0.6632, 0.3458, 1.0]], 3: [[0.0157, 0.0271, 0.04], [0.2733, 0.7038, 0.8995, 1.0]], 4: [[0.0192, 0.0417, 0.0318], [0.4286, 0.4809, 0.8645, 1.0]], 5: [[0.0209, 0.0385, 0.0333], [0.7587, 0.4668, 0.4177, 1.0]], 6: [[0.0197, 0.0285, 0.0206], [0.2474, 0.6831, 0.4869, 1.0]], 7: [[0.0212, 0.0294, 0.0415], [0.5253, 0.4904, 0.8947, 1.0]], 8: [[0.0102, 0.0289, 0.0405], [0.2619, 0.1736, 0.7266, 1.0]], 9: [[0.0169, 0.0289, 0.0235], [0.3558, 0.3385, 0.149, 1.0]], 10: [[0.0196, 0.0215, 0.043], [0.1956, 0.2643, 0.4139, 1.0]], 11: [[0.0129, 0.0224, 0.0292], [0.1594, 0.8224, 0.1802, 1.0]], 12: [[0.0203, 0.0347, 0.0336], [0.1015, 0.1924, 0.7818, 1.0]], 13: [[0.0179, 0.0374, 0.0418], [0.4296, 0.2539, 0.5479, 1.0]], 14: [[0.021, 0.0438, 0.0232], [0.3144, 0.3437, 0.3675, 1.0]], 15: [[0.0175, 0.0292, 0.0202], [0.79, 0.7432, 0.2371, 1.0]], 16: [[0.0139, 0.0243, 0.0296], [0.6455, 0.715, 0.356, 1.0]], 17: [[0.0188, 0.0235, 0.0418], [0.8043, 0.6093, 0.609, 1.0]], 18: [[0.0141, 0.0402, 0.0338], [0.3568, 0.4025, 0.2161, 1.0]], 19: [[0.0198, 0.0225, 0.0278], [0.4036, 0.7152, 0.3897, 1.0]], 20: [[0.0162, 0.0363, 0.0384], [0.1466, 0.478, 0.6661, 1.0]], 21: [[0.021, 0.036, 0.0346], [0.7738, 0.8506, 0.8907, 1.0]], 22: [[0.0145, 0.0261, 0.0312], [0.2366, 0.1764, 0.8232, 1.0]], 23: [[0.0136, 0.0384, 0.0273], [0.6499, 0.2905, 0.4252, 1.0]], 24: [[0.0191, 0.0315, 0.0431], [0.7513, 0.8309, 0.1806, 1.0]], 25: [[0.0159, 0.0274, 0.0218], [0.2988, 0.6072, 0.7034, 1.0]], 26: [[0.0198, 0.0372, 0.0267], [0.5948, 0.4304, 0.3038, 1.0]], 27: [[0.0198, 0.0424, 0.0441], [0.7152, 0.7174, 0.3449, 1.0]], 28: [[0.0193, 0.0379, 0.0259], [0.6032, 0.3437, 0.8725, 1.0]], 29: [[0.0161, 0.0363, 0.0217], [0.2494, 0.595, 0.6631, 1.0]], 30: [[0.0167, 0.0403, 0.0245], [0.1006, 0.1567, 0.3473, 1.0]], 31: [[0.0173, 0.0217, 0.0409], [0.8376, 0.4187, 0.4746, 1.0]], 32: [[0.0162, 0.0391, 0.0438], [0.6416, 0.3594, 0.1724, 1.0]], 33: [[0.0142, 0.0385, 0.0401], [0.6182, 0.5726, 0.6438, 1.0]], 34: [[0.0184, 0.0333, 0.032], [0.8179, 0.4146, 0.531, 1.0]], 35: [[0.0194, 0.0427, 0.0374], [0.2303, 0.4759, 0.4294, 1.0]], 36: [[0.0102, 0.0207, 0.0367], [0.6047, 0.4501, 0.8338, 1.0]], 37: [[0.0152, 0.0282, 0.0407], [0.6631, 0.7613, 0.5563, 1.0]], 38: [[0.0199, 0.0262, 0.0347], [0.7692, 0.1686, 0.1261, 1.0]], 39: [[0.0185, 0.0245, 0.0259], [0.7763, 0.6939, 0.3496, 1.0]], 40: [[0.0179, 0.0324, 0.0317], [0.5161, 0.3995, 0.4301, 1.0]], 41: [[0.0199, 0.0379, 0.028], [0.4624, 0.6208, 0.8652, 1.0]], 42: [[0.0165, 0.0208, 0.0296], [0.201, 0.4577, 0.2458, 1.0]], 43: [[0.0113, 0.0264, 0.0277], [0.8419, 0.1653, 0.7582, 1.0]], 44: [[0.0101, 0.0212, 0.0321], [0.857, 0.8674, 0.8981, 1.0]], 45: [[0.0113, 0.042, 0.032], [0.4617, 0.3144, 0.5003, 1.0]], 46: [[0.0145, 0.0234, 0.0271], [0.566, 0.3929, 0.2595, 1.0]], 47: [[0.014, 0.0351, 0.0404], [0.2455, 0.3105, 0.6313, 1.0]], 48: [[0.0161, 0.0365, 0.0268], [0.6829, 0.483, 0.4624, 1.0]], 49: [[0.0174, 0.0222, 0.043], [0.1739, 0.7701, 0.5844, 1.0]], 50: [[0.0138, 0.0405, 0.0394], [0.4791, 0.582, 0.6131, 1.0]], 51: [[0.016, 0.0266, 0.0414], [0.219, 0.3952, 0.3512, 1.0]], 52: [[0.0193, 0.0374, 0.0285], [0.1733, 0.5253, 0.5952, 1.0]], 53: [[0.0195, 0.0305, 0.025], [0.2337, 0.2408, 0.2449, 1.0]], 54: [[0.0105, 0.0251, 0.0449], [0.3661, 0.8533, 0.2444, 1.0]], 55: [[0.0186, 0.041, 0.0409], [0.3865, 0.1818, 0.6864, 1.0]], 56: [[0.0192, 0.042, 0.0238], [0.5518, 0.7363, 0.1583, 1.0]], 57: [[0.0116, 0.0251, 0.0379], [0.2524, 0.8529, 0.3819, 1.0]], 58: [[0.021, 0.0371, 0.021], [0.1414, 0.384, 0.7155, 1.0]], 59: [[0.0213, 0.0265, 0.0395], [0.2749, 0.3127, 0.7909, 1.0]], 60: [[0.021, 0.0319, 0.0283], [0.5913, 0.6348, 0.5974, 1.0]], 61: [[0.0149, 0.0439, 0.0299], [0.8976, 0.4958, 0.1536, 1.0]], 62: [[0.0172, 0.0248, 0.0299], [0.1486, 0.2097, 0.2295, 1.0]], 63: [[0.0179, 0.0351, 0.0355], [0.8954, 0.4191, 0.8969, 1.0]], 64: [[0.0124, 0.0382, 0.0424], [0.5034, 0.4461, 0.1056, 1.0]], 65: [[0.0215, 0.0217, 0.0356], [0.5877, 0.1743, 0.4656, 1.0]], 66: [[0.0206, 0.0267, 0.0367], [0.855, 0.62, 0.1948, 1.0]], 67: [[0.0165, 0.0261, 0.0435], [0.4013, 0.3245, 0.5992, 1.0]], 68: [[0.0186, 0.0283, 0.0373], [0.2626, 0.2002, 0.8558, 1.0]], 69: [[0.0149, 0.0376, 0.0324], [0.3573, 0.691, 0.2245, 1.0]], 70: [[0.0177, 0.0345, 0.0291], [0.3173, 0.4388, 0.6166, 1.0]], 71: [[0.0161, 0.0401, 0.0311], [0.6692, 0.3539, 0.8887, 1.0]], 72: [[0.0137, 0.0416, 0.0388], [0.4142, 0.1037, 0.8764, 1.0]], 73: [[0.0214, 0.035, 0.0247], [0.2892, 0.4272, 0.5212, 1.0]], 74: [[0.0201, 0.0327, 0.0307], [0.4781, 0.1114, 0.8923, 1.0]], 75: [[0.0168, 0.0392, 0.0287], [0.4642, 0.8013, 0.873, 1.0]], 76: [[0.0154, 0.0246, 0.0297], [0.292, 0.4809, 0.3152, 1.0]], 77: [[0.02, 0.0412, 0.0386], [0.2382, 0.1954, 0.2127, 1.0]], 78: [[0.0115, 0.0286, 0.0446], [0.6675, 0.6493, 0.3174, 1.0]], 79: [[0.021, 0.0267, 0.0364], [0.4092, 0.7301, 0.4177, 1.0]], 80: [[0.0197, 0.036, 0.033], [0.7277, 0.137, 0.4312, 1.0]], 81: [[0.011, 0.0212, 0.024], [0.1985, 0.4726, 0.3116, 1.0]], 82: [[0.0148, 0.043, 0.0265], [0.8787, 0.3004, 0.3646, 1.0]], 83: [[0.0142, 0.0219, 0.0429], [0.3327, 0.3933, 0.1077, 1.0]], 84: [[0.0113, 0.0275, 0.0351], [0.12, 0.7265, 0.3532, 1.0]], 85: [[0.0105, 0.0441, 0.027], [0.1197, 0.6571, 0.298, 1.0]], 86: [[0.0104, 0.0272, 0.0389], [0.1791, 0.5694, 0.4313, 1.0]], 87: [[0.0169, 0.0287, 0.0275], [0.4656, 0.1723, 0.4079, 1.0]], 88: [[0.0182, 0.0292, 0.0267], [0.1662, 0.6236, 0.3404, 1.0]], 89: [[0.0158, 0.0357, 0.032], [0.8223, 0.2361, 0.3007, 1.0]], 90: [[0.0144, 0.0254, 0.0322], [0.4219, 0.7239, 0.5218, 1.0]], 91: [[0.0192, 0.0335, 0.0429], [0.8972, 0.4936, 0.1189, 1.0]], 92: [[0.0213, 0.0217, 0.0271], [0.4168, 0.6123, 0.1293, 1.0]], 93: [[0.0189, 0.0334, 0.0417], [0.5742, 0.5411, 0.1496, 1.0]], 94: [[0.0203, 0.0313, 0.035], [0.7883, 0.7032, 0.8112, 1.0]], 95: [[0.021, 0.0215, 0.0389], [0.7635, 0.167, 0.8432, 1.0]], 96: [[0.0194, 0.0237, 0.0261], [0.2485, 0.7071, 0.8598, 1.0]], 97: [[0.0177, 0.0341, 0.0236], [0.3404, 0.4298, 0.2811, 1.0]], 98: [[0.0141, 0.0237, 0.0399], [0.847, 0.2361, 0.6864, 1.0]], 99: [[0.0148, 0.0302, 0.0251], [0.7078, 0.2871, 0.125, 1.0]], 100: [[0.0201, 0.0372, 0.0237], [0.4901, 0.2259, 0.4513, 1.0]], 101: [[0.0135, 0.034, 0.0381], [0.626, 0.4381, 0.3036, 1.0]], 102: [[0.0115, 0.0209, 0.0369], [0.2645, 0.5309, 0.1583, 1.0]], 103: [[0.012, 0.0412, 0.0407], [0.4469, 0.6832, 0.5495, 1.0]], 104: [[0.0166, 0.0275, 0.0289], [0.7481, 0.4705, 0.2908, 1.0]], 105: [[0.0135, 0.0306, 0.0393], [0.2987, 0.5765, 0.8264, 1.0]], 106: [[0.0151, 0.0319, 0.028], [0.6073, 0.3117, 0.5346, 1.0]], 107: [[0.0111, 0.0299, 0.0264], [0.1762, 0.2882, 0.5015, 1.0]], 108: [[0.0121, 0.0274, 0.0258], [0.6427, 0.1698, 0.2819, 1.0]], 109: [[0.0191, 0.041, 0.0273], [0.8644, 0.7027, 0.4652, 1.0]], 110: [[0.0189, 0.0206, 0.0373], [0.6887, 0.1229, 0.7218, 1.0]], 111: [[0.01, 0.0426, 0.024], [0.8984, 0.5439, 0.8335, 1.0]], 112: [[0.0206, 0.0336, 0.0303], [0.6008, 0.8212, 0.6378, 1.0]], 113: [[0.0119, 0.0205, 0.0411], [0.1392, 0.3216, 0.5325, 1.0]], 114: [[0.0111, 0.0374, 0.0226], [0.497, 0.7804, 0.6547, 1.0]], 115: [[0.0184, 0.0374, 0.0253], [0.8878, 0.6867, 0.4106, 1.0]], 116: [[0.0186, 0.0382, 0.0427], [0.2857, 0.7416, 0.8699, 1.0]], 117: [[0.0153, 0.025, 0.0303], [0.8224, 0.7924, 0.7773, 1.0]], 118: [[0.0124, 0.0309, 0.0344], [0.7696, 0.7589, 0.4938, 1.0]], 119: [[0.016, 0.03, 0.035], [0.475, 0.3274, 0.6608, 1.0]], 120: [[0.0142, 0.0362, 0.025], [0.7899, 0.7263, 0.7698, 1.0]], 121: [[0.0202, 0.0213, 0.0249], [0.3437, 0.5714, 0.5224, 1.0]], 122: [[0.0114, 0.0344, 0.0268], [0.4585, 0.3573, 0.7211, 1.0]], 123: [[0.0107, 0.0202, 0.0315], [0.5748, 0.1965, 0.3126, 1.0]], 124: [[0.0114, 0.043, 0.0358], [0.762, 0.5276, 0.3324, 1.0]], 125: [[0.0187, 0.0345, 0.0342], [0.7973, 0.2488, 0.4734, 1.0]], 126: [[0.0157, 0.0442, 0.0308], [0.623, 0.1523, 0.4676, 1.0]], 127: [[0.017, 0.0314, 0.0259], [0.4546, 0.1244, 0.8692, 1.0]], 128: [[0.0152, 0.0292, 0.0205], [0.8583, 0.186, 0.6164, 1.0]], 129: [[0.0217, 0.0387, 0.0409], [0.7588, 0.2858, 0.4066, 1.0]], 130: [[0.0151, 0.0213, 0.0275], [0.1535, 0.1794, 0.2188, 1.0]], 131: [[0.0165, 0.0202, 0.0226], [0.6488, 0.5128, 0.7284, 1.0]], 132: [[0.0198, 0.0306, 0.0344], [0.8172, 0.1524, 0.5647, 1.0]], 133: [[0.012, 0.0202, 0.0431], [0.3239, 0.2815, 0.1753, 1.0]], 134: [[0.014, 0.0407, 0.029], [0.3331, 0.5083, 0.1023, 1.0]], 135: [[0.0129, 0.039, 0.0278], [0.8779, 0.8324, 0.5216, 1.0]], 136: [[0.0211, 0.0407, 0.033], [0.3116, 0.2184, 0.1684, 1.0]], 137: [[0.0215, 0.0237, 0.0446], [0.4445, 0.7085, 0.321, 1.0]], 138: [[0.021, 0.0431, 0.0259], [0.2011, 0.8037, 0.248, 1.0]], 139: [[0.018, 0.0385, 0.0354], [0.4092, 0.3509, 0.3809, 1.0]], 140: [[0.0192, 0.0435, 0.0387], [0.5259, 0.6084, 0.2877, 1.0]], 141: [[0.0146, 0.0442, 0.0437], [0.8129, 0.5661, 0.2454, 1.0]], 142: [[0.0122, 0.0318, 0.0204], [0.1138, 0.2666, 0.6396, 1.0]], 143: [[0.0149, 0.025, 0.0414], [0.8627, 0.1379, 0.792, 1.0]], 144: [[0.019, 0.0394, 0.0342], [0.8871, 0.7036, 0.7611, 1.0]], 145: [[0.013, 0.0405, 0.0222], [0.6197, 0.1222, 0.7248, 1.0]], 146: [[0.0118, 0.0286, 0.0435], [0.245, 0.8639, 0.4145, 1.0]], 147: [[0.0215, 0.0257, 0.0278], [0.5146, 0.4355, 0.1102, 1.0]], 148: [[0.0168, 0.0289, 0.0336], [0.5404, 0.4091, 0.4605, 1.0]], 149: [[0.0183, 0.0321, 0.0219], [0.3991, 0.5365, 0.1933, 1.0]], 150: [[0.0136, 0.0414, 0.0226], [0.2694, 0.6687, 0.1661, 1.0]], 151: [[0.0201, 0.0408, 0.0352], [0.8922, 0.5367, 0.3726, 1.0]], 152: [[0.0213, 0.0407, 0.0402], [0.6644, 0.8317, 0.2544, 1.0]], 153: [[0.0147, 0.0428, 0.0258], [0.8475, 0.6459, 0.42, 1.0]], 154: [[0.0127, 0.0217, 0.0264], [0.3133, 0.1272, 0.2741, 1.0]], 155: [[0.0176, 0.0326, 0.0372], [0.1748, 0.2515, 0.7733, 1.0]], 156: [[0.0206, 0.029, 0.0234], [0.5347, 0.6689, 0.5907, 1.0]], 157: [[0.02, 0.0377, 0.0282], [0.7088, 0.7381, 0.124, 1.0]], 158: [[0.0125, 0.0261, 0.0223], [0.5638, 0.1319, 0.2414, 1.0]], 159: [[0.0161, 0.0206, 0.0384], [0.7087, 0.8315, 0.1975, 1.0]], 160: [[0.0176, 0.0311, 0.0226], [0.2817, 0.6628, 0.1593, 1.0]], 161: [[0.0152, 0.0399, 0.0434], [0.7062, 0.806, 0.7353, 1.0]], 162: [[0.0126, 0.0419, 0.0367], [0.1241, 0.4649, 0.6977, 1.0]], 163: [[0.021, 0.0294, 0.0446], [0.274, 0.4746, 0.8286, 1.0]], 164: [[0.0166, 0.025, 0.0284], [0.4418, 0.4761, 0.4187, 1.0]], 165: [[0.0148, 0.0323, 0.0439], [0.1281, 0.2521, 0.2884, 1.0]], 166: [[0.0166, 0.037, 0.0249], [0.8255, 0.1862, 0.2677, 1.0]], 167: [[0.0158, 0.0344, 0.0354], [0.8431, 0.4058, 0.7443, 1.0]], 168: [[0.0161, 0.0433, 0.0431], [0.4497, 0.8796, 0.6064, 1.0]], 169: [[0.012, 0.0246, 0.0357], [0.8279, 0.5592, 0.5058, 1.0]], 170: [[0.012, 0.0328, 0.0301], [0.4054, 0.2396, 0.8119, 1.0]], 171: [[0.0207, 0.0285, 0.0449], [0.2582, 0.7108, 0.6562, 1.0]], 172: [[0.0198, 0.0407, 0.0406], [0.7306, 0.8985, 0.2908, 1.0]], 173: [[0.0116, 0.0432, 0.0367], [0.8991, 0.3089, 0.3308, 1.0]], 174: [[0.0176, 0.0316, 0.0209], [0.2477, 0.3467, 0.8524, 1.0]], 175: [[0.0165, 0.0304, 0.0219], [0.2946, 0.2616, 0.1959, 1.0]], 176: [[0.021, 0.0418, 0.0226], [0.3537, 0.7351, 0.877, 1.0]], 177: [[0.0172, 0.0448, 0.0285], [0.4011, 0.2376, 0.5996, 1.0]], 178: [[0.0204, 0.0305, 0.0347], [0.1494, 0.3643, 0.4896, 1.0]], 179: [[0.0174, 0.0297, 0.0284], [0.1526, 0.1234, 0.3979, 1.0]], 180: [[0.0156, 0.0356, 0.0265], [0.7139, 0.5448, 0.4136, 1.0]], 181: [[0.0142, 0.0307, 0.0289], [0.8888, 0.1212, 0.5144, 1.0]], 182: [[0.0146, 0.0374, 0.0376], [0.299, 0.6482, 0.2818, 1.0]], 183: [[0.0188, 0.035, 0.0298], [0.5388, 0.6216, 0.5589, 1.0]], 184: [[0.0139, 0.0411, 0.0392], [0.6896, 0.2936, 0.1024, 1.0]], 185: [[0.0156, 0.043, 0.0221], [0.7578, 0.1587, 0.4894, 1.0]], 186: [[0.0176, 0.0349, 0.0379], [0.8278, 0.2228, 0.4548, 1.0]], 187: [[0.016, 0.0359, 0.0345], [0.8741, 0.7924, 0.2308, 1.0]], 188: [[0.0152, 0.0271, 0.0202], [0.3433, 0.3167, 0.6326, 1.0]], 189: [[0.0115, 0.0368, 0.0403], [0.7768, 0.7014, 0.4184, 1.0]], 190: [[0.0165, 0.0352, 0.0243], [0.603, 0.6831, 0.155, 1.0]], 191: [[0.0172, 0.0436, 0.0227], [0.8262, 0.7358, 0.3365, 1.0]], 192: [[0.0124, 0.0269, 0.0377], [0.58, 0.3878, 0.7276, 1.0]], 193: [[0.0159, 0.0288, 0.0428], [0.6612, 0.2328, 0.3988, 1.0]], 194: [[0.0137, 0.0243, 0.0409], [0.7058, 0.7323, 0.8918, 1.0]], 195: [[0.0168, 0.0399, 0.0405], [0.3749, 0.891, 0.4413, 1.0]], 196: [[0.0171, 0.0282, 0.0323], [0.8426, 0.2891, 0.3984, 1.0]], 197: [[0.0165, 0.0402, 0.0215], [0.3392, 0.2928, 0.7445, 1.0]], 198: [[0.0114, 0.0256, 0.0442], [0.1266, 0.1478, 0.4688, 1.0]], 199: [[0.0202, 0.0225, 0.02], [0.7923, 0.8564, 0.5706, 1.0]], 200: [[0.0207, 0.0428, 0.0432], [0.8461, 0.395, 0.4055, 1.0]], 201: [[0.0165, 0.0201, 0.044], [0.1453, 0.4584, 0.6714, 1.0]], 202: [[0.0188, 0.0261, 0.0379], [0.3333, 0.1802, 0.444, 1.0]], 203: [[0.0183, 0.0285, 0.0449], [0.1724, 0.1512, 0.702, 1.0]], 204: [[0.0108, 0.0317, 0.023], [0.2897, 0.4708, 0.7926, 1.0]], 205: [[0.0165, 0.0332, 0.0368], [0.6661, 0.4885, 0.5519, 1.0]], 206: [[0.0151, 0.0255, 0.0348], [0.3782, 0.1494, 0.2391, 1.0]], 207: [[0.0164, 0.0208, 0.0212], [0.5459, 0.8448, 0.7581, 1.0]], 208: [[0.0125, 0.0423, 0.0228], [0.5731, 0.2101, 0.7136, 1.0]], 209: [[0.0157, 0.0434, 0.0228], [0.5913, 0.7602, 0.1991, 1.0]], 210: [[0.0102, 0.0383, 0.02], [0.4456, 0.3827, 0.7466, 1.0]], 211: [[0.0171, 0.0269, 0.0221], [0.148, 0.329, 0.2079, 1.0]], 212: [[0.0179, 0.0273, 0.0309], [0.8428, 0.7415, 0.6008, 1.0]], 213: [[0.014, 0.0265, 0.02], [0.4166, 0.8607, 0.3437, 1.0]], 214: [[0.0159, 0.0424, 0.0282], [0.2146, 0.6623, 0.7518, 1.0]], 215: [[0.0113, 0.0248, 0.0214], [0.2293, 0.8791, 0.8644, 1.0]], 216: [[0.0106, 0.0358, 0.0437], [0.8324, 0.4802, 0.4808, 1.0]], 217: [[0.0147, 0.04, 0.0201], [0.5684, 0.1587, 0.8058, 1.0]], 218: [[0.0204, 0.0264, 0.0332], [0.586, 0.546, 0.8121, 1.0]], 219: [[0.0202, 0.0444, 0.0354], [0.1814, 0.1217, 0.8219, 1.0]], 220: [[0.0202, 0.027, 0.0394], [0.2476, 0.165, 0.5762, 1.0]], 221: [[0.0213, 0.0419, 0.0288], [0.891, 0.7265, 0.3415, 1.0]], 222: [[0.018, 0.0265, 0.0213], [0.4445, 0.7071, 0.6048, 1.0]], 223: [[0.0133, 0.04, 0.0208], [0.2507, 0.2076, 0.5588, 1.0]], 224: [[0.0137, 0.0312, 0.0411], [0.7524, 0.5718, 0.6326, 1.0]], 225: [[0.019, 0.0398, 0.0295], [0.2345, 0.1574, 0.8824, 1.0]], 226: [[0.0199, 0.0424, 0.0407], [0.262, 0.4326, 0.2912, 1.0]], 227: [[0.0159, 0.0388, 0.0319], [0.8816, 0.1561, 0.8673, 1.0]], 228: [[0.0144, 0.0436, 0.0302], [0.6982, 0.1408, 0.122, 1.0]], 229: [[0.0171, 0.0229, 0.0331], [0.6688, 0.6537, 0.8223, 1.0]], 230: [[0.019, 0.035, 0.041], [0.2898, 0.2755, 0.3883, 1.0]], 231: [[0.0106, 0.0209, 0.0318], [0.3899, 0.4914, 0.4955, 1.0]], 232: [[0.0117, 0.0268, 0.0275], [0.5498, 0.6662, 0.236, 1.0]], 233: [[0.0171, 0.0338, 0.0323], [0.7855, 0.3418, 0.8884, 1.0]], 234: [[0.0106, 0.0322, 0.0206], [0.1908, 0.2949, 0.1147, 1.0]], 235: [[0.0104, 0.0408, 0.0394], [0.5442, 0.8103, 0.8033, 1.0]], 236: [[0.0146, 0.0374, 0.0234], [0.1747, 0.6012, 0.1749, 1.0]], 237: [[0.012, 0.0379, 0.0339], [0.3885, 0.4516, 0.2483, 1.0]], 238: [[0.0158, 0.0327, 0.0382], [0.5894, 0.7818, 0.1015, 1.0]], 239: [[0.0117, 0.0289, 0.0422], [0.4988, 0.4957, 0.278, 1.0]], 240: [[0.0114, 0.0378, 0.0282], [0.1492, 0.6123, 0.5169, 1.0]], 241: [[0.0129, 0.0259, 0.0323], [0.6051, 0.2834, 0.5271, 1.0]], 242: [[0.0157, 0.0404, 0.0397], [0.2497, 0.1461, 0.7958, 1.0]], 243: [[0.0113, 0.0409, 0.027], [0.5513, 0.1007, 0.6764, 1.0]], 244: [[0.0124, 0.0448, 0.0357], [0.2974, 0.878, 0.1887, 1.0]], 245: [[0.0105, 0.0279, 0.0406], [0.6258, 0.2195, 0.5361, 1.0]], 246: [[0.0159, 0.0381, 0.029], [0.3421, 0.3551, 0.2471, 1.0]], 247: [[0.022, 0.0368, 0.0404], [0.5323, 0.849, 0.3699, 1.0]], 248: [[0.0183, 0.0261, 0.0257], [0.3421, 0.3559, 0.4682, 1.0]], 249: [[0.018, 0.0308, 0.0279], [0.7701, 0.5332, 0.2324, 1.0]]} @@ -330,11 +330,11 @@ def reset(self): self.sim.model.site_rgba[self.success_indicator_sid, :2] = np.array([2, 0]) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs class OutofDistribution(ProprioceptiveEnvV0): - def reset(self): + def reset(self, **kwargs): ellips = {0: [[0.0192, 0.0452, 0.0499], [0.5267, 0.4574, 0.8141, 1.0]], 1: [[0.0139, 0.0151, 0.0196], [0.732, 0.149, 0.7571, 1.0]], 2: [[0.0199, 0.0193, 0.0195], [0.8499, 0.8535, 0.5279, 1.0]], 3: [[0.013, 0.0498, 0.0495], [0.4756, 0.6692, 0.2101, 1.0]], 4: [[0.0101, 0.0184, 0.017], [0.2398, 0.142, 0.2351, 1.0]], 5: [[0.0163, 0.0181, 0.0158], [0.7945, 0.3168, 0.8549, 1.0]], 6: [[0.0123, 0.016, 0.0178], [0.2792, 0.164, 0.1963, 1.0]], 7: [[0.0207, 0.0151, 0.0191], [0.7214, 0.2926, 0.1578, 1.0]], 8: [[0.0101, 0.0178, 0.0184], [0.2483, 0.829, 0.394, 1.0]], 9: [[0.0152, 0.0168, 0.0184], [0.4377, 0.3745, 0.6418, 1.0]], 10: [[0.0181, 0.049, 0.0478], [0.6609, 0.7573, 0.3534, 1.0]], 11: [[0.0209, 0.0472, 0.0461], [0.8082, 0.6887, 0.5339, 1.0]], 12: [[0.0207, 0.0458, 0.0479], [0.2804, 0.6847, 0.2201, 1.0]], 13: [[0.021, 0.0464, 0.0483], [0.4936, 0.2515, 0.1887, 1.0]], 14: [[0.0208, 0.0171, 0.0162], [0.5238, 0.8167, 0.4633, 1.0]], 15: [[0.0191, 0.0173, 0.0152], [0.1136, 0.1179, 0.4157, 1.0]], 16: [[0.0164, 0.0186, 0.0168], [0.1566, 0.4181, 0.4308, 1.0]], 17: [[0.0218, 0.046, 0.0452], [0.7179, 0.2007, 0.6449, 1.0]], 18: [[0.013, 0.0455, 0.049], [0.2808, 0.8917, 0.8089, 1.0]], 19: [[0.0172, 0.0151, 0.0184], [0.6515, 0.5491, 0.1872, 1.0]], 20: [[0.0147, 0.0196, 0.0198], [0.4904, 0.2028, 0.1975, 1.0]], 21: [[0.018, 0.048, 0.0471], [0.1999, 0.5807, 0.2375, 1.0]], 22: [[0.021, 0.0478, 0.0459], [0.4834, 0.2712, 0.8256, 1.0]], 23: [[0.0215, 0.0174, 0.0194], [0.876, 0.1412, 0.5414, 1.0]], 24: [[0.0138, 0.0499, 0.0451], [0.2633, 0.2109, 0.1106, 1.0]], 25: [[0.0127, 0.0498, 0.0465], [0.217, 0.5914, 0.4403, 1.0]], 26: [[0.0205, 0.0477, 0.0479], [0.6415, 0.5312, 0.1625, 1.0]], 27: [[0.0166, 0.016, 0.0188], [0.3175, 0.7062, 0.3952, 1.0]], 28: [[0.0153, 0.0176, 0.0176], [0.4573, 0.1392, 0.3729, 1.0]], 29: [[0.018, 0.0175, 0.0161], [0.6205, 0.5582, 0.4211, 1.0]], 30: [[0.0148, 0.0186, 0.0172], [0.6425, 0.8501, 0.547, 1.0]], 31: [[0.0106, 0.0152, 0.0171], [0.5701, 0.5896, 0.423, 1.0]], 32: [[0.0159, 0.0166, 0.0159], [0.4597, 0.5511, 0.5647, 1.0]], 33: [[0.0197, 0.0158, 0.0159], [0.5407, 0.8049, 0.5615, 1.0]], 34: [[0.0123, 0.0499, 0.0479], [0.2203, 0.4778, 0.6365, 1.0]], 35: [[0.0192, 0.0151, 0.0158], [0.2206, 0.8939, 0.2668, 1.0]], 36: [[0.0169, 0.0196, 0.0172], [0.8576, 0.2538, 0.2064, 1.0]], 37: [[0.0158, 0.0182, 0.0151], [0.4524, 0.197, 0.5453, 1.0]], 38: [[0.0207, 0.0493, 0.0478], [0.6939, 0.1908, 0.4939, 1.0]], 39: [[0.0175, 0.0156, 0.0197], [0.2389, 0.5661, 0.5583, 1.0]], 40: [[0.0166, 0.0459, 0.0465], [0.7683, 0.2604, 0.1557, 1.0]], 41: [[0.0208, 0.0186, 0.0156], [0.3433, 0.4539, 0.1055, 1.0]], 42: [[0.0205, 0.0477, 0.0456], [0.1845, 0.5981, 0.1483, 1.0]], 43: [[0.0124, 0.0489, 0.0499], [0.5175, 0.2887, 0.3097, 1.0]], 44: [[0.0113, 0.0485, 0.0461], [0.3786, 0.8318, 0.1601, 1.0]], 45: [[0.0215, 0.016, 0.0166], [0.3173, 0.1281, 0.1034, 1.0]], 46: [[0.0144, 0.0186, 0.0179], [0.1642, 0.5316, 0.5814, 1.0]], 47: [[0.0193, 0.0487, 0.0457], [0.3788, 0.4253, 0.8784, 1.0]], 48: [[0.0185, 0.0197, 0.0198], [0.1222, 0.1569, 0.3243, 1.0]], 49: [[0.0107, 0.0196, 0.0186], [0.7762, 0.2969, 0.1212, 1.0]], 50: [[0.0185, 0.0475, 0.0464], [0.4258, 0.596, 0.1708, 1.0]], 51: [[0.0219, 0.0463, 0.0453], [0.6682, 0.2874, 0.8677, 1.0]], 52: [[0.0216, 0.0184, 0.0167], [0.2306, 0.5239, 0.6937, 1.0]], 53: [[0.0142, 0.0153, 0.0172], [0.4521, 0.3318, 0.8279, 1.0]], 54: [[0.0205, 0.0163, 0.0187], [0.7757, 0.2893, 0.5572, 1.0]], 55: [[0.0111, 0.0487, 0.0458], [0.8859, 0.8771, 0.1853, 1.0]], 56: [[0.0123, 0.0494, 0.046], [0.2689, 0.5928, 0.3019, 1.0]], 57: [[0.0181, 0.0495, 0.046], [0.7198, 0.1171, 0.6706, 1.0]], 58: [[0.0101, 0.0157, 0.0152], [0.7989, 0.6149, 0.3325, 1.0]], 59: [[0.0119, 0.0199, 0.0164], [0.2336, 0.7501, 0.1612, 1.0]], 60: [[0.0138, 0.0165, 0.0183], [0.7101, 0.7613, 0.5269, 1.0]], 61: [[0.0149, 0.0193, 0.0181], [0.2383, 0.8573, 0.657, 1.0]], 62: [[0.0205, 0.0462, 0.0483], [0.8779, 0.876, 0.5988, 1.0]], 63: [[0.0126, 0.0485, 0.0498], [0.2322, 0.7434, 0.3378, 1.0]], 64: [[0.0208, 0.0166, 0.0175], [0.5481, 0.3963, 0.2458, 1.0]], 65: [[0.0123, 0.0458, 0.0473], [0.8363, 0.584, 0.8345, 1.0]], 66: [[0.0152, 0.0461, 0.0488], [0.2852, 0.6593, 0.4106, 1.0]], 67: [[0.0175, 0.018, 0.0156], [0.1591, 0.8999, 0.1738, 1.0]], 68: [[0.0104, 0.0167, 0.0164], [0.8758, 0.6222, 0.7088, 1.0]], 69: [[0.0161, 0.0456, 0.0488], [0.2751, 0.6347, 0.1566, 1.0]], 70: [[0.0118, 0.0183, 0.0177], [0.2599, 0.4178, 0.5191, 1.0]], 71: [[0.0169, 0.0484, 0.0471], [0.6084, 0.1645, 0.8656, 1.0]], 72: [[0.0144, 0.0191, 0.0158], [0.3274, 0.1633, 0.1479, 1.0]], 73: [[0.0196, 0.0182, 0.0178], [0.2689, 0.1003, 0.857, 1.0]], 74: [[0.0176, 0.047, 0.0498], [0.2583, 0.8694, 0.2129, 1.0]], 75: [[0.0216, 0.0166, 0.0178], [0.5812, 0.4521, 0.2085, 1.0]], 76: [[0.0193, 0.0184, 0.0176], [0.2953, 0.1088, 0.1919, 1.0]], 77: [[0.0181, 0.0194, 0.0178], [0.2713, 0.1443, 0.2227, 1.0]], 78: [[0.0164, 0.0156, 0.0177], [0.2967, 0.2742, 0.7264, 1.0]], 79: [[0.0219, 0.049, 0.0496], [0.58, 0.3954, 0.3932, 1.0]], 80: [[0.0128, 0.049, 0.0457], [0.7774, 0.1565, 0.7805, 1.0]], 81: [[0.0124, 0.0489, 0.0468], [0.4525, 0.3835, 0.832, 1.0]], 82: [[0.0186, 0.0486, 0.0464], [0.8187, 0.6362, 0.1601, 1.0]], 83: [[0.0206, 0.046, 0.0488], [0.623, 0.8552, 0.8165, 1.0]], 84: [[0.0117, 0.0165, 0.0153], [0.4809, 0.7749, 0.8775, 1.0]], 85: [[0.0213, 0.0459, 0.0464], [0.4702, 0.4813, 0.3434, 1.0]], 86: [[0.019, 0.045, 0.0475], [0.7441, 0.8441, 0.8382, 1.0]], 87: [[0.0122, 0.0151, 0.0156], [0.2572, 0.4567, 0.6492, 1.0]], 88: [[0.0215, 0.017, 0.0172], [0.2149, 0.479, 0.4043, 1.0]], 89: [[0.0115, 0.0182, 0.0173], [0.1044, 0.857, 0.6485, 1.0]], 90: [[0.0153, 0.0463, 0.0498], [0.3174, 0.8098, 0.6281, 1.0]], 91: [[0.0149, 0.0485, 0.0462], [0.811, 0.6085, 0.3801, 1.0]], 92: [[0.0164, 0.0167, 0.0152], [0.4186, 0.4523, 0.6883, 1.0]], 93: [[0.012, 0.045, 0.048], [0.7487, 0.1029, 0.1244, 1.0]], 94: [[0.012, 0.0158, 0.0167], [0.6503, 0.2879, 0.7908, 1.0]], 95: [[0.0139, 0.0474, 0.045], [0.5876, 0.8936, 0.6603, 1.0]], 96: [[0.0161, 0.0473, 0.0474], [0.8403, 0.7127, 0.4783, 1.0]], 97: [[0.0194, 0.0184, 0.0176], [0.2173, 0.1284, 0.533, 1.0]], 98: [[0.0213, 0.0478, 0.0486], [0.6218, 0.1434, 0.7881, 1.0]], 99: [[0.0187, 0.0457, 0.0454], [0.701, 0.5286, 0.1528, 1.0]], 100: [[0.0212, 0.0164, 0.0173], [0.1359, 0.3418, 0.1702, 1.0]], 101: [[0.0115, 0.0494, 0.0474], [0.1115, 0.6422, 0.2154, 1.0]], 102: [[0.0201, 0.0482, 0.0467], [0.3541, 0.5037, 0.1197, 1.0]], 103: [[0.0113, 0.018, 0.018], [0.2665, 0.6401, 0.5914, 1.0]], 104: [[0.0186, 0.0452, 0.0464], [0.5232, 0.5509, 0.4378, 1.0]], 105: [[0.016, 0.0453, 0.0474], [0.3037, 0.1085, 0.4448, 1.0]], 106: [[0.0173, 0.0455, 0.0477], [0.7182, 0.1459, 0.3942, 1.0]], 107: [[0.0202, 0.0184, 0.0151], [0.4666, 0.5228, 0.4178, 1.0]], 108: [[0.0172, 0.018, 0.0193], [0.2233, 0.215, 0.2582, 1.0]], 109: [[0.0167, 0.0175, 0.0187], [0.8508, 0.2018, 0.5478, 1.0]], 110: [[0.0108, 0.0175, 0.0179], [0.5524, 0.4684, 0.4966, 1.0]], 111: [[0.0143, 0.0191, 0.0193], [0.7332, 0.8646, 0.5665, 1.0]], 112: [[0.0139, 0.0482, 0.0454], [0.3081, 0.1629, 0.2863, 1.0]], 113: [[0.0147, 0.0473, 0.0491], [0.3088, 0.7483, 0.6019, 1.0]], 114: [[0.0185, 0.0466, 0.05], [0.116, 0.1138, 0.7998, 1.0]], 115: [[0.011, 0.0192, 0.0193], [0.7837, 0.3943, 0.6835, 1.0]], 116: [[0.0131, 0.0178, 0.0177], [0.8052, 0.3264, 0.2248, 1.0]], 117: [[0.0117, 0.0182, 0.02], [0.3734, 0.845, 0.856, 1.0]], 118: [[0.0141, 0.0185, 0.017], [0.7165, 0.2071, 0.2128, 1.0]], 119: [[0.0138, 0.0462, 0.0474], [0.1744, 0.39, 0.2223, 1.0]], 120: [[0.0171, 0.0461, 0.0475], [0.3088, 0.2349, 0.4867, 1.0]], 121: [[0.0117, 0.016, 0.0181], [0.2226, 0.8378, 0.3225, 1.0]], 122: [[0.0105, 0.0158, 0.0171], [0.8364, 0.7014, 0.8305, 1.0]], 123: [[0.0189, 0.0165, 0.0187], [0.297, 0.3524, 0.7998, 1.0]], 124: [[0.0215, 0.0461, 0.045], [0.4239, 0.1262, 0.3183, 1.0]], 125: [[0.0172, 0.0153, 0.0164], [0.1557, 0.6216, 0.3463, 1.0]], 126: [[0.0103, 0.0156, 0.0183], [0.7072, 0.1275, 0.6285, 1.0]], 127: [[0.0181, 0.0159, 0.016], [0.237, 0.2653, 0.556, 1.0]], 128: [[0.0173, 0.0464, 0.0495], [0.4488, 0.2617, 0.806, 1.0]], 129: [[0.0189, 0.0179, 0.0169], [0.7203, 0.2278, 0.7525, 1.0]], 130: [[0.0123, 0.0181, 0.0176], [0.2859, 0.6673, 0.7392, 1.0]], 131: [[0.0195, 0.019, 0.0177], [0.437, 0.4429, 0.1932, 1.0]], 132: [[0.0116, 0.019, 0.0183], [0.4918, 0.1073, 0.5199, 1.0]], 133: [[0.0104, 0.0486, 0.0455], [0.4454, 0.6188, 0.4804, 1.0]], 134: [[0.0147, 0.0182, 0.0187], [0.8536, 0.2603, 0.6406, 1.0]], 135: [[0.0101, 0.0155, 0.0154], [0.6714, 0.4946, 0.3238, 1.0]], 136: [[0.021, 0.0174, 0.0163], [0.6446, 0.837, 0.5574, 1.0]], 137: [[0.0213, 0.049, 0.0485], [0.4923, 0.1998, 0.6318, 1.0]], 138: [[0.0216, 0.0158, 0.0157], [0.7656, 0.8456, 0.4701, 1.0]], 139: [[0.0211, 0.0457, 0.0455], [0.8384, 0.2285, 0.1857, 1.0]], 140: [[0.0153, 0.0462, 0.049], [0.4427, 0.8192, 0.1261, 1.0]], 141: [[0.0173, 0.0486, 0.0458], [0.1271, 0.7625, 0.6444, 1.0]], 142: [[0.0111, 0.0167, 0.0181], [0.2497, 0.3426, 0.5949, 1.0]], 143: [[0.0129, 0.0162, 0.0176], [0.1535, 0.4215, 0.1287, 1.0]], 144: [[0.0173, 0.05, 0.0454], [0.7078, 0.6966, 0.7814, 1.0]], 145: [[0.0173, 0.0173, 0.0189], [0.8446, 0.8726, 0.1454, 1.0]], 146: [[0.0175, 0.0467, 0.0452], [0.3159, 0.2075, 0.4135, 1.0]], 147: [[0.0189, 0.0472, 0.0489], [0.3509, 0.7609, 0.247, 1.0]], 148: [[0.011, 0.0473, 0.0477], [0.5892, 0.7222, 0.7374, 1.0]], 149: [[0.0129, 0.0474, 0.0453], [0.4001, 0.3362, 0.3461, 1.0]], 150: [[0.0147, 0.0492, 0.0495], [0.3652, 0.3845, 0.7267, 1.0]], 151: [[0.0177, 0.0455, 0.048], [0.3723, 0.8321, 0.8714, 1.0]], 152: [[0.0166, 0.0487, 0.0466], [0.4703, 0.3864, 0.4177, 1.0]], 153: [[0.0153, 0.0499, 0.0499], [0.7648, 0.7416, 0.7357, 1.0]], 154: [[0.0134, 0.0493, 0.0457], [0.3237, 0.6859, 0.2325, 1.0]], 155: [[0.018, 0.018, 0.0197], [0.7054, 0.213, 0.2625, 1.0]], 156: [[0.0162, 0.0452, 0.0477], [0.6377, 0.6928, 0.5753, 1.0]], 157: [[0.0148, 0.046, 0.0469], [0.5586, 0.8162, 0.6545, 1.0]], 158: [[0.0211, 0.0157, 0.0184], [0.8362, 0.3753, 0.7235, 1.0]], 159: [[0.0214, 0.0196, 0.0156], [0.5081, 0.4477, 0.5708, 1.0]], 160: [[0.0205, 0.0492, 0.0497], [0.1427, 0.7907, 0.8017, 1.0]], 161: [[0.0133, 0.0166, 0.0178], [0.7213, 0.3542, 0.781, 1.0]], 162: [[0.0174, 0.0483, 0.0453], [0.449, 0.4152, 0.215, 1.0]], 163: [[0.0118, 0.0183, 0.0173], [0.8446, 0.4294, 0.1347, 1.0]], 164: [[0.0127, 0.0193, 0.0175], [0.5537, 0.6098, 0.4473, 1.0]], 165: [[0.0109, 0.046, 0.045], [0.8556, 0.5505, 0.5512, 1.0]], 166: [[0.0147, 0.0457, 0.0481], [0.2713, 0.7413, 0.2877, 1.0]], 167: [[0.0144, 0.0462, 0.0459], [0.165, 0.5973, 0.8201, 1.0]], 168: [[0.0212, 0.0157, 0.0171], [0.214, 0.1838, 0.1647, 1.0]], 169: [[0.019, 0.0468, 0.0494], [0.6828, 0.5465, 0.1629, 1.0]], 170: [[0.0185, 0.0177, 0.0197], [0.5079, 0.5769, 0.7325, 1.0]], 171: [[0.0203, 0.0477, 0.0491], [0.2193, 0.22, 0.6761, 1.0]], 172: [[0.0103, 0.0186, 0.0178], [0.7868, 0.5255, 0.352, 1.0]], 173: [[0.0196, 0.0184, 0.0155], [0.4391, 0.3064, 0.5309, 1.0]], 174: [[0.0189, 0.0452, 0.0459], [0.6549, 0.5861, 0.727, 1.0]], 175: [[0.0117, 0.0485, 0.0461], [0.1796, 0.2544, 0.4843, 1.0]], 176: [[0.0138, 0.0162, 0.0169], [0.7049, 0.7908, 0.8643, 1.0]], 177: [[0.0165, 0.0452, 0.0477], [0.6733, 0.1488, 0.204, 1.0]], 178: [[0.0166, 0.0184, 0.0165], [0.788, 0.1421, 0.3831, 1.0]], 179: [[0.013, 0.0183, 0.0167], [0.2977, 0.5449, 0.3888, 1.0]], 180: [[0.0154, 0.0172, 0.0179], [0.4412, 0.8275, 0.5608, 1.0]], 181: [[0.0104, 0.0196, 0.016], [0.1381, 0.1568, 0.6259, 1.0]], 182: [[0.0133, 0.0475, 0.0488], [0.1533, 0.3559, 0.1507, 1.0]], 183: [[0.0155, 0.0185, 0.0164], [0.5281, 0.6594, 0.1726, 1.0]], 184: [[0.0211, 0.0485, 0.0467], [0.8173, 0.4514, 0.5882, 1.0]], 185: [[0.0143, 0.0153, 0.0158], [0.8264, 0.3948, 0.6738, 1.0]], 186: [[0.0109, 0.0493, 0.0474], [0.8409, 0.4085, 0.5155, 1.0]], 187: [[0.0128, 0.0192, 0.0157], [0.5373, 0.433, 0.8742, 1.0]], 188: [[0.0209, 0.0455, 0.0498], [0.1971, 0.8868, 0.2746, 1.0]], 189: [[0.018, 0.0454, 0.0469], [0.4346, 0.538, 0.2588, 1.0]], 190: [[0.0116, 0.0165, 0.0163], [0.4802, 0.811, 0.4269, 1.0]], 191: [[0.0109, 0.018, 0.0163], [0.7326, 0.4583, 0.8561, 1.0]], 192: [[0.014, 0.0166, 0.0167], [0.2543, 0.5965, 0.3042, 1.0]], 193: [[0.013, 0.0166, 0.0165], [0.5034, 0.7119, 0.4114, 1.0]], 194: [[0.0177, 0.0182, 0.0182], [0.4691, 0.8266, 0.8231, 1.0]], 195: [[0.018, 0.0498, 0.0485], [0.426, 0.4928, 0.5865, 1.0]], 196: [[0.0108, 0.0481, 0.0478], [0.4718, 0.469, 0.2384, 1.0]], 197: [[0.0208, 0.0499, 0.0488], [0.5592, 0.7773, 0.8967, 1.0]], 198: [[0.0115, 0.0458, 0.0462], [0.4334, 0.1846, 0.3883, 1.0]], 199: [[0.0206, 0.0456, 0.0495], [0.6058, 0.7777, 0.429, 1.0]], 200: [[0.0177, 0.0464, 0.0472], [0.1362, 0.1824, 0.5835, 1.0]], 201: [[0.014, 0.0191, 0.0181], [0.7206, 0.3437, 0.37, 1.0]], 202: [[0.0118, 0.0474, 0.0455], [0.1148, 0.262, 0.4485, 1.0]], 203: [[0.0188, 0.0455, 0.0477], [0.5114, 0.8115, 0.7676, 1.0]], 204: [[0.0172, 0.0499, 0.0462], [0.5308, 0.8111, 0.4337, 1.0]], 205: [[0.0103, 0.0178, 0.0162], [0.1149, 0.2917, 0.5876, 1.0]], 206: [[0.0153, 0.0476, 0.0499], [0.4999, 0.5089, 0.4157, 1.0]], 207: [[0.0102, 0.0171, 0.0199], [0.2822, 0.6051, 0.1381, 1.0]], 208: [[0.0117, 0.0166, 0.0161], [0.628, 0.8254, 0.2762, 1.0]], 209: [[0.0214, 0.0191, 0.0198], [0.7102, 0.6143, 0.7504, 1.0]], 210: [[0.011, 0.0197, 0.0181], [0.4121, 0.5831, 0.6519, 1.0]], 211: [[0.0112, 0.0482, 0.0489], [0.5244, 0.4626, 0.3611, 1.0]], 212: [[0.0207, 0.0452, 0.049], [0.5354, 0.5679, 0.6929, 1.0]], 213: [[0.0181, 0.0178, 0.0199], [0.1985, 0.8685, 0.6615, 1.0]], 214: [[0.0161, 0.018, 0.0184], [0.8486, 0.384, 0.3626, 1.0]], 215: [[0.014, 0.0186, 0.0164], [0.8014, 0.1281, 0.8958, 1.0]], 216: [[0.0186, 0.0174, 0.0167], [0.2922, 0.3107, 0.4267, 1.0]], 217: [[0.0146, 0.0156, 0.0177], [0.8846, 0.8364, 0.6047, 1.0]], 218: [[0.0112, 0.047, 0.0465], [0.5951, 0.6992, 0.1147, 1.0]], 219: [[0.0158, 0.0489, 0.0452], [0.2499, 0.7388, 0.8561, 1.0]], 220: [[0.0106, 0.048, 0.0469], [0.7771, 0.1326, 0.4109, 1.0]], 221: [[0.0214, 0.0464, 0.0482], [0.3484, 0.482, 0.6405, 1.0]], 222: [[0.0114, 0.0465, 0.0492], [0.2858, 0.3632, 0.6027, 1.0]], 223: [[0.0153, 0.0173, 0.0154], [0.4912, 0.2048, 0.5629, 1.0]], 224: [[0.0109, 0.016, 0.0168], [0.3627, 0.4122, 0.7033, 1.0]], 225: [[0.0105, 0.0492, 0.046], [0.5845, 0.6587, 0.7785, 1.0]], 226: [[0.0215, 0.0199, 0.0167], [0.6111, 0.4582, 0.5161, 1.0]], 227: [[0.0177, 0.0452, 0.0487], [0.887, 0.7108, 0.8265, 1.0]], 228: [[0.0182, 0.0454, 0.0477], [0.4018, 0.2768, 0.6919, 1.0]], 229: [[0.0141, 0.0466, 0.0481], [0.6269, 0.4912, 0.5564, 1.0]], 230: [[0.0191, 0.0195, 0.0175], [0.428, 0.3875, 0.1871, 1.0]], 231: [[0.0174, 0.0491, 0.048], [0.3444, 0.1266, 0.8867, 1.0]], 232: [[0.0186, 0.0159, 0.0187], [0.2639, 0.3621, 0.5915, 1.0]], 233: [[0.0209, 0.0193, 0.0156], [0.7967, 0.8694, 0.8187, 1.0]], 234: [[0.0154, 0.0191, 0.0199], [0.2638, 0.3083, 0.3342, 1.0]], 235: [[0.0143, 0.0464, 0.0469], [0.7041, 0.4356, 0.8954, 1.0]], 236: [[0.0161, 0.0185, 0.0166], [0.1647, 0.6814, 0.3447, 1.0]], 237: [[0.0113, 0.0493, 0.0467], [0.6876, 0.7429, 0.7811, 1.0]], 238: [[0.0156, 0.0473, 0.0485], [0.5644, 0.27, 0.7321, 1.0]], 239: [[0.0163, 0.0461, 0.049], [0.1809, 0.2318, 0.6899, 1.0]], 240: [[0.0152, 0.0492, 0.0493], [0.2864, 0.6311, 0.2679, 1.0]], 241: [[0.015, 0.0162, 0.0169], [0.7129, 0.8796, 0.4358, 1.0]], 242: [[0.0131, 0.0483, 0.0494], [0.1308, 0.4755, 0.1302, 1.0]], 243: [[0.0208, 0.0183, 0.0167], [0.5402, 0.816, 0.8953, 1.0]], 244: [[0.0219, 0.0188, 0.0188], [0.4618, 0.2172, 0.1317, 1.0]], 245: [[0.0144, 0.0157, 0.0165], [0.7438, 0.3295, 0.4762, 1.0]], 246: [[0.0214, 0.0479, 0.047], [0.2621, 0.1049, 0.6837, 1.0]], 247: [[0.015, 0.0486, 0.0472], [0.1136, 0.1078, 0.1676, 1.0]], 248: [[0.0202, 0.0474, 0.0472], [0.5658, 0.7594, 0.5825, 1.0]], 249: [[0.0143, 0.0171, 0.017], [0.1856, 0.2551, 0.2272, 1.0]]} caps = {0: [[0.014, 0.0462, 0.0456], [0.7552, 0.5999, 0.1518, 1.0]], 1: [[0.0216, 0.0457, 0.0467], [0.2093, 0.2803, 0.6963, 1.0]], 2: [[0.0164, 0.0488, 0.0469], [0.533, 0.7576, 0.3658, 1.0]], 3: [[0.0155, 0.0467, 0.0474], [0.7339, 0.4907, 0.8917, 1.0]], 4: [[0.018, 0.05, 0.0487], [0.7503, 0.278, 0.1356, 1.0]], 5: [[0.0124, 0.0484, 0.0491], [0.6518, 0.8439, 0.2416, 1.0]], 6: [[0.0197, 0.0483, 0.0459], [0.1324, 0.4663, 0.3485, 1.0]], 7: [[0.0129, 0.0476, 0.0495], [0.2285, 0.6781, 0.8076, 1.0]], 8: [[0.0135, 0.0455, 0.0458], [0.6364, 0.4566, 0.7004, 1.0]], 9: [[0.0193, 0.0485, 0.0487], [0.2589, 0.7651, 0.819, 1.0]], 10: [[0.0159, 0.0461, 0.0485], [0.3405, 0.7015, 0.4264, 1.0]], 11: [[0.0124, 0.0465, 0.0454], [0.4863, 0.8691, 0.3991, 1.0]], 12: [[0.0129, 0.046, 0.0493], [0.3887, 0.4909, 0.2819, 1.0]], 13: [[0.0152, 0.0494, 0.048], [0.336, 0.572, 0.237, 1.0]], 14: [[0.017, 0.0463, 0.049], [0.1482, 0.6936, 0.3843, 1.0]], 15: [[0.0212, 0.0489, 0.0472], [0.1987, 0.7815, 0.2925, 1.0]], 16: [[0.0159, 0.0454, 0.0494], [0.4086, 0.5416, 0.2677, 1.0]], 17: [[0.0209, 0.0484, 0.0461], [0.6409, 0.6521, 0.4668, 1.0]], 18: [[0.0117, 0.0461, 0.0493], [0.1992, 0.4997, 0.1159, 1.0]], 19: [[0.0167, 0.0474, 0.0474], [0.2572, 0.8054, 0.2026, 1.0]], 20: [[0.0176, 0.0479, 0.0453], [0.8732, 0.411, 0.7781, 1.0]], 21: [[0.0131, 0.0483, 0.05], [0.7602, 0.7548, 0.5154, 1.0]], 22: [[0.0127, 0.047, 0.0495], [0.1458, 0.7376, 0.6436, 1.0]], 23: [[0.0185, 0.0454, 0.048], [0.4824, 0.429, 0.2692, 1.0]], 24: [[0.0183, 0.0467, 0.0467], [0.2042, 0.1897, 0.5394, 1.0]], 25: [[0.0166, 0.0491, 0.0487], [0.2161, 0.8852, 0.116, 1.0]], 26: [[0.0155, 0.0456, 0.0485], [0.4469, 0.7486, 0.2039, 1.0]], 27: [[0.017, 0.0489, 0.0491], [0.2923, 0.296, 0.463, 1.0]], 28: [[0.0172, 0.0469, 0.0488], [0.5067, 0.8173, 0.2359, 1.0]], 29: [[0.0114, 0.0467, 0.0487], [0.7463, 0.4096, 0.2019, 1.0]], 30: [[0.0156, 0.0491, 0.046], [0.5904, 0.1744, 0.3366, 1.0]], 31: [[0.0176, 0.0478, 0.0451], [0.3127, 0.827, 0.4654, 1.0]], 32: [[0.0164, 0.0476, 0.0455], [0.5944, 0.7672, 0.201, 1.0]], 33: [[0.0198, 0.0459, 0.0474], [0.1264, 0.7821, 0.2798, 1.0]], 34: [[0.0191, 0.0456, 0.0482], [0.8076, 0.7139, 0.8896, 1.0]], 35: [[0.0165, 0.0494, 0.0481], [0.8064, 0.5783, 0.2943, 1.0]], 36: [[0.0153, 0.0464, 0.0482], [0.7696, 0.8593, 0.8418, 1.0]], 37: [[0.013, 0.0491, 0.0486], [0.7956, 0.5799, 0.2945, 1.0]], 38: [[0.0164, 0.0481, 0.0478], [0.1812, 0.6424, 0.6937, 1.0]], 39: [[0.0207, 0.0492, 0.048], [0.4824, 0.3024, 0.5251, 1.0]], 40: [[0.0189, 0.0496, 0.0465], [0.1554, 0.181, 0.8387, 1.0]], 41: [[0.0143, 0.0473, 0.0488], [0.2187, 0.5529, 0.1564, 1.0]], 42: [[0.0135, 0.0489, 0.0485], [0.2515, 0.7067, 0.6655, 1.0]], 43: [[0.0152, 0.0465, 0.0491], [0.1965, 0.7194, 0.3763, 1.0]], 44: [[0.0178, 0.0467, 0.0491], [0.6523, 0.6147, 0.4241, 1.0]], 45: [[0.0163, 0.0497, 0.0474], [0.106, 0.4336, 0.751, 1.0]], 46: [[0.013, 0.0455, 0.0497], [0.127, 0.8312, 0.266, 1.0]], 47: [[0.0144, 0.0491, 0.0497], [0.4803, 0.174, 0.3265, 1.0]], 48: [[0.0165, 0.0495, 0.0468], [0.3985, 0.1923, 0.5643, 1.0]], 49: [[0.0108, 0.0464, 0.0464], [0.8164, 0.1074, 0.6519, 1.0]], 50: [[0.0125, 0.0496, 0.0476], [0.3726, 0.3997, 0.7442, 1.0]], 51: [[0.0202, 0.049, 0.0452], [0.6754, 0.1224, 0.4845, 1.0]], 52: [[0.0152, 0.0491, 0.047], [0.4577, 0.1851, 0.8076, 1.0]], 53: [[0.0157, 0.047, 0.0469], [0.7069, 0.7766, 0.7214, 1.0]], 54: [[0.0164, 0.0491, 0.0478], [0.3279, 0.7244, 0.7102, 1.0]], 55: [[0.0215, 0.0495, 0.0483], [0.432, 0.3687, 0.2987, 1.0]], 56: [[0.0111, 0.0474, 0.0478], [0.4287, 0.5973, 0.2854, 1.0]], 57: [[0.0198, 0.047, 0.047], [0.6354, 0.6947, 0.3978, 1.0]], 58: [[0.0157, 0.045, 0.0482], [0.7976, 0.8864, 0.5672, 1.0]], 59: [[0.0145, 0.0473, 0.0467], [0.7185, 0.634, 0.3194, 1.0]], 60: [[0.0136, 0.0457, 0.0498], [0.6684, 0.8659, 0.6903, 1.0]], 61: [[0.0126, 0.0458, 0.0499], [0.582, 0.6602, 0.855, 1.0]], 62: [[0.0112, 0.0471, 0.0475], [0.7666, 0.4851, 0.2647, 1.0]], 63: [[0.0131, 0.0476, 0.0482], [0.8042, 0.4006, 0.1974, 1.0]], 64: [[0.0209, 0.0452, 0.05], [0.68, 0.7969, 0.8836, 1.0]], 65: [[0.0161, 0.0473, 0.048], [0.6047, 0.4247, 0.859, 1.0]], 66: [[0.0125, 0.0494, 0.048], [0.1788, 0.8538, 0.4675, 1.0]], 67: [[0.0118, 0.0461, 0.049], [0.7032, 0.7573, 0.5692, 1.0]], 68: [[0.0135, 0.049, 0.0475], [0.1252, 0.7411, 0.6948, 1.0]], 69: [[0.019, 0.0477, 0.0451], [0.6571, 0.5464, 0.3748, 1.0]], 70: [[0.0169, 0.0491, 0.0474], [0.4295, 0.4868, 0.5209, 1.0]], 71: [[0.0145, 0.0468, 0.0468], [0.4492, 0.7508, 0.8496, 1.0]], 72: [[0.0118, 0.0485, 0.0469], [0.7431, 0.3396, 0.4076, 1.0]], 73: [[0.021, 0.0482, 0.0476], [0.1269, 0.2122, 0.6914, 1.0]], 74: [[0.0179, 0.0479, 0.0475], [0.2867, 0.8822, 0.2719, 1.0]], 75: [[0.0155, 0.0473, 0.0485], [0.1142, 0.8501, 0.6702, 1.0]], 76: [[0.0106, 0.0493, 0.0461], [0.7289, 0.8452, 0.7046, 1.0]], 77: [[0.0105, 0.0451, 0.0481], [0.7829, 0.3362, 0.4289, 1.0]], 78: [[0.0111, 0.0452, 0.0451], [0.8546, 0.3946, 0.6141, 1.0]], 79: [[0.013, 0.0473, 0.0464], [0.7705, 0.1568, 0.4001, 1.0]], 80: [[0.0109, 0.0476, 0.0464], [0.7354, 0.5219, 0.8647, 1.0]], 81: [[0.0139, 0.047, 0.0487], [0.656, 0.3886, 0.8112, 1.0]], 82: [[0.0191, 0.0496, 0.0454], [0.2665, 0.2847, 0.8008, 1.0]], 83: [[0.013, 0.0486, 0.0474], [0.2069, 0.1691, 0.6395, 1.0]], 84: [[0.0179, 0.0463, 0.0495], [0.4257, 0.7388, 0.7933, 1.0]], 85: [[0.0194, 0.047, 0.0455], [0.8254, 0.4672, 0.3752, 1.0]], 86: [[0.0196, 0.0451, 0.0464], [0.3792, 0.8519, 0.6723, 1.0]], 87: [[0.015, 0.0492, 0.0453], [0.2409, 0.1773, 0.7596, 1.0]], 88: [[0.0172, 0.0456, 0.0471], [0.3939, 0.6507, 0.3653, 1.0]], 89: [[0.0186, 0.0496, 0.0486], [0.5724, 0.1403, 0.5072, 1.0]], 90: [[0.0169, 0.0484, 0.0484], [0.1981, 0.3563, 0.7455, 1.0]], 91: [[0.0216, 0.0461, 0.0472], [0.7385, 0.2926, 0.8827, 1.0]], 92: [[0.0172, 0.0487, 0.0476], [0.1522, 0.3485, 0.5212, 1.0]], 93: [[0.0113, 0.0455, 0.0489], [0.1023, 0.8846, 0.1536, 1.0]], 94: [[0.0162, 0.046, 0.0458], [0.671, 0.378, 0.2139, 1.0]], 95: [[0.0189, 0.0452, 0.0465], [0.5306, 0.4105, 0.8137, 1.0]], 96: [[0.0177, 0.0466, 0.0457], [0.4245, 0.1615, 0.7385, 1.0]], 97: [[0.019, 0.0471, 0.046], [0.4411, 0.3306, 0.8051, 1.0]], 98: [[0.0213, 0.0463, 0.0461], [0.6587, 0.8677, 0.6339, 1.0]], 99: [[0.0129, 0.0498, 0.0495], [0.2374, 0.1145, 0.2975, 1.0]], 100: [[0.0119, 0.0466, 0.0457], [0.1556, 0.7113, 0.3381, 1.0]], 101: [[0.0155, 0.0454, 0.0468], [0.3652, 0.3653, 0.1902, 1.0]], 102: [[0.0119, 0.0488, 0.0498], [0.8815, 0.2709, 0.3494, 1.0]], 103: [[0.0142, 0.0455, 0.0473], [0.5871, 0.6631, 0.3971, 1.0]], 104: [[0.0123, 0.0473, 0.0489], [0.2607, 0.5424, 0.5473, 1.0]], 105: [[0.0161, 0.0482, 0.0461], [0.6337, 0.3421, 0.5175, 1.0]], 106: [[0.018, 0.0455, 0.0494], [0.2013, 0.3448, 0.8066, 1.0]], 107: [[0.0166, 0.0479, 0.0462], [0.4087, 0.2224, 0.206, 1.0]], 108: [[0.0149, 0.0469, 0.0474], [0.5612, 0.6575, 0.8316, 1.0]], 109: [[0.0142, 0.0474, 0.0489], [0.214, 0.8244, 0.3106, 1.0]], 110: [[0.0141, 0.049, 0.0475], [0.355, 0.3569, 0.2908, 1.0]], 111: [[0.0138, 0.0494, 0.0484], [0.508, 0.2495, 0.7607, 1.0]], 112: [[0.0162, 0.0476, 0.0496], [0.113, 0.7339, 0.1047, 1.0]], 113: [[0.0142, 0.0457, 0.0487], [0.897, 0.5763, 0.7634, 1.0]], 114: [[0.0149, 0.0464, 0.0482], [0.5559, 0.3195, 0.2474, 1.0]], 115: [[0.0142, 0.0462, 0.0452], [0.3724, 0.3881, 0.5863, 1.0]], 116: [[0.0164, 0.0459, 0.0466], [0.4296, 0.3586, 0.2608, 1.0]], 117: [[0.018, 0.049, 0.0472], [0.3819, 0.6104, 0.47, 1.0]], 118: [[0.0118, 0.0499, 0.0475], [0.345, 0.5759, 0.5034, 1.0]], 119: [[0.017, 0.0474, 0.0471], [0.1843, 0.7297, 0.5133, 1.0]], 120: [[0.0109, 0.0491, 0.0466], [0.1814, 0.7445, 0.3425, 1.0]], 121: [[0.0116, 0.0491, 0.0495], [0.4409, 0.6216, 0.359, 1.0]], 122: [[0.0138, 0.0495, 0.0492], [0.3301, 0.1702, 0.5307, 1.0]], 123: [[0.012, 0.0475, 0.0497], [0.1439, 0.1612, 0.4065, 1.0]], 124: [[0.0214, 0.0485, 0.0469], [0.5616, 0.2181, 0.7746, 1.0]], 125: [[0.0121, 0.0493, 0.0465], [0.801, 0.4947, 0.2195, 1.0]], 126: [[0.0211, 0.0456, 0.0457], [0.7768, 0.3856, 0.8558, 1.0]], 127: [[0.019, 0.0483, 0.0464], [0.4195, 0.4983, 0.3311, 1.0]], 128: [[0.019, 0.0454, 0.0484], [0.1451, 0.2257, 0.3421, 1.0]], 129: [[0.0144, 0.0463, 0.0474], [0.275, 0.7622, 0.4729, 1.0]], 130: [[0.0191, 0.0474, 0.0495], [0.8404, 0.1813, 0.4804, 1.0]], 131: [[0.0161, 0.0465, 0.048], [0.8094, 0.7558, 0.363, 1.0]], 132: [[0.0198, 0.0463, 0.0487], [0.4962, 0.5524, 0.6939, 1.0]], 133: [[0.0132, 0.0459, 0.0481], [0.7799, 0.3286, 0.59, 1.0]], 134: [[0.019, 0.0456, 0.0468], [0.5001, 0.7448, 0.146, 1.0]], 135: [[0.0216, 0.0499, 0.0476], [0.681, 0.1947, 0.5377, 1.0]], 136: [[0.0171, 0.0478, 0.0495], [0.2267, 0.2009, 0.3166, 1.0]], 137: [[0.0168, 0.0478, 0.0476], [0.1241, 0.8835, 0.1198, 1.0]], 138: [[0.0166, 0.0477, 0.0466], [0.8359, 0.8665, 0.589, 1.0]], 139: [[0.0136, 0.0497, 0.0484], [0.641, 0.8936, 0.1003, 1.0]], 140: [[0.0124, 0.048, 0.0458], [0.8078, 0.3435, 0.7498, 1.0]], 141: [[0.0107, 0.0464, 0.0455], [0.5088, 0.6734, 0.2235, 1.0]], 142: [[0.0211, 0.0452, 0.0461], [0.8815, 0.8621, 0.3036, 1.0]], 143: [[0.0143, 0.047, 0.0489], [0.5841, 0.5472, 0.5041, 1.0]], 144: [[0.0163, 0.0474, 0.0493], [0.3747, 0.4394, 0.3849, 1.0]], 145: [[0.0146, 0.0495, 0.0453], [0.5791, 0.8076, 0.8573, 1.0]], 146: [[0.0134, 0.0477, 0.0459], [0.8958, 0.3424, 0.784, 1.0]], 147: [[0.0137, 0.0476, 0.0451], [0.5203, 0.5384, 0.1478, 1.0]], 148: [[0.0137, 0.046, 0.0485], [0.4583, 0.8113, 0.1408, 1.0]], 149: [[0.0138, 0.0466, 0.0492], [0.2218, 0.6236, 0.4278, 1.0]], 150: [[0.0195, 0.0486, 0.0481], [0.8044, 0.8925, 0.5465, 1.0]], 151: [[0.0107, 0.0489, 0.0459], [0.2298, 0.5506, 0.7442, 1.0]], 152: [[0.0131, 0.049, 0.0496], [0.5106, 0.5685, 0.2025, 1.0]], 153: [[0.0179, 0.0462, 0.05], [0.5488, 0.4512, 0.5557, 1.0]], 154: [[0.0167, 0.0452, 0.0468], [0.2355, 0.5682, 0.3305, 1.0]], 155: [[0.0145, 0.0494, 0.046], [0.7485, 0.5168, 0.5872, 1.0]], 156: [[0.0161, 0.0486, 0.05], [0.7493, 0.2481, 0.7278, 1.0]], 157: [[0.013, 0.0478, 0.0468], [0.7072, 0.558, 0.6206, 1.0]], 158: [[0.021, 0.0454, 0.0476], [0.3211, 0.7554, 0.1626, 1.0]], 159: [[0.021, 0.0463, 0.046], [0.1345, 0.2543, 0.8641, 1.0]], 160: [[0.0159, 0.0481, 0.0458], [0.8549, 0.3534, 0.7343, 1.0]], 161: [[0.0167, 0.0463, 0.0482], [0.1375, 0.2357, 0.7441, 1.0]], 162: [[0.0196, 0.0465, 0.0463], [0.7306, 0.1254, 0.4492, 1.0]], 163: [[0.0112, 0.0458, 0.0462], [0.3003, 0.211, 0.472, 1.0]], 164: [[0.0107, 0.048, 0.0462], [0.3184, 0.1849, 0.8337, 1.0]], 165: [[0.0101, 0.0487, 0.0451], [0.1582, 0.3694, 0.8754, 1.0]], 166: [[0.0126, 0.0459, 0.047], [0.8508, 0.4215, 0.384, 1.0]], 167: [[0.0185, 0.0453, 0.0499], [0.8908, 0.4287, 0.7101, 1.0]], 168: [[0.01, 0.048, 0.05], [0.1278, 0.6523, 0.1568, 1.0]], 169: [[0.0216, 0.0477, 0.0493], [0.5271, 0.4004, 0.7093, 1.0]], 170: [[0.0156, 0.0485, 0.0491], [0.5416, 0.4668, 0.8501, 1.0]], 171: [[0.0189, 0.0465, 0.0489], [0.1752, 0.8665, 0.4717, 1.0]], 172: [[0.0134, 0.0488, 0.0497], [0.1015, 0.5939, 0.2164, 1.0]], 173: [[0.0129, 0.0462, 0.0454], [0.4864, 0.7432, 0.5212, 1.0]], 174: [[0.0154, 0.049, 0.0461], [0.4718, 0.836, 0.8369, 1.0]], 175: [[0.0119, 0.046, 0.0465], [0.4722, 0.5093, 0.448, 1.0]], 176: [[0.0176, 0.049, 0.0471], [0.5645, 0.1044, 0.1248, 1.0]], 177: [[0.0176, 0.0487, 0.0495], [0.5378, 0.2277, 0.4826, 1.0]], 178: [[0.0118, 0.05, 0.0489], [0.1466, 0.5063, 0.7052, 1.0]], 179: [[0.0206, 0.049, 0.0459], [0.6002, 0.7016, 0.157, 1.0]], 180: [[0.0198, 0.0471, 0.0461], [0.3864, 0.7201, 0.4741, 1.0]], 181: [[0.0195, 0.046, 0.0457], [0.1372, 0.6083, 0.8012, 1.0]], 182: [[0.0207, 0.0463, 0.0488], [0.2979, 0.2253, 0.1311, 1.0]], 183: [[0.0137, 0.0498, 0.0487], [0.6296, 0.5709, 0.6801, 1.0]], 184: [[0.0123, 0.0462, 0.0482], [0.7159, 0.2452, 0.5007, 1.0]], 185: [[0.0144, 0.0474, 0.0495], [0.5067, 0.2733, 0.4707, 1.0]], 186: [[0.0161, 0.047, 0.0499], [0.2278, 0.6304, 0.3455, 1.0]], 187: [[0.0202, 0.0467, 0.0493], [0.3487, 0.6855, 0.2026, 1.0]], 188: [[0.0211, 0.0495, 0.0493], [0.7012, 0.4327, 0.5554, 1.0]], 189: [[0.0164, 0.0491, 0.0474], [0.7327, 0.6326, 0.4928, 1.0]], 190: [[0.0123, 0.0487, 0.0483], [0.1216, 0.8714, 0.3256, 1.0]], 191: [[0.0125, 0.0464, 0.0465], [0.5306, 0.8276, 0.4073, 1.0]], 192: [[0.0189, 0.0452, 0.0466], [0.5741, 0.4903, 0.6112, 1.0]], 193: [[0.0209, 0.0492, 0.0461], [0.3885, 0.3919, 0.7096, 1.0]], 194: [[0.0183, 0.045, 0.049], [0.5015, 0.2727, 0.1527, 1.0]], 195: [[0.0161, 0.0467, 0.0457], [0.8348, 0.8837, 0.803, 1.0]], 196: [[0.0197, 0.0463, 0.0487], [0.4166, 0.2679, 0.4325, 1.0]], 197: [[0.0112, 0.0484, 0.0458], [0.3721, 0.7998, 0.1675, 1.0]], 198: [[0.0209, 0.0483, 0.049], [0.7851, 0.2541, 0.6033, 1.0]], 199: [[0.0164, 0.0465, 0.0463], [0.5876, 0.8114, 0.655, 1.0]], 200: [[0.0178, 0.0455, 0.0457], [0.4663, 0.1764, 0.7163, 1.0]], 201: [[0.0105, 0.0499, 0.048], [0.6529, 0.6654, 0.3025, 1.0]], 202: [[0.0129, 0.0485, 0.0469], [0.5254, 0.738, 0.7322, 1.0]], 203: [[0.0128, 0.0454, 0.0496], [0.6982, 0.7002, 0.372, 1.0]], 204: [[0.0179, 0.0498, 0.0462], [0.5059, 0.5297, 0.4374, 1.0]], 205: [[0.0158, 0.047, 0.0491], [0.2421, 0.4905, 0.3823, 1.0]], 206: [[0.0164, 0.0459, 0.046], [0.879, 0.2691, 0.4236, 1.0]], 207: [[0.0168, 0.0479, 0.0493], [0.6304, 0.1239, 0.1815, 1.0]], 208: [[0.0177, 0.05, 0.0497], [0.2669, 0.6729, 0.1417, 1.0]], 209: [[0.0109, 0.0467, 0.0472], [0.6868, 0.8308, 0.1593, 1.0]], 210: [[0.0163, 0.0468, 0.0483], [0.355, 0.3442, 0.494, 1.0]], 211: [[0.0125, 0.0488, 0.0463], [0.7516, 0.3278, 0.3894, 1.0]], 212: [[0.0201, 0.0486, 0.0478], [0.7729, 0.6964, 0.1113, 1.0]], 213: [[0.0144, 0.0488, 0.0498], [0.5506, 0.4388, 0.1844, 1.0]], 214: [[0.0128, 0.047, 0.0462], [0.2999, 0.2076, 0.79, 1.0]], 215: [[0.0184, 0.0482, 0.0492], [0.1001, 0.616, 0.1964, 1.0]], 216: [[0.0209, 0.0477, 0.0491], [0.299, 0.4896, 0.7637, 1.0]], 217: [[0.0212, 0.0456, 0.0452], [0.2345, 0.863, 0.1339, 1.0]], 218: [[0.0136, 0.047, 0.0484], [0.6714, 0.2591, 0.2313, 1.0]], 219: [[0.013, 0.0484, 0.0496], [0.2711, 0.198, 0.564, 1.0]], 220: [[0.0145, 0.0464, 0.0486], [0.1741, 0.7011, 0.3521, 1.0]], 221: [[0.018, 0.0497, 0.0486], [0.1392, 0.4709, 0.4622, 1.0]], 222: [[0.0158, 0.0489, 0.0462], [0.6808, 0.7848, 0.6046, 1.0]], 223: [[0.019, 0.0499, 0.0484], [0.4008, 0.627, 0.5309, 1.0]], 224: [[0.0182, 0.0475, 0.0471], [0.6352, 0.7478, 0.5901, 1.0]], 225: [[0.0161, 0.0453, 0.0488], [0.4907, 0.5478, 0.322, 1.0]], 226: [[0.0112, 0.0451, 0.0467], [0.6312, 0.8163, 0.7085, 1.0]], 227: [[0.016, 0.0494, 0.0464], [0.4536, 0.2015, 0.3697, 1.0]], 228: [[0.0182, 0.0459, 0.0495], [0.2824, 0.4644, 0.4464, 1.0]], 229: [[0.0129, 0.0459, 0.0469], [0.2873, 0.2099, 0.2296, 1.0]], 230: [[0.0127, 0.0462, 0.0497], [0.8869, 0.2588, 0.1459, 1.0]], 231: [[0.0188, 0.047, 0.0483], [0.3265, 0.4428, 0.7746, 1.0]], 232: [[0.0163, 0.0474, 0.0463], [0.7222, 0.7784, 0.179, 1.0]], 233: [[0.0215, 0.0454, 0.0496], [0.4895, 0.7109, 0.5859, 1.0]], 234: [[0.0205, 0.0475, 0.0491], [0.379, 0.1502, 0.4715, 1.0]], 235: [[0.0177, 0.0493, 0.0493], [0.2987, 0.8402, 0.4573, 1.0]], 236: [[0.0131, 0.0451, 0.0466], [0.7328, 0.2241, 0.1718, 1.0]], 237: [[0.0216, 0.0479, 0.048], [0.6541, 0.8914, 0.8037, 1.0]], 238: [[0.0198, 0.0482, 0.05], [0.6548, 0.1264, 0.2441, 1.0]], 239: [[0.0141, 0.0475, 0.0474], [0.7782, 0.7148, 0.4431, 1.0]], 240: [[0.0186, 0.0458, 0.0492], [0.1186, 0.8537, 0.2014, 1.0]], 241: [[0.0197, 0.0471, 0.0498], [0.2241, 0.227, 0.3461, 1.0]], 242: [[0.0192, 0.0473, 0.0481], [0.7069, 0.7175, 0.5007, 1.0]], 243: [[0.0218, 0.0466, 0.0459], [0.6437, 0.7739, 0.5479, 1.0]], 244: [[0.018, 0.0462, 0.0499], [0.7862, 0.1087, 0.8509, 1.0]], 245: [[0.0188, 0.0499, 0.0468], [0.4566, 0.4582, 0.5993, 1.0]], 246: [[0.0187, 0.0471, 0.0462], [0.4948, 0.6884, 0.8971, 1.0]], 247: [[0.0181, 0.0488, 0.0465], [0.1297, 0.3699, 0.8934, 1.0]], 248: [[0.0156, 0.047, 0.0461], [0.7562, 0.5196, 0.331, 1.0]], 249: [[0.0101, 0.0456, 0.0455], [0.1511, 0.8475, 0.322, 1.0]]} @@ -399,6 +399,5 @@ def reset(self): self.sim.model.site_rgba[self.success_indicator_sid, :2] = np.array([2, 0]) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset() + obs = super().reset(**kwargs) return obs - diff --git a/robohive/envs/myo/myobase/walk_v0.py b/robohive/envs/myo/myobase/walk_v0.py index 53349c08..a5bf3164 100644 --- a/robohive/envs/myo/myobase/walk_v0.py +++ b/robohive/envs/myo/myobase/walk_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import quat2mat @@ -62,6 +62,10 @@ def _setup(self, # Change the alpha value to make it transparent self.sim.model.geom_rgba[geom_1_indices, 3] = 0 + # move heightfield down if not used + self.sim.model.geom_rgba[self.sim.model.geom_name2id('terrain')][-1] = 0.0 + self.sim.model.geom_pos[self.sim.model.geom_name2id('terrain')] = np.array([0, 0, -10]) + def get_obs_dict(self, sim): obs_dict = {} @@ -122,7 +126,7 @@ def generate_qpos(self): return qpos_new - def reset(self): + def reset(self, **kwargs): # generate random targets if np.ptp(self.joint_random_range)>0: self.sim.data.qpos = self.generate_qpos() @@ -134,9 +138,9 @@ def reset(self): # generate resets if np.ptp(self.joint_random_range)>0: - obs = super().reset(reset_qpos= self.generate_qpos()) + obs = super().reset(reset_qpos= self.generate_qpos(), **kwargs) else: - obs = super().reset() + obs = super().reset(**kwargs) return obs class WalkEnvV0(BaseV0): @@ -207,6 +211,10 @@ def _setup(self, self.init_qpos[:] = self.sim.model.key_qpos[0] self.init_qvel[:] = 0.0 + # move heightfield down if not used + self.sim.model.geom_rgba[self.sim.model.geom_name2id('terrain')][-1] = 0.0 + self.sim.model.geom_pos[self.sim.model.geom_name2id('terrain')] = np.array([0, 0, -10]) + def get_obs_dict(self, sim): obs_dict = {} obs_dict['t'] = np.array([sim.data.time]) @@ -270,11 +278,11 @@ def get_randomized_initial_state(self): return qpos, qvel def step(self, *args, **kwargs): - obs, reward, done, info = super().step(*args, **kwargs) + results = super().step(*args, **kwargs) self.steps += 1 - return obs, reward, done, info + return results - def reset(self): + def reset(self, **kwargs): self.steps = 0 if self.reset_type == 'random': qpos, qvel = self.get_randomized_initial_state() @@ -283,7 +291,7 @@ def reset(self): else: qpos, qvel = self.sim.model.key_qpos[0], self.sim.model.key_qvel[0] self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super().reset(reset_qpos=qpos, reset_qvel=qvel) + obs = super().reset(reset_qpos=qpos, reset_qvel=qvel, **kwargs) return obs def muscle_lengths(self): @@ -470,7 +478,7 @@ def _setup(self, self.init_qpos[:] = self.sim.model.key_qpos[0] self.init_qvel[:] = 0.0 - def reset(self): + def reset(self, **kwargs): self.steps = 0 if self.terrain == 'rough': rough = self.np_random.uniform(low=-.5, high=.5, size=(10000,)) @@ -512,7 +520,7 @@ def reset(self): else: qpos, qvel = self.sim.model.key_qpos[0], self.sim.model.key_qvel[0] self.robot.sync_sims(self.sim, self.sim_obsd) - obs = BaseV0.reset(self, reset_qpos=qpos, reset_qvel=qvel) + obs = BaseV0.reset(self, reset_qpos=qpos, reset_qvel=qvel, **kwargs) return obs def _get_done(self): diff --git a/robohive/envs/myo/myochallenge/__init__.py b/robohive/envs/myo/myochallenge/__init__.py index 7fa2d003..26fc6f7c 100644 --- a/robohive/envs/myo/myochallenge/__init__.py +++ b/robohive/envs/myo/myochallenge/__init__.py @@ -1,5 +1,4 @@ -from gym.envs.registration import register - +from robohive.utils import gym; register=gym.register import os curr_dir = os.path.dirname(os.path.abspath(__file__)) import numpy as np @@ -11,12 +10,12 @@ entry_point='robohive.envs.myo.myochallenge.relocate_v0:RelocateEnvV0', max_episode_steps=150, kwargs={ - 'model_path': curr_dir+'/../../../simhive/myo_sim/arm/myoarm_object_v0.16(mj237).mjb', + 'model_path': curr_dir+'/../assets/arm/myoarm_relocate.xml', 'normalize_act': True, 'frame_skip': 5, 'pos_th': 0.1, # cover entire base of the receptacle 'rot_th': np.inf, # ignore rotation errors - 'target_xyz_range': {'high':[0.2, -.35, 0.9], 'low':[0.0, -.1, 0.9]}, + 'target_xyz_range': {'high':[0.2, -.1, 0.9], 'low':[0.0, -.35, 0.9]}, 'target_rxryrz_range': {'high':[0.0, 0.0, 0.0], 'low':[0.0, 0.0, 0.0]} } ) @@ -26,7 +25,7 @@ entry_point='robohive.envs.myo.myochallenge.relocate_v0:RelocateEnvV0', max_episode_steps=150, kwargs={ - 'model_path': curr_dir+'/../../../simhive/myo_sim/arm/myoarm_object_v0.16(mj237).mjb', + 'model_path': curr_dir+'/../assets/arm/myoarm_relocate.xml', 'normalize_act': True, 'frame_skip': 5, 'pos_th': 0.1, # cover entire base of the receptacle @@ -41,13 +40,33 @@ } ) +# Register MyoChallenge Manipulation P2 Evals +register(id='myoChallengeRelocateP2eval-v0', + entry_point='robohive.envs.myo.myochallenge.relocate_v0:RelocateEnvV0', + max_episode_steps=150, + kwargs={ + 'model_path': curr_dir + '/../assets/arm/myoarm_relocate.xml', + 'normalize_act': True, + 'frame_skip': 5, + 'pos_th': 0.1, # cover entire base of the receptacle + 'rot_th': np.inf, # ignore rotation errors + 'qpos_noise_range':0.015, # jnt initialization range + 'target_xyz_range': {'high':[0.4, -.1, 1.1], 'low':[-.5, -.5, .9]}, + 'target_rxryrz_range': {'high':[.3, .3, .3], 'low':[-.3, -.3, -.3]}, + 'obj_xyz_range': {'high':[0.15, -.10, 1.0], 'low':[-0.20, -.40, 1.0]}, + 'obj_geom_range': {'high':[.025, .025, .035], 'low':[.015, 0.015, 0.015]}, + 'obj_mass_range': {'high':0.300, 'low':0.050},# 50gms to 250 gms + 'obj_friction_range': {'high':[1.2, 0.006, 0.00012], 'low':[0.8, 0.004, 0.00008]} + } +) + ## MyoChallenge Locomotion P1 register(id='myoChallengeChaseTagP1-v0', entry_point='robohive.envs.myo.myochallenge.chasetag_v0:ChaseTagEnvV0', max_episode_steps=2000, kwargs={ - 'model_path': curr_dir+'/../../../simhive/myo_sim/leg/myolegs_chasetag_v0.11(mj237).mjb', + 'model_path': curr_dir+'/../assets/leg/myolegs_chasetag.xml', 'normalize_act': True, 'win_distance': 0.5, 'min_spawn_distance': 2, @@ -57,6 +76,7 @@ 'hills_range': (0.0, 0.0), 'rough_range': (0.0, 0.0), 'relief_range': (0.0, 0.0), + 'opponent_probabilities': (0.1, 0.45, 0.45), } ) @@ -66,7 +86,7 @@ entry_point='robohive.envs.myo.myochallenge.chasetag_v0:ChaseTagEnvV0', max_episode_steps=2000, kwargs={ - 'model_path': curr_dir+'/../../../simhive/myo_sim/leg/myolegs_chasetag_v0.11(mj237).mjb', + 'model_path': curr_dir+'/../assets/leg/myolegs_chasetag.xml', 'normalize_act': True, 'win_distance': 0.5, 'min_spawn_distance': 2, @@ -76,9 +96,35 @@ 'hills_range': (0.03, 0.23), 'rough_range': (0.05, 0.1), 'relief_range': (0.1, 0.3), + 'repeller_opponent': False, + 'chase_vel_range': (1.0, 1.0), + 'random_vel_range': (-2, 2), + 'opponent_probabilities': (0.1, 0.45, 0.45), } ) +# Register MyoChallenge Locomotion P2 Evals +register(id='myoChallengeChaseTagP2eval-v0', + entry_point='robohive.envs.myo.myochallenge.chasetag_v0:ChaseTagEnvV0', + max_episode_steps=2000, + kwargs={ + 'model_path': curr_dir+'/../assets/leg/myolegs_chasetag.xml', + 'normalize_act': True, + 'win_distance': 0.5, + 'min_spawn_distance': 2, + 'reset_type': 'random', # none, init, random + 'terrain': 'random', # FLAT, random + 'task_choice': 'random', # CHASE, EVADE, random + 'hills_range': (0.03, 0.23), + 'rough_range': (0.05, 0.1), + 'relief_range': (0.1, 0.3), + 'repeller_opponent': True, + 'chase_vel_range': (1, 5), + 'random_vel_range': (-2, 2), + 'repeller_vel_range': (0.3, 1), + 'opponent_probabilities': (0.1, 0.35, 0.35, 0.2), + } + ) # MyoChallenge 2022 envs ============================================== # MyoChallenge Die: Trial env diff --git a/robohive/envs/myo/myochallenge/baoding_v1.py b/robohive/envs/myo/myochallenge/baoding_v1.py index a46a15fd..861a1f40 100644 --- a/robohive/envs/myo/myochallenge/baoding_v1.py +++ b/robohive/envs/myo/myochallenge/baoding_v1.py @@ -5,7 +5,7 @@ import collections import enum -import gym +from robohive.utils import gym import numpy as np from robohive.envs.myo.base_v0 import BaseV0 @@ -233,7 +233,7 @@ def get_metrics(self, paths): return metrics - def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None, **kwargs): # reset task if self.task_choice == 'random': self.which_task = self.np_random.choice(Task) @@ -266,7 +266,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=No self.sim.model.geom_size[self.object2_gid] = self.np_random.uniform(**self.obj_size_range) # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel) + obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) return obs def create_goal_trajectory(self, time_step=.1, time_period=6): diff --git a/robohive/envs/myo/myochallenge/chasetag_v0.py b/robohive/envs/myo/myochallenge/chasetag_v0.py index 28568857..2c035765 100644 --- a/robohive/envs/myo/myochallenge/chasetag_v0.py +++ b/robohive/envs/myo/myochallenge/chasetag_v0.py @@ -4,11 +4,12 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np import pink import os from enum import Enum +from typing import Optional, Tuple from robohive.envs.myo.base_v0 import BaseV0 from robohive.envs.myo.myobase.walk_v0 import WalkEnvV0 @@ -36,11 +37,31 @@ class ChallengeOpponent: Contains several different policies. For the final evaluation, an additional non-disclosed policy will be used. """ - def __init__(self, sim, rng, probabilities: list, min_spawn_distance: float): - self.dt = 0.01 + def __init__(self, + sim, + rng, + probabilities: Tuple[float], + min_spawn_distance: float, + chase_vel_range: Tuple[float], + random_vel_range: Tuple[float], + dt=0.01, + ): + """ + Initialize the opponent class. + :param sim: Mujoco sim object. + :param rng: np_random generator. + :param probabilities: Probabilities for the different policies, (static_stationary, stationary, random). + :param min_spawn_distance: Minimum distance for opponent to spawn from the model. + :param chase_vel_range: Range of velocities for the chase policy. Randomly drawn. + :param random_vel_range: Range of velocities for the random policy. Clipped. + :param dt: Simulation timestep. + """ + self.dt = dt self.sim = sim self.opponent_probabilities = probabilities self.min_spawn_distance = min_spawn_distance + self.chase_vel_range = chase_vel_range + self.random_vel_range = random_vel_range self.reset_opponent(rng=rng) def reset_noise_process(self): @@ -91,7 +112,7 @@ def random_movement(self): This moves the opponent randomly in a correlated pattern. """ - return self.noise_process.sample() + return np.clip(self.noise_process.sample(), self.random_vel_range[0], self.random_vel_range[1]) def sample_opponent_policy(self): """ @@ -150,6 +171,9 @@ def reset_opponent(self, player_task='CHASE', rng=None): self.set_opponent_pose(pose) self.opponent_vel[:] = 0.0 + # Randomize opponent forward velocity + self.chase_velocity = self.rng.uniform(self.chase_vel_range[0], self.chase_vel_range[1]) + def chase_player(self): """ This moves the opponent randomly in a correlated @@ -162,7 +186,7 @@ def chase_player(self): new_vec = np.array([np.cos(theta), np.sin(theta)]) new_vec2 = pel - vec vel = np.dot(new_vec, new_vec2) - return np.array([1.0, vel]) + return np.array([self.chase_velocity, vel]) class HeightField: @@ -199,6 +223,17 @@ def __init__(self, self.relief_range = relief_range self._populate_patches() + def flatten_agent_patch(self, qpos): + """ + Turn terrain in the patch around the agent to flat. + """ + # convert position to map position + pos = self.cart2map(qpos[:2]) + # get patch that belongs to the position + i = pos[0] // self.patch_size + j = pos[1] // self.patch_size + self._fill_patch(i, j, terrain_type=TerrainTypes.FLAT) + def _compute_patch_data(self, terrain_type): if terrain_type.name == 'FLAT': return np.zeros((self.patch_size, self.patch_size)) @@ -223,10 +258,10 @@ def _populate_patches(self): self._fill_patch(i, j, terrain_type) # put special terrain only once in 20% of episodes if self.rng.uniform() < 0.2: - i, j = self.rng.randint(0, self.patches_per_side, size=2) + i, j = np.random.randint(0, self.patches_per_side, size=2) self._fill_patch(i, j, SpecialTerrains.RELIEF) - def _fill_patch(self, i, j, terrain_type='FLAT'): + def _fill_patch(self, i, j, terrain_type=TerrainTypes.FLAT): """ Fill patch at position , with terrain """ @@ -234,20 +269,37 @@ def _fill_patch(self, i, j, terrain_type='FLAT'): j * self.patch_size: j * self.patch_size + self.patch_size] = self._compute_patch_data(terrain_type) def get_heightmap_obs(self): + """ + Get heightmap observation. + """ if self.heightmap_window is None: self.heightmap_window = np.zeros((10, 10)) self._measure_height() return self.heightmap_window[:].flatten().copy() - def cart2map(self, pos): + def cart2map(self, + points_1: list, + points_2: Optional[list] = None): """ Transform cartesian position [m * m] to rounded map position [nrow * ncol] + If only points_1 is given: Expects cartesian positions in [x, y] format. + If also points_2 is given: Expects points_1 = [x1, x2, ...] points_2 = [y1, y2, ...] """ delta_map = self.real_length / self.nrow offset = self.hfield.data.shape[0] / 2 - return pos[:] / delta_map + offset + # x, y needs to be switched to match hfield. + if points_2 is None: + return np.array(points_1[::-1] / delta_map + offset, dtype=np.int16) + else: + ret1 = np.array(points_1[:] / delta_map + offset, dtype=np.int16) + ret2 = np.array(points_2[:] / delta_map + offset, dtype=np.int16) + return ret2, ret1 def sample(self, rng=None): + """ + Sample an entire heightfield for the episode. + Update geom in viewer if rendering. + """ if not rng is None: self.rng = rng self._populate_patches() @@ -256,6 +308,9 @@ def sample(self, rng=None): # Patch types --------------- def _compute_rough_terrain(self): + """ + Compute data for a random noise rough terrain. + """ rough = self.rng.uniform(low=-1.0, high=1.0, size=(self.patch_size, self.patch_size)) normalized_data = (rough - np.min(rough)) / (np.max(rough) - np.min(rough)) scalar, offset = .08, .02 @@ -263,12 +318,18 @@ def _compute_rough_terrain(self): return normalized_data * scalar - offset def _compute_relief_terrain(self): + """ + Compute data for a special logo terrain. + """ curr_dir = os.path.dirname(__file__) relief = np.load(os.path.join(curr_dir, '../assets/myo_relief.npy')) normalized_data = (relief - np.min(relief)) / (np.max(relief) - np.min(relief)) return np.flipud(normalized_data) * self.rng.uniform(self.relief_range[0], self.relief_range[1]) def _compute_hilly_terrain(self): + """ + Compute data for a terrain with smooth hills. + """ frequency = 10 scalar = self.rng.uniform(low=self.hills_range[0], high=self.hills_range[1]) data = np.sin(np.linspace(0, frequency * np.pi, self.patch_size * self.patch_size) + np.pi / 2) - 1 @@ -279,7 +340,7 @@ def _compute_hilly_terrain(self): return normalized_data def _init_height_points(self): - """ Compute points at which height measurments are sampled (in base frame) + """ Compute grid points at which height measurements are sampled (in base frame) Saves the points in ndarray of shape (self.num_height_points, 3) """ measured_points_x = [-0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5] @@ -295,10 +356,14 @@ def _init_height_points(self): self.height_points = points def _measure_height(self): + """ + Update heights at grid points around + model. + """ rot_direction = quat2euler(self.sim.data.qpos[3:7])[2] rot_mat = euler2mat([0, 0, rot_direction]) # rotate points around z-direction to match model - points = self.height_points @ rot_mat + points = np.einsum("ij,kj->ik", self.height_points, rot_mat) # increase point spacing points = (points * self.view_distance) # translate points to model frame @@ -307,20 +372,17 @@ def _measure_height(self): px = self.points[:, 0] py = self.points[:, 1] # get map_index coordinates of points - px = np.asarray(self.cart2map(px), dtype=np.int16) - py = np.asarray(self.cart2map(py), dtype=np.int16) + px, py = self.cart2map(px, py) # avoid out-of-bounds by clipping indices to map boundaries # -2 because we go one further and shape is 1 longer than map index px = np.clip(px, 0, self.hfield.data.shape[0] - 2) py = np.clip(py, 0, self.hfield.data.shape[1] - 2) - # switch x and y here because of array indexing - heights = self.hfield.data[py, px] - + heights = self.hfield.data[px, py] if not hasattr(self, 'length'): self.length = 0 self.length += 1 # align with egocentric view of model - self.heightmap_window[:] = np.rot90((heights).reshape(10, 10)) + self.heightmap_window[:] = np.flipud(np.rot90(heights.reshape(10, 10), axes=(1,0))) @property def size(self): @@ -335,6 +397,188 @@ def ncol(self): return self.hfield.ncol +class RepellerChallengeOpponent(ChallengeOpponent): + # Repeller parameters + DIST_INFLUENCE = 3.5 # Distance of influence by the repeller + ETA = 20.0 # Scaling factor + MIN_SPAWN_DIST = 1.5 + BOUND_RESOLUTIONS = [-8.7, 8.7, 25] + + def __init__(self, + sim, + rng, + probabilities: Tuple[float], + min_spawn_distance: float, + chase_vel_range: Tuple[float], + random_vel_range: Tuple[float], + repeller_vel_range: Tuple[float], + dt=0.01, + ): + """ + Initialize the opponent class. This class additionally contains a repeller policy which always runs away from the + agent. + :param sim: Mujoco sim object. + :param rng: np_random generator. + :param probabilities: Probabilities for the different policies, (static_stationary, stationary, random, repeller). + :param min_spawn_distance: Minimum distance for opponent to spawn from the model. + :param chase_vel_range: Range of velocities for the chase policy. Randomly drawn. + :param random_vel_range: Range of velocities for the random policy. Clipped. + :param dt: Simulation timestep. + """ + self.dt = dt + self.sim = sim + self.rng = rng + self.opponent_probabilities = probabilities + + self.min_spawn_distance = min_spawn_distance + self.noise_process = pink.ColoredNoiseProcess(beta=2, size=(2, 2000), scale=10, rng=rng) + self.chase_vel_range = chase_vel_range + self.random_vel_range = random_vel_range + self.repeller_vel_range = repeller_vel_range + self.reset_opponent() + + def get_agent_pos(self): + """ + Get agent Pose + :param pose: Pose of the agent, measured from the pelvis. + :type pose: array -> [x, y] + """ + return self.sim.data.body('pelvis').xpos[:2] + + def get_wall_pos(self): + """ + Get location of quad boundaries. + :param pose: Pose of points along quad boundaries. + :type pose: array -> [x, y] + """ + bound_resolution = np.linspace(self.BOUND_RESOLUTIONS[0], self.BOUND_RESOLUTIONS[1], self.BOUND_RESOLUTIONS[2]) + right_left_bounds = np.vstack( (np.array([[8.7,x] for x in bound_resolution]), + np.array([[-8.7,x] for x in bound_resolution])) ) + all_bounds = np.vstack( (right_left_bounds, right_left_bounds[:,[1,0]]) ) + + return all_bounds + + def get_repellers(self): + """ + Get location of all repellers. + :param pose: Pose of all repellers + :type pose: array -> [x, y] + """ + agent_pos = self.get_agent_pos() + wall_pos = self.get_wall_pos() + + obstacle_list = np.vstack( (agent_pos, wall_pos) ) + return obstacle_list + + def repeller_stochastic(self): + """ + Returns the linear velocity for the opponent + :param pose: Pose of points of all repellers + :type pose: array -> [x, y, rotation] + """ + obstacle_pos = self.get_repellers() + opponent_pos = self.get_opponent_pose().copy() + + # Calculate over all the workspace + distance = np.array([np.linalg.norm(diff) for diff in (obstacle_pos - opponent_pos[0:2])]) + + # Check if any obstacles are around + dist_idx = np.where(distance < self.DIST_INFLUENCE)[0] + + # Take a random step if no repellers are close by, making it a non-stationary target + if len(dist_idx) == 0: + lin, rot = self.noise_process.sample() + escape_linear = np.clip(lin, self.repeller_vel_range[0], self.repeller_vel_range[1]) + escape_ang_rot = self._calc_angular_vel(opponent_pos[2], rot) + return np.hstack((escape_linear, escape_ang_rot)) + + repel_COM = np.mean(obstacle_pos[dist_idx,:], axis=0) + # Use repeller force as linear velocity to escape + repel_force = 0.5 * self.ETA * ( 1/np.maximum(distance[dist_idx], 0.00001) - 1/self.DIST_INFLUENCE )**2 + escape_linear = np.clip(np.mean(repel_force), self.repeller_vel_range[0], self.repeller_vel_range[1]) + escape_xpos = opponent_pos[0:2] - repel_COM + + equil_idx = np.where(np.abs(escape_xpos) <= 0.1 )[0] + if len(equil_idx) != 0: + for idx in equil_idx: + escape_xpos[idx] = -1*np.sign(escape_xpos[idx]) * self.rng.uniform(low=0.3, high=0.9) + + escape_direction = np.arctan2(escape_xpos[1], escape_xpos[0]) # Direction + escape_direction = escape_direction + 1.57 # Account for rotation in world frame + + # Determines turning direction + escape_ang_rot = self._calc_angular_vel(opponent_pos[2], escape_direction) + + return np.hstack((escape_linear, escape_ang_rot)) + + def _calc_angular_vel(self, current_pos, desired_pos): + # Checking for sign of the current position and escape position to prevent inefficient turning + # E.g. 3.14 and -3.14 are pointing in the same direction, so a simple substraction of facing direction will make the opponent turn a lot + + # Bring the current pos and desired pos to be between 0 to 2pi + if current_pos > (2*np.pi): + while current_pos > (2*np.pi): + current_pos = current_pos - (2*np.pi) + elif np.sign(current_pos) < 0: + while np.sign(current_pos) < 0: + current_pos = current_pos + (2*np.pi) + + if desired_pos > (2*np.pi): + while desired_pos > (2*np.pi): + desired_pos = desired_pos - (2*np.pi) + elif np.sign(desired_pos) < 0: + while np.sign(desired_pos) < 0: + desired_pos = desired_pos + (2*np.pi) + + direction_clock = np.abs(0 - current_pos) + (2*np.pi - desired_pos) # Clockwise rotation + direction_anticlock = (2*np.pi - current_pos) + (0 + desired_pos) # Anticlockwise rotation + + if direction_clock < direction_anticlock: + return 1 + else: + return -1 + + def repeller_policy(self): + """ + This uses the repeller policy to move the opponent. + """ + return self.repeller_stochastic() + + def sample_opponent_policy(self): + """ + Takes in three probabilities and returns the policies with the given frequency. + """ + rand_num = self.rng.uniform() + if rand_num < self.opponent_probabilities[0]: + self.opponent_policy = 'static_stationary' + elif rand_num < self.opponent_probabilities[0] + self.opponent_probabilities[1]: + self.opponent_policy = 'stationary' + elif rand_num < self.opponent_probabilities[0] + self.opponent_probabilities[1] + self.opponent_probabilities[2]: + self.opponent_policy = 'random' + else: + self.opponent_policy = 'repeller' + + def update_opponent_state(self): + """ + This function executes an opponent step with + one of the control policies. + """ + if self.opponent_policy == 'stationary' or self.opponent_policy == 'static_stationary': + opponent_vel = np.zeros(2,) + + elif self.opponent_policy == 'random': + opponent_vel = self.random_movement() + + elif self.opponent_policy == 'repeller': + opponent_vel = self.repeller_policy() + + elif self.opponent_policy == 'chase_player': + opponent_vel = self.chase_player() + else: + raise NotImplementedError(f"This opponent policy doesn't exist. Chose: static_stationary, stationary or random. Policy was: {self.opponent_policy}") + self.move_opponent(opponent_vel) + + class ChaseTagEnvV0(WalkEnvV0): DEFAULT_OBS_KEYS = [ @@ -375,14 +619,12 @@ def __init__(self, model_path, obsd_model_path=None, seed=None, **kwargs): # first construct the inheritance chain, which is just __init__ calls all the way down, with env_base # creating the sim / sim_obsd instances. Next we run through "setup" which relies on sim / sim_obsd # created in __init__ to complete the setup. - # base().__init__(model_path=model_path, obsd_model_path=obsd_model_path, seed=seed) BaseV0.__init__(self, model_path=model_path, obsd_model_path=obsd_model_path, seed=seed, env_credits=self.MYO_CREDIT) self._setup(**kwargs) def _setup(self, obs_keys: list = DEFAULT_OBS_KEYS, weighted_reward_keys: dict = DEFAULT_RWD_KEYS_AND_WEIGHTS, - opponent_probabilities=[0.1, 0.45, 0.45], reset_type='none', win_distance=0.5, min_spawn_distance=2, @@ -391,6 +633,11 @@ def _setup(self, hills_range=(0,0), rough_range=(0,0), relief_range=(0,0), + repeller_opponent=False, + chase_vel_range=(1.0, 1.0), + random_vel_range=(1.0, 1.0), + repeller_vel_range=(1.0, 1.0), + opponent_probabilities=(0.1, 0.45, 0.45), **kwargs, ): @@ -406,12 +653,27 @@ def _setup(self, self.task_choice = task_choice self.terrain = terrain self.maxTime = 20 + if repeller_opponent: + self.opponent = RepellerChallengeOpponent(sim=self.sim, + rng=self.np_random, + probabilities=opponent_probabilities, + min_spawn_distance=min_spawn_distance, + chase_vel_range=chase_vel_range, + random_vel_range=random_vel_range, + repeller_vel_range=repeller_vel_range) + else: + self.opponent = ChallengeOpponent(sim=self.sim, + rng=self.np_random, + probabilities=opponent_probabilities, + min_spawn_distance=min_spawn_distance, + chase_vel_range=chase_vel_range, + random_vel_range=random_vel_range) self.win_distance = win_distance self.grf_sensor_names = ['r_foot', 'r_toes', 'l_foot', 'l_toes'] - self.opponent = ChallengeOpponent(sim=self.sim, rng=self.np_random, probabilities=opponent_probabilities, min_spawn_distance = min_spawn_distance) self.success_indicator_sid = self.sim.model.site_name2id("opponent_indicator") self.current_task = Task.CHASE + self.repeller_opponent = repeller_opponent super()._setup(obs_keys=obs_keys, weighted_reward_keys=weighted_reward_keys, reset_type=reset_type, @@ -420,6 +682,25 @@ def _setup(self, self.init_qpos[:] = self.sim.model.key_qpos[0] self.init_qvel[:] = 0.0 self.startFlag = True + self.assert_settings() + self.opponent.dt = self.sim.model.opt.timestep * self.frame_skip + + + + def assert_settings(self): + # chase always positive + assert self.opponent.chase_vel_range[0] >= 0 and self.opponent.chase_vel_range[1] > 0, f"Chase velocity range should be positive. {self.opponent.chase_vel_range}" + # others assert that range end is bigger than range start + assert self.opponent.chase_vel_range[0] <= self.opponent.chase_vel_range[1], f"Chase velocity range is not valid. {self.opponent.chase_vel_range}" + assert self.opponent.random_vel_range[0] <= self.opponent.random_vel_range[1], f"Random movement velocity range is not valid {self.opponent.random_vel_range}" + if hasattr(self.opponent, 'repeller_vel_range'): + assert self.opponent.repeller_vel_range[0] <= self.opponent.repeller_vel_range[1], f"Repeller velocity range is not valid {self.opponent.repeller_vel_range}" + if self.repeller_opponent == True: + assert len(self.opponent.opponent_probabilities) == 4, "Repeller opponent requires 4 probabilities" + else: + assert len(self.opponent.opponent_probabilities) == 3, "Standard opponent requires 3 probabilities" + for x in self.opponent.opponent_probabilities: + assert 0 <= x <= 1, "Probabilities should be between 0 and 1" def get_obs_dict(self, sim): obs_dict = {} @@ -495,8 +776,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) # Success Indicator - self.sim.model.site_rgba[self.success_indicator_sid, :] = np.array([0, 2, 0, 0.1]) if rwd_dict['solved'] else np.array([2, 0, 0, 0]) - + self.sim.model.site_rgba[self.success_indicator_sid, :] = np.array([0, 2, 0, 0.2]) if rwd_dict['solved'] else np.array([2, 0, 0, 0]) return rwd_dict def get_metrics(self, paths): @@ -518,22 +798,32 @@ def get_metrics(self, paths): def step(self, *args, **kwargs): self.opponent.update_opponent_state() - obs, reward, done, info = super().step(*args, **kwargs) - return obs, reward, done, info + results = super().step(*args, **kwargs) + return results - def reset(self): + def reset(self, **kwargs): # randomized terrain types self._maybe_sample_terrain() # randomized tasks self._sample_task() # randomized initial state qpos, qvel = self._get_reset_state() + self._maybe_flatten_agent_patch(qpos) self.robot.sync_sims(self.sim, self.sim_obsd) - obs = super(WalkEnvV0, self).reset(reset_qpos=qpos, reset_qvel=qvel) + obs = super(WalkEnvV0, self).reset(reset_qpos=qpos, reset_qvel=qvel, **kwargs) self.opponent.reset_opponent(player_task=self.current_task.name, rng=self.np_random) self.sim.forward() return obs + def _maybe_flatten_agent_patch(self, qpos): + """ + Ensure that initial state patch is flat. + """ + if self.heightfield is not None: + self.heightfield.flatten_agent_patch(qpos) + if hasattr(self.sim, 'renderer') and not self.sim.renderer._window is None: + self.sim.renderer._window.update_hfield(0) + def _sample_task(self): if self.task_choice == 'random': self.current_task = self.np_random.choice(Task) @@ -559,6 +849,8 @@ def _randomize_position_orientation(self, qpos, qvel): euler_angle = quat2euler(qpos[3:7]) euler_angle[-1] = orientation qpos[3:7] = euler2quat(euler_angle) + # rotate original velocity with unit direction vector + qvel[:2] = np.array([np.cos(orientation), np.sin(orientation)]) * np.linalg.norm(qvel[:2]) return qpos, qvel def _get_reset_state(self): @@ -570,6 +862,17 @@ def _get_reset_state(self): else: return self.sim.model.key_qpos[0], self.sim.model.key_qvel[0] + def _maybe_adjust_height(self, qpos, qvel): + """ + Currently not used. + """ + if self.heightfield is not None: + map_i, map_j = self.heightfield.cart2map(qpos[:2]) + hfield_val = self.heightfield.hfield.data[map_i, map_j] + if hfield_val > 0.05: + qpos[2] += hfield_val + return qpos, qvel + def viewer_setup(self, *args, **kwargs): """ Setup the default camera @@ -748,4 +1051,4 @@ def _get_fallen_condition(self): if head[2] - mean[2] < 0.2: return 1 else: - return 0 + return 0 \ No newline at end of file diff --git a/robohive/envs/myo/myochallenge/relocate_v0.py b/robohive/envs/myo/myochallenge/relocate_v0.py index 6f4fab00..5547c169 100644 --- a/robohive/envs/myo/myochallenge/relocate_v0.py +++ b/robohive/envs/myo/myochallenge/relocate_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import mat2euler, euler2quat @@ -137,7 +137,7 @@ def get_metrics(self, paths, successful_steps=5): return metrics - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): self.sim.model.body_pos[self.goal_bid] = self.np_random.uniform(**self.target_xyz_range) self.sim.model.body_quat[self.goal_bid] = euler2quat(self.np_random.uniform(**self.target_rxryrz_range)) @@ -153,7 +153,7 @@ def reset(self, reset_qpos=None, reset_qvel=None): for gid in range(self.sim.model.body_geomnum[bid]): gid+=self.sim.model.body_geomadr[bid] # get geom ids # update type, size, and collision bounds - self.sim.model.geom_type[gid]=self.np_random.randint(low=2, high=7) # random shape + self.sim.model.geom_type[gid]=self.np_random.choice([2,3,4,5,6]) # random shape self.sim.model.geom_size[gid]=self.np_random.uniform(low=self.obj_geom_range['low'], high=self.obj_geom_range['high']) # random size self.sim.model.geom_aabb[gid][3:]= self.obj_geom_range['high'] # bounding box, (center, size) self.sim.model.geom_rbound[gid] = 2.0*max(self.obj_geom_range['high']) # radius of bounding sphere @@ -180,8 +180,8 @@ def reset(self, reset_qpos=None, reset_qvel=None): else: reset_qpos_local = reset_qpos - obs = super().reset(reset_qpos_local, reset_qvel) + obs = super().reset(reset_qpos_local, reset_qvel,**kwargs) if self.sim.data.ncon>0: - self.reset(reset_qpos, reset_qvel) + self.reset(reset_qpos, reset_qvel,**kwargs) return obs \ No newline at end of file diff --git a/robohive/envs/myo/myochallenge/reorient_v0.py b/robohive/envs/myo/myochallenge/reorient_v0.py index fb15ec37..fd919193 100644 --- a/robohive/envs/myo/myochallenge/reorient_v0.py +++ b/robohive/envs/myo/myochallenge/reorient_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils import gym from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import mat2euler, euler2quat @@ -144,7 +144,7 @@ def get_metrics(self, paths, successful_steps=5): } return metrics - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): self.sim.model.body_pos[self.goal_bid] = self.goal_init_pos + \ self.np_random.uniform( high=self.goal_pos[1], low=self.goal_pos[0], size=3) @@ -167,5 +167,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): object_gpos = self.sim.model.geom_pos[self.object_gid0:self.object_gidn] self.sim.model.geom_pos[self.object_gid0:self.object_gidn] = object_gpos/abs(object_gpos+1e-16) * (abs(self.object_default_pos) + del_size) - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs \ No newline at end of file diff --git a/robohive/envs/myo/myodm/__init__.py b/robohive/envs/myo/myodm/__init__.py index a49237e5..f2dcbbde 100644 --- a/robohive/envs/myo/myodm/__init__.py +++ b/robohive/envs/myo/myodm/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + import collections import os import numpy as np @@ -120,7 +121,7 @@ def register_myohand_object_trackref(task_name, object_name, motion_path=None): entry_point='robohive.envs.myo.myodm.myodm_v0:TrackEnv', max_episode_steps=75, #50steps*40Skip*2ms = 4s kwargs={ - 'model_path': '/../assets/hand/MyoHand_object.xml', + 'model_path': '/../assets/hand/myohand_object.xml', 'object_name': object_name, 'reference':curr_dir+'/data/'+motion_path, } @@ -144,7 +145,7 @@ def register_MyoHand_object(object_name): entry_point='robohive.envs.myo.myodm.myodm_v0:TrackEnv', max_episode_steps=50, #50steps*40Skip*2ms = 4s kwargs={ - 'model_path': '/../assets/hand/MyoHand_object.xml', + 'model_path': '/../assets/hand/myohand_object.xml', 'object_name': object_name, 'reference': {'time':(0.0, 4.0), 'robot':np.zeros((1, dof_robot)), @@ -163,7 +164,7 @@ def register_MyoHand_object(object_name): entry_point='robohive.envs.myo.myodm.myodm_v0:TrackEnv', max_episode_steps=50, #50steps*40Skip*2ms = 4s kwargs={ - 'model_path': '/../assets/hand/MyoHand_object.xml', + 'model_path': '/../assets/hand/myohand_object.xml', 'object_name': object_name, 'reference': {'time':(0.0, 4.0), 'robot':np.zeros((2, dof_robot)), diff --git a/robohive/envs/myo/myodm/data/MyoHand_cylindersmall_lift.npz b/robohive/envs/myo/myodm/data/MyoHand_cylindersmall_lift.npz index 1879c491..5bb5d138 100644 Binary files a/robohive/envs/myo/myodm/data/MyoHand_cylindersmall_lift.npz and b/robohive/envs/myo/myodm/data/MyoHand_cylindersmall_lift.npz differ diff --git a/robohive/envs/myo/myodm/data/MyoHand_fryingpan_cook2.npz b/robohive/envs/myo/myodm/data/MyoHand_fryingpan_cook2.npz index d8209fa8..0a465362 100644 Binary files a/robohive/envs/myo/myodm/data/MyoHand_fryingpan_cook2.npz and b/robohive/envs/myo/myodm/data/MyoHand_fryingpan_cook2.npz differ diff --git a/robohive/envs/myo/myodm/data/MyoHand_hand_pass1.npz b/robohive/envs/myo/myodm/data/MyoHand_hand_pass1.npz index 8a3cbc22..a1698c1f 100644 Binary files a/robohive/envs/myo/myodm/data/MyoHand_hand_pass1.npz and b/robohive/envs/myo/myodm/data/MyoHand_hand_pass1.npz differ diff --git a/robohive/envs/myo/myodm/data/MyoHand_knife_lift.npz b/robohive/envs/myo/myodm/data/MyoHand_knife_lift.npz index cb949293..40b1061a 100644 Binary files a/robohive/envs/myo/myodm/data/MyoHand_knife_lift.npz and b/robohive/envs/myo/myodm/data/MyoHand_knife_lift.npz differ diff --git a/robohive/envs/myo/myodm/data/MyoHand_wineglass_drink1.npz b/robohive/envs/myo/myodm/data/MyoHand_wineglass_drink1.npz index 4ac45d68..7da359ff 100644 Binary files a/robohive/envs/myo/myodm/data/MyoHand_wineglass_drink1.npz and b/robohive/envs/myo/myodm/data/MyoHand_wineglass_drink1.npz differ diff --git a/robohive/envs/myo/myodm/myodm_v0.py b/robohive/envs/myo/myodm/myodm_v0.py index d59feda6..abd47ea0 100644 --- a/robohive/envs/myo/myodm/myodm_v0.py +++ b/robohive/envs/myo/myodm/myodm_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat @@ -289,10 +289,10 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs diff --git a/robohive/envs/myo/myomimic/__init__.py b/robohive/envs/myo/myomimic/__init__.py index 88bd1c3e..e42a82f6 100644 --- a/robohive/envs/myo/myomimic/__init__.py +++ b/robohive/envs/myo/myomimic/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + import collections import os import numpy as np diff --git a/robohive/envs/myo/myomimic/myomimic_v0.py b/robohive/envs/myo/myomimic/myomimic_v0.py index 4100dde1..c634de95 100644 --- a/robohive/envs/myo/myomimic/myomimic_v0.py +++ b/robohive/envs/myo/myomimic/myomimic_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat @@ -155,9 +155,9 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs diff --git a/robohive/envs/myo/sync_myo.sh b/robohive/envs/myo/sync_myo.sh index 59d4b569..c199a46a 100755 --- a/robohive/envs/myo/sync_myo.sh +++ b/robohive/envs/myo/sync_myo.sh @@ -79,11 +79,23 @@ rsync -av --progress $src_path/robohive/tests/test_envs.py $dst_path/myosuite/te rsync -av --progress $src_path/robohive/tests/test_myo.py $dst_path/myosuite/tests/ # Replace -# sed -i "s/robohive\./myosuite\./g" $dst_path/myosuite/envs/myo/__init__.py -find $dst_path/myosuite -type f -name "*.py" -exec sed -i "s/robohive\./myosuite\./g" {} \; -find $dst_path/myosuite/tests -type f -name "*.py" -exec sed -i "s/robohive/myosuite/g" {} \; -find $dst_path/myosuite/logger -type f -name "examine_reference.py" -exec sed -i "s/robohive/myosuite/g" {} \; -find $dst_path/myosuite -type f -name "*.py" -exec sed -i "s/RoboHive:>/MyoSuite:>/g" {} \; +if [ "$(uname)" == "Darwin" ]; then + # macOS + sed -i '' "s/robohive\./myosuite\./g" $dst_path/myosuite/envs/myo/myobase/__init__.py + find $dst_path/myosuite -type f -name "*.py" -exec sed -i '' "s/robohive\./myosuite\./g" {} \; + find $dst_path/myosuite/tests -type f -name "*.py" -exec sed -i '' "s/robohive/myosuite/g" {} \; + find $dst_path/myosuite/logger -type f -name "examine_reference.py" -exec sed -i '' "s/robohive/myosuite/g" {} \; + find $dst_path/myosuite -type f -name "*.py" -exec sed -i '' "s/RoboHive:>/MyoSuite:>/g" {} \; +elif [ "$(uname)" == "Linux" ]; then + sed -i "s/robohive\./myosuite\./g" $dst_path/myosuite/envs/myo/myobase/__init__.py + find $dst_path/myosuite -type f -name "*.py" -exec sed -i "s/robohive\./myosuite\./g" {} \; + find $dst_path/myosuite/tests -type f -name "*.py" -exec sed -i "s/robohive/myosuite/g" {} \; + find $dst_path/myosuite/logger -type f -name "examine_reference.py" -exec sed -i "s/robohive/myosuite/g" {} \; + find $dst_path/myosuite -type f -name "*.py" -exec sed -i "s/RoboHive:>/MyoSuite:>/g" {} \; +else + # Other or unknown OS + echo "This is neither macOS nor Linux" +fi # configs rsync -av --progress $src_path/.gitignore $dst_path/ diff --git a/robohive/envs/quadrupeds/__init__.py b/robohive/envs/quadrupeds/__init__.py index b56c9fb5..d5cb8855 100644 --- a/robohive/envs/quadrupeds/__init__.py +++ b/robohive/envs/quadrupeds/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + from robohive.envs.env_variants import register_env_variant import numpy as np import os diff --git a/robohive/envs/quadrupeds/orient_v0.py b/robohive/envs/quadrupeds/orient_v0.py index ef92f28c..776e9978 100644 --- a/robohive/envs/quadrupeds/orient_v0.py +++ b/robohive/envs/quadrupeds/orient_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6) @@ -182,5 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0]) # Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0]) - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs diff --git a/robohive/envs/quadrupeds/stand_v0.py b/robohive/envs/quadrupeds/stand_v0.py index 071b7e17..72778084 100644 --- a/robohive/envs/quadrupeds/stand_v0.py +++ b/robohive/envs/quadrupeds/stand_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -170,7 +170,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): if reset_qpos is None: reset_qpos = self.init_qpos.copy() @@ -189,6 +189,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): else: raise TypeError(f"Unknown reset type: {self.reset_type}") - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs - diff --git a/robohive/envs/quadrupeds/walk_v0.py b/robohive/envs/quadrupeds/walk_v0.py index 4a793c3b..bf91307f 100644 --- a/robohive/envs/quadrupeds/walk_v0.py +++ b/robohive/envs/quadrupeds/walk_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils import gym import numpy as np from robohive.envs import env_base @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6) @@ -182,6 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0]) # Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0]) - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs - diff --git a/robohive/envs/tcdm/__init__.py b/robohive/envs/tcdm/__init__.py index 7a63ca85..055ac7ac 100644 --- a/robohive/envs/tcdm/__init__.py +++ b/robohive/envs/tcdm/__init__.py @@ -5,7 +5,8 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils import gym; register=gym.register + import numpy as np import os import collections @@ -152,4 +153,3 @@ def register_Franka_object(object_name, data_path=None): ) for obj in OBJECTS: register_Franka_object(obj, data_path=None) - diff --git a/robohive/envs/tcdm/playback_mocap.py b/robohive/envs/tcdm/playback_mocap.py index 719593fc..44ce818a 100644 --- a/robohive/envs/tcdm/playback_mocap.py +++ b/robohive/envs/tcdm/playback_mocap.py @@ -1,5 +1,5 @@ import robohive -import gym +from robohive.utils import gym import numpy as np from mocap_utils import MoCapController, MoCapTask import dm_env diff --git a/robohive/envs/tcdm/track.py b/robohive/envs/tcdm/track.py index e865ee80..af75b558 100644 --- a/robohive/envs/tcdm/track.py +++ b/robohive/envs/tcdm/track.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat @@ -277,10 +277,10 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs diff --git a/robohive/logger/README.md b/robohive/logger/README.md index 775e96f7..bc8c695d 100644 --- a/robohive/logger/README.md +++ b/robohive/logger/README.md @@ -7,7 +7,19 @@ name: { ... } ``` -Let's understand RoboHive's Logger details using its most common usecase -- i.e. recording Robot trajectories. +Nested groups can be created simply by using nested `group` keys. For example keys `g1/sg1`, `g1/sg2` will lead to the following nesting - +``` +name: { + group{g1: + group{sg1: dataset{k1:v1}, dataset{k2:v2}, ...} + group{sg2: dataset{kx:vx}, dataset{ky:vy}, ...} + } + ... +} +``` +Note: Logger preserves the nested keys allowing data access using `data["g1/sg1"]`. The saved H5 dataset is more flexible. Data can be accessed using either `data["g1/sg1"]` or `data["g1"]["sg1"]`. + +Next, let's understand RoboHive's Logger details using its most common usecase -- i.e. recording Robot trajectories. ## Robot Trajectories (Rollouts) diff --git a/robohive/logger/examine_logs.py b/robohive/logger/examine_logs.py index 95df7050..621c3711 100644 --- a/robohive/logger/examine_logs.py +++ b/robohive/logger/examine_logs.py @@ -17,7 +17,7 @@ from robohive.utils.paths_utils import plot as plotnsave_paths from robohive.utils import tensor_utils -import gym +from robohive.utils import gym import click import numpy as np import time @@ -48,7 +48,7 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu # seed and load environments np.random.seed(seed) env = gym.make(env_name) if env_args==None else gym.make(env_name, **(eval(env_args))) - env = env.env + env = env.unwrapped env.seed(seed) # Start a "trace" for recording rollouts @@ -122,7 +122,7 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu trace_horizon = horizon if mode=='record' else path_data['time'].shape[0]-1 # Rollout path -------------------------------- - obs, rwd, done, env_info = env.forward(update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.forward(update_exteroception=include_exteroception) ep_rwd = rwd for i_step in range(trace_horizon+1): @@ -205,10 +205,10 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu env.set_env_state(path_state[i_step+1]) else: raise NotImplementedError("Settings not found") - obs, rwd, done, env_info = env.forward(update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.forward(update_exteroception=include_exteroception) ep_rwd += rwd elif i_step < trace_horizon: # incase last step actions (nans) can cause issues in step - obs, rwd, done, env_info = env.step(act, update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.step(act, update_exteroception=include_exteroception) ep_rwd += rwd # save offscreen buffers as video and clear the dataset diff --git a/robohive/logger/examine_reference.py b/robohive/logger/examine_reference.py index 1aa7c9b6..32a63af7 100644 --- a/robohive/logger/examine_reference.py +++ b/robohive/logger/examine_reference.py @@ -1,5 +1,5 @@ import robohive -import gym +from robohive.utils import gym import time import click from tqdm import tqdm @@ -10,18 +10,21 @@ """ @click.command(help=DESC) -@click.option('-e', '--env_name', type=str, help='environment to load', default="AdroitBananaPass-v0") +@click.option('-e', '--env_name', type=str, help='environment to load', default="MyoHandBananaPass-v0") @click.option('-h', '--horizon', type=int, help='playback horizon', default=-1) @click.option('-n', '--num_playback', type=int, help='Number of time to loop playback', default=1) @click.option('-r', '--render', type=click.Choice(['onscreen', 'none']), help='visualize onscreen?', default='onscreen') def examine_reference(env_name, horizon, num_playback, render): env = gym.make(env_name) + # fixed or random reference + if horizon==1: + horizon = env.spec.max_episode_steps + # infer reference horizon + env = env.unwrapped if horizon==-1: - horizon = env.env.ref.horizon - if horizon==1: # fixed or random reference - horizon = env.env.horizon + horizon = env.ref.horizon # Start playback loops print(f"Rending reference motion (total frames: {horizon})") diff --git a/robohive/logger/roboset_logger.py b/robohive/logger/roboset_logger.py index 371dab7a..d85350e0 100644 --- a/robohive/logger/roboset_logger.py +++ b/robohive/logger/roboset_logger.py @@ -12,7 +12,7 @@ def __init__(self, name, **kwargs): # parse path from robohive format into robopen dataset format def path2dataset(self, path:dict, config_path=None)->dict: """ - Convert Robohive format into roboset format + Convert RoboHive format into roboset format """ path_keys = path.keys() diff --git a/robohive/physics/mj_sim_scene.py b/robohive/physics/mj_sim_scene.py index 0beaa806..6705e8b8 100644 --- a/robohive/physics/mj_sim_scene.py +++ b/robohive/physics/mj_sim_scene.py @@ -37,6 +37,8 @@ def _load_simulation(self, model_handle: Any) -> Any: if isinstance(model_handle, str): if model_handle.endswith('.xml'): sim = dm_mujoco.Physics.from_xml_path(model_handle) + elif isinstance(model_handle, str) and " Robot class is being cleared from the workspace. This is expected if we still need to maintain the active connection to the hardware. A persistent connection to robot is still maintained and will be used next time a robot class is created. Ensure that a robot.close() is called to terminate the persistent connection before exiting the program.") + + # Close the persistnent connection to the robot. This should be called only once at the end when persistent connection is no longer needed. + def close(self): + if self.robot_config is not None: + status = self.hardware_close() if self.is_hardware else True + if status: + prompt(f"Closed {self.name} (Status: {status})", 'white', 'on_grey', flush=True) + self.robot_config = None + else: + prompt(f"Error closing {self.name} (Status: {status})", 'red', 'on_grey', flush=True, type=Prompt.ERROR) + else: + prompt(f"Trying to close a non-existent robot", flush=True, type=Prompt.WARN) def demo_robot(): - import gym + from robohive.utils import gym prompt("Starting Robot===================") env = gym.make('FrankaReachFixed-v0') diff --git a/robohive/simhive/fetch_sim b/robohive/simhive/fetch_sim index 58d561fa..7f6d25ae 160000 --- a/robohive/simhive/fetch_sim +++ b/robohive/simhive/fetch_sim @@ -1 +1 @@ -Subproject commit 58d561fa416b6a151761ced18f2dc8f067188909 +Subproject commit 7f6d25ae8a6f5778379a48fa60c17d685075e64d diff --git a/robohive/simhive/myo_sim b/robohive/simhive/myo_sim index aff0bc09..5e462da7 160000 --- a/robohive/simhive/myo_sim +++ b/robohive/simhive/myo_sim @@ -1 +1 @@ -Subproject commit aff0bc096d98085ee0a6befd613cc9fbff024944 +Subproject commit 5e462da71589fe42164af25ef3c4311231a0d6b2 diff --git a/robohive/tests/test_envs.py b/robohive/tests/test_envs.py index 9ce870d9..78137191 100644 --- a/robohive/tests/test_envs.py +++ b/robohive/tests/test_envs.py @@ -7,12 +7,23 @@ import unittest -import gym +from robohive.utils import gym import numpy as np import pickle import copy -import torch.testing import os +from flatten_dict import flatten + +def assert_close(prm1, prm2, atol=1e-05, rtol=1e-08): + if prm1 is None and prm2 is None: + return True + elif isinstance(prm1,dict) and isinstance(prm2, dict): + prm1_dict = flatten(prm1) + prm2_dict = flatten(prm2) + for key in prm1_dict.keys(): + assert_close(prm1_dict[key], prm2_dict[key], atol=atol, rtol=rtol) + else: + np.testing.assert_allclose(prm1, prm2, atol=atol, rtol=rtol) class TestEnvs(unittest.TestCase): @@ -26,29 +37,31 @@ def check_envs(self, module_name, env_names, lite=False, input_seed=1234): def check_env(self, environment_id, input_seed): - # Skip tests for envs that requires encoder downloading + # If requested, skip tests for envs that requires encoder downloading ROBOHIVE_TEST = os.getenv('ROBOHIVE_TEST') if ROBOHIVE_TEST == 'LITE': if "r3m" in environment_id or "rrl" in environment_id or "vc1" in environment_id: return # test init - env1 = gym.make(environment_id, seed=input_seed) + env1w = gym.make(environment_id, seed=input_seed) + env1 = env1w.unwrapped assert env1.get_input_seed() == input_seed - # test reset - env1.env.reset() + # test reseed and reset + env1.seed(input_seed) + reset_obs1, *_ = env1.reset() # step - u = 0.01*np.random.uniform(low=0, high=1, size=env1.env.sim.model.nu) # small controls - obs1, rwd1, done1, infos1 = env1.env.step(u.copy()) + u = 0.01*np.random.uniform(low=0, high=1, size=env1.sim.model.nu) # small controls + obs1, rwd1, done1, *_, infos1 = env1.step(u.copy()) infos1 = copy.deepcopy(infos1) #info points to internal variables. - proprio1 = env1.env.get_proprioception() - extero1 = env1.env.get_exteroception() + proprio1_t, proprio1_vec, proprio1_dict = env1.get_proprioception() + extero1 = env1.get_exteroception() assert len(obs1>0) # assert len(rwd1>0) # test dicts assert len(infos1) > 0 - obs_dict1 = env1.get_obs_dict(env1.env.sim) + obs_dict1 = env1.get_obs_dict(env1.sim) assert len(obs_dict1) > 0 rwd_dict1 = env1.get_reward_dict(obs_dict1) assert len(rwd_dict1) > 0 @@ -56,27 +69,33 @@ def check_env(self, environment_id, input_seed): env1.reset() # serialize / deserialize env ------------ - env2 = pickle.loads(pickle.dumps(env1)) - # test reset - env2.reset() + env2w = pickle.loads(pickle.dumps(env1w)) + env2 = env2w.unwrapped # test seed assert env2.get_input_seed() == input_seed assert env1.get_input_seed() == env2.get_input_seed(), {env1.get_input_seed(), env2.get_input_seed()} # check input output spaces assert env1.action_space == env2.action_space, (env1.action_space, env2.action_space) assert env1.observation_space == env2.observation_space, (env1.observation_space, env2.observation_space) + + # test reseed and reset + env2.seed(input_seed) + reset_obs2, *_ = env2.reset() + assert_close(reset_obs1, reset_obs2) + # step - obs2, rwd2, done2, infos2 = env2.env.step(u) + obs2, rwd2, done2, *_, infos2 = env2.step(u) infos2 = copy.deepcopy(infos2) - proprio2 = env2.env.get_proprioception() - extero2 = env2.env.get_exteroception() - torch.testing.assert_close(obs1, obs2) - torch.testing.assert_close(proprio1, proprio2) - torch.testing.assert_close(extero1, extero2, atol=2, rtol=0.04) - torch.testing.assert_close(rwd1, rwd2) + proprio2_t, proprio2_vec, proprio2_dict = env2.get_proprioception() + extero2 = env2.get_exteroception() + + assert_close(obs1, obs2) + assert_close(proprio1_vec, proprio2_vec)#, f"Difference in Proprio: {proprio1_vec-proprio2_vec}" + assert_close(extero1, extero2, atol=2, rtol=0.04)#, f"Difference in Extero {extero1}, {extero2}" + assert_close(rwd1, rwd2)#, "Difference in Rewards" assert (done1==done2), (done1, done2) assert len(infos1)==len(infos2), (infos1, infos2) - torch.testing.assert_close(infos1, infos2) + assert_close(infos1, infos2) # reset env2.reset() diff --git a/robohive/tests/test_examine_env.py b/robohive/tests/test_examine_env.py index dd7ded87..d4e531ca 100644 --- a/robohive/tests/test_examine_env.py +++ b/robohive/tests/test_examine_env.py @@ -3,9 +3,39 @@ import unittest from robohive.utils.examine_env import main as examine_env import os +import glob +import time class TestExamineEnv(unittest.TestCase): + + def delete_recent_file(self, filename_pattern, directory='.', age=5): + + # Get the current time + current_time = time.time() + + # Use glob to find files matching the pattern in the specified directory + matching_files = glob.glob(os.path.join(directory, filename_pattern)) + + # Iterate over the matching files + for file_path in matching_files: + try: + # Get the creation time of the file + creation_time = os.path.getctime(file_path) + + # Calculate the time difference between current time and creation time + time_difference = current_time - creation_time + + # If the file was created within the last 5 seconds, delete it + if time_difference <= 5: + os.remove(file_path) + print(f"Deleted file created within {age} seconds: {file_path}") + else: + print(f"File not deleted: {file_path}, created {time_difference} seconds ago.") + except Exception as e: + print(f"Error deleting file: {file_path} - {e}") + + def test_main(self): # Call your function and test its output/assertions print("Testing env with random policy") @@ -32,7 +62,22 @@ def test_offscreen_rendering(self): print("EXCEPTION", result.exception, flush=True) # print(result.output.strip()) self.assertEqual(result.exception, None, result.exception) - os.remove('random_policy0.mp4') + self.delete_recent_file(filename_pattern="random_policy*.mp4") + + def test_paths_plotting(self): + # Call your function and test its output/assertions + print("Testing plotting paths") + runner = click.testing.CliRunner() + result = runner.invoke(examine_env, ["--env_name", "door-v1", \ + "--num_episodes", 1, \ + "--render", "none",\ + "--plot_paths", True]) + print("OUTPUT", result.output.strip(), flush=True) + print("RESULT", result, flush=True) + print("EXCEPTION", result.exception, flush=True) + # print(result.output.strip()) + self.assertEqual(result.exception, None, result.exception) + self.delete_recent_file(filename_pattern="random_policy*Trial*.pdf") def no_test_scripted_policy_loading(self): # Call your function and test its output/assertions diff --git a/robohive/tests/test_logger.py b/robohive/tests/test_logger.py index 56dbf194..d07e0010 100644 --- a/robohive/tests/test_logger.py +++ b/robohive/tests/test_logger.py @@ -6,6 +6,7 @@ from robohive.logger.examine_logs import examine_logs from robohive.utils.examine_env import main as examine_env import os +import re class TestTrace(unittest.TestCase): def teast_trace(self): @@ -24,7 +25,8 @@ def test_logs_playback(self): "--render", "none",\ "--save_paths", True,\ "--output_name", "door_test_logs"]) - log_name = result.output.strip()[-38:] + log_name_pattern = re.compile(r'Saved: (?:.+\.h5)') + log_name = log_name_pattern.search(result.output)[0][7:] result = runner.invoke(examine_logs, ["--env_name", "door-v1", \ "--rollout_path", log_name, \ diff --git a/robohive/tests/test_myo.py b/robohive/tests/test_myo.py index a8823572..be6c1339 100644 --- a/robohive/tests/test_myo.py +++ b/robohive/tests/test_myo.py @@ -53,6 +53,4 @@ def no_test_myomimic(self): if __name__ == '__main__': - unittest.main() - - + unittest.main() \ No newline at end of file diff --git a/robohive/tests/test_sb.py b/robohive/tests/test_sb.py new file mode 100644 index 00000000..5d88b852 --- /dev/null +++ b/robohive/tests/test_sb.py @@ -0,0 +1,11 @@ +from stable_baselines3 import PPO +import robohive +from robohive.utils import gym + +from robohive import robohive_arm_suite +for env_name in sorted(robohive_arm_suite): + print(f"Training {env_name} ========================================") + env = gym.make(env_name) + model = PPO("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=2) + break diff --git a/robohive/tests/test_versions.sh b/robohive/tests/test_versions.sh new file mode 100755 index 00000000..0b1b2f29 --- /dev/null +++ b/robohive/tests/test_versions.sh @@ -0,0 +1,23 @@ +pip uninstall -y gym +pip uninstall -y gymnasium +pip uninstall -y stable-baselines3 + +echo "=================== Testing gym==0.13 ===================" +pip install gym==0.13 +python tests/test_all.py +pip uninstall -y gym + +echo "=================== Testing gym==0.26.2 ===================" +pip install gym==0.26.2 +python tests/test_all.py +pip uninstall -y gym + +echo "=================== Testing gymnasium ===================" +pip install gymnasium +python tests/test_all.py + +echo "=================== Testing Stable Baselines ===================" +pip install stable-baselines3 +python tests/test_sb.py +pip uninstall -y gymnasium +pip uninstall -y stable-baselines3 diff --git a/robohive/tutorials/ee_teleop.py b/robohive/tutorials/ee_teleop.py index e2abbea2..2808e775 100644 --- a/robohive/tutorials/ee_teleop.py +++ b/robohive/tutorials/ee_teleop.py @@ -18,10 +18,11 @@ from robohive.logger.grouped_datasets import Trace as RoboHive_Trace import numpy as np import click -import gym +from robohive.utils import gym try: from vtils.input.keyboard import KeyInput as KeyBoard + from vtils.input.gamepad import GamePad from vtils.input.spacemouse import SpaceMouse except ImportError as e: raise ImportError("Please install vtils -- https://github.com/vikashplus/vtils") @@ -100,12 +101,59 @@ def poll_spacemouse(input_device): return delta_pos, delta_euler, delta_gripper, done +# Poll and process gamepad values +def poll_gamepad(input_device): + # get sensors + sen = input_device.get_sensors() + + # exit request + done = True if (sen["BTN_START"] and sen["BTN_SELECT"]) else False + + scale_factor = 1.0 + + if sen["BTN_NORTH"] == 1: + scale_factor = 0.25 # Hold X to slow down arm movement + + if sen["BTN_WEST"] == 1: # B: open gripper + delta_gripper = 1 + elif sen["BTN_EAST"] == 1: # Y: close gripper + delta_gripper = -1 + else: + delta_gripper = 0 + + # positions + delta_pos = np.array([0, 0, 0]) + # Moving EE forward or backward, when facing the robot + delta_pos[0] = sen["ABS_Y"] + # Moving EE left or right, when facing the robot + delta_pos[1] = sen["ABS_X"] + # Raise or lower + delta_pos[2] = sen["ABS_Z"] - sen["ABS_RZ"] + + # rotations + delta_euler = np.array([0, 0, 0]) + if sen["ABS_HAT0X"] == -1: + delta_euler[0] = -1 + elif sen["ABS_HAT0X"] == 1: + delta_euler[0] = 1 + elif sen["ABS_HAT0Y"] == 1: + delta_euler[1] = -1 + elif sen["ABS_HAT0Y"] == -1: + delta_euler[1] = 1 + elif sen["BTN_TL"] == 1: + delta_euler[2] = -1 + elif sen["BTN_TR"] == 1: + delta_euler[2] = 1 + + return delta_pos * scale_factor, delta_euler * scale_factor, delta_gripper, done + + @click.command(help=DESC) @click.option('-e', '--env_name', type=str, help='environment to load', default='rpFrankaRobotiqData-v0') @click.option('-ea', '--env_args', type=str, default=None, help=('env args. E.g. --env_args "{\'is_hardware\':True}"')) @click.option('-rn', '--reset_noise', type=float, default=0.0, help=('Amplitude of noise during reset')) @click.option('-an', '--action_noise', type=float, default=0.0, help=('Amplitude of action noise during rollout')) -@click.option('-i', '--input_device', type=click.Choice(['keyboard', 'spacemouse']), help='input to use for teleOp', default='keyboard') +@click.option('-i', '--input_device', type=click.Choice(['keyboard', 'spacemouse', 'gamepad']), help='input to use for teleOp', default='keyboard') @click.option('-o', '--output', type=str, default="teleOp_trace.h5", help=('Output name')) @click.option('-h', '--horizon', type=int, help='Rollout horizon', default=100) @click.option('-n', '--num_rollouts', type=int, help='number of repeats for the rollouts', default=1) @@ -141,6 +189,8 @@ def main(env_name, env_args, reset_noise, action_noise, input_device, output, ho # prep input device if input_device=='keyboard': input = KeyBoard() + elif input_device=='gamepad': + input = GamePad() elif input_device=='spacemouse': input = SpaceMouse(vendor_id=vendor_id, product_id=product_id) print("Press both keys to stop listening") @@ -171,6 +221,8 @@ def main(env_name, env_args, reset_noise, action_noise, input_device, output, ho # poll input device -------------------------------------- if input_device=='keyboard': delta_pos, delta_euler, delta_gripper, exit_request = poll_keyboard(input) + elif input_device=='gamepad': + delta_pos, delta_euler, delta_gripper, exit_request = poll_gamepad(input) elif input_device=='spacemouse': delta_pos, delta_euler, delta_gripper, exit_request = poll_spacemouse(input) if exit_request: diff --git a/robohive/tutorials/ee_teleop_oculus.py b/robohive/tutorials/ee_teleop_oculus.py index 01d5fd0e..bd8f3165 100644 --- a/robohive/tutorials/ee_teleop_oculus.py +++ b/robohive/tutorials/ee_teleop_oculus.py @@ -15,7 +15,7 @@ import time import numpy as np import click -import gym +from robohive.utils import gym from robohive.utils.quat_math import euler2quat, euler2mat, mat2quat, diffQuat, mulQuat from robohive.utils.inverse_kinematics import IKResult, qpos_from_site_pose from robohive.logger.roboset_logger import RoboSet_Trace diff --git a/robohive/utils/__init__.py b/robohive/utils/__init__.py index e69de29b..4027afdd 100644 --- a/robohive/utils/__init__.py +++ b/robohive/utils/__init__.py @@ -0,0 +1,19 @@ +import importlib.util + +# Utility to import gym/gymnasium +def import_gym(): + help = """ + Either gym or gymnasium is required to use this library + Options: + (1) re-run the setup instructions for this package (pip install -e .) + (2) install chosen version of gym (pip install gym==0.13) + (3) install chosen version of gymnasium (pip install gymnasium==0.29.1) + """ + if importlib.util.find_spec("gymnasium"): + import gymnasium as gg + elif importlib.util.find_spec("gym"): + import gym as gg + else: + raise ModuleNotFoundError(help) + return gg +gym = import_gym() \ No newline at end of file diff --git a/robohive/utils/examine_env.py b/robohive/utils/examine_env.py index d70f6b41..f8370a64 100644 --- a/robohive/utils/examine_env.py +++ b/robohive/utils/examine_env.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils import gym from robohive.utils.paths_utils import plot as plotnsave_paths import click import numpy as np @@ -28,7 +28,7 @@ class rand_policy(): def __init__(self, env, seed): self.env = env - self.env.action_space.np_random.seed(seed) # requires exlicit seeding + self.env.action_space.seed(seed) # requires explicit seeding def get_action(self, obs): # return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low) @@ -61,7 +61,8 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o # seed and load environments np.random.seed(seed) - env = gym.make(env_name) if env_args==None else gym.make(env_name, **(eval(env_args))) + envw = gym.make(env_name) if env_args==None else gym.make(env_name, **(eval(env_args))) + env = envw.unwrapped env.seed(seed) # resolve policy and outputs @@ -91,7 +92,7 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o # examine policy's behavior to recover paths paths = env.examine_policy_new( policy=pi, - horizon=env.spec.max_episode_steps, + horizon=envw.spec.max_episode_steps, num_episodes=num_episodes, frame_size=(640,480), mode=mode, @@ -101,7 +102,7 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o render=render) # evaluate paths - success_percentage = env.env.evaluate_success(paths) + success_percentage = env.evaluate_success(paths) print(f'Average success over rollouts: {success_percentage}%') # save paths diff --git a/robohive/utils/examine_sim.py b/robohive/utils/examine_sim.py index 621c5369..97d13c8d 100644 --- a/robohive/utils/examine_sim.py +++ b/robohive/utils/examine_sim.py @@ -8,7 +8,7 @@ - python utils/examine_sim.py --sim_path envs/arms/franka/assets/franka_reach_v0.xml --ctrl "0, 0, -1, -1, 0, 0, 0, 0, 0"\n """ -from mujoco_py import load_model_from_path, MjSim, MjViewer +from mujoco import MjModel, MjData, mj_step, mj_forward, viewer import click import numpy as np @@ -19,20 +19,19 @@ @click.option('-h', '--horizon', type=int, help='time (s) to simulate', default=5) def main(sim_path, qpos, ctrl, horizon): - model = load_model_from_path(sim_path) - sim = MjSim(model) - viewer = MjViewer(sim) + model = MjModel.from_xml_path(sim_path) + data = MjData(model) - while sim.data.time>> @implement_for("gym", "0.13", "0.14") + >>> def fun(self, x): + ... # Older gym versions will return x + 1 + ... return x + 1 + ... + >>> @implement_for("gym", "0.14", "0.23") + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for(lambda: import_module("gym"), "0.23", None) + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for("gymnasium", "0.27", None) + >>> def fun(self, x): + ... # If gymnasium is to be used instead of gym, x+3 will be returned + ... return x + 3 + ... + + This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. + """ + + # Stores pointers to fitting implementations: dict[func_name] = func_pointer + _implementations = {} + _setters = [] + _cache_modules = {} + + def __init__( + self, + module_name: Union[str, Callable], + from_version: str = None, + to_version: str = None, + ): + self.module_name = module_name + self.from_version = from_version + self.to_version = to_version + implement_for._setters.append(self) + + @staticmethod + def check_version(version, from_version, to_version): + return (from_version is None or parse(version) >= parse(from_version)) and ( + to_version is None or parse(version) < parse(to_version) + ) + + @staticmethod + def get_class_that_defined_method(f): + """Returns the class of a method, if it is defined, and None otherwise.""" + out = f.__globals__.get(f.__qualname__.split(".")[0], None) + return out + + @classmethod + def get_func_name(cls, fn): + # produces a name like torchrl.module.Class.method or torchrl.module.function + first = str(fn).split(".")[0][len(" str: + """Imports module and returns its version.""" + if not callable(module_name): + module = cls._cache_modules.get(module_name, None) + if module is None: + if module_name in sys.modules: + sys.modules[module_name] = module = import_module(module_name) + else: + cls._cache_modules[module_name] = module = import_module( + module_name + ) + else: + module = module_name() + return module.__version__ + + _lazy_impl = collections.defaultdict(list) + + def _delazify(self, func_name): + for local_call in implement_for._lazy_impl[func_name]: + out = local_call() + return out + + def __call__(self, fn): + # function names are unique + self.func_name = self.get_func_name(fn) + self.fn = fn + implement_for._lazy_impl[self.func_name].append(self._call) + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + return self._delazify(self.func_name)(*args, **kwargs) + + return _lazy_call_fn + + def _call(self): + + # If the module is missing replace the function with the mock. + fn = self.fn + func_name = self.func_name + implementations = implement_for._implementations + + @wraps(fn) + def unsupported(*args, **kwargs): + raise ModuleNotFoundError( + f"Supported version of '{func_name}' has not been found." + ) + + self.do_set = False + # Return fitting implementation if it was encountered before. + if func_name in implementations: + try: + # check that backends don't conflict + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + if not self.do_set: + return implementations[func_name].fn + except ModuleNotFoundError: + # then it's ok, there is no conflict + return implementations[func_name].fn + else: + try: + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + except ModuleNotFoundError: + return unsupported + if self.do_set: + self.module_set() + return fn + return unsupported + + @classmethod + def reset(cls, setters_dict: Dict[str, implement_for] = None): + """Resets the setters in setter_dict. + + ``setter_dict`` is a copy of implementations. We just need to iterate through its + values and call :meth:`~.module_set` for each. + + """ + if setters_dict is None: + setters_dict = copy(cls._implementations) + for setter in setters_dict.values(): + setter.module_set() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"module_name={self.module_name}({self.from_version, self.to_version}), " + f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" + ) diff --git a/robohive/utils/import_utils.py b/robohive/utils/import_utils.py index 7dfa9f12..1f36ffe6 100644 --- a/robohive/utils/import_utils.py +++ b/robohive/utils/import_utils.py @@ -1,8 +1,10 @@ import importlib +import importlib.util import os from os.path import expanduser import git + def mujoco_py_isavailable(): help = """ Options: @@ -20,7 +22,6 @@ def mujoco_isavailable(): (1) install robohive with encoders (pip install robohive['mujoco']) (2) follow setup instructions here: https://github.com/deepmind/mujoco (3) install mujoco via pip (pip install mujoco) - """ if importlib.util.find_spec("mujoco") is None: raise ModuleNotFoundError(help) @@ -141,5 +142,4 @@ def fetch_git(repo_url, commit_hash, clone_directory, clone_path=None): torch_isavailable() torchvision_isavailable() r3m_isavailable() - vc_isavailable() - + vc_isavailable() \ No newline at end of file diff --git a/robohive/utils/paths_utils.py b/robohive/utils/paths_utils.py index 7b2e5f36..32c44e7d 100644 --- a/robohive/utils/paths_utils.py +++ b/robohive/utils/paths_utils.py @@ -16,9 +16,20 @@ from robohive.utils.dict_utils import flatten_dict, dict_numpify import json +#TODO: Harmonize names, remove rollout_paths, use path for one and paths for multiple -# Useful to check the horizon for teleOp / Hardware experiments +# Check the horizon for teleOp / Hardware experiments def plot_horizon(paths, env, fileName_prefix=None): + """ + Check the horizon for teleOp / Hardware experiments + + Args: + paths: paths to examine + env: unwrapped env + fileName_prefix (str): prefix to use in the filename + Saves: + fileName_prefix + '_horizon.pdf' + """ import matplotlib as mpl mpl.use('TkAgg') import matplotlib.pyplot as plt @@ -30,7 +41,7 @@ def plot_horizon(paths, env, fileName_prefix=None): # plot timesteps plt.clf() - rl_dt_ideal = env.env.frame_skip * env.env.model.opt.timestep + rl_dt_ideal = env.frame_skip * env.model.opt.timestep for i, path in enumerate(paths): dt = path['env_infos']['time'][1:] - path['env_infos']['time'][:-1] horizon[i] = path['env_infos']['time'][-1] - path['env_infos'][ @@ -75,8 +86,19 @@ def plot_horizon(paths, env, fileName_prefix=None): print("Saved:", file_name) -# Plot paths to a pdf file +# 2D-plot of paths detailing obs, act, rwds across time def plot(paths, env=None, fileName_prefix=''): + """ + 2D-plot of paths detailing obs, act, rwds across time + + Args: + paths: paths to examine + env: unwrapped env + fileName_prefix: prefix to use in the filename + + Saves: + fileName_prefix + path_name + '.pdf' + """ import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt @@ -108,7 +130,7 @@ def plot(paths, env=None, fileName_prefix=''): nplt2 = 3 ax = plt.subplot(nplt2, 2, 2) ax.set_prop_cycle(None) - # h4 = plt.plot(path['env_infos']['time'], env.env.act_mid + path['actions']*env.env.act_rng, '-', label='act') # plot scaled actions + # h4 = plt.plot(path['env_infos']['time'], env.act_mid + path['actions']*env.act_rng, '-', label='act') # plot scaled actions h4 = plt.plot( path['env_infos']['time'], path['actions'], '-', label='act') # plot normalized actions @@ -143,13 +165,14 @@ def plot(paths, env=None, fileName_prefix=''): ax.axes.xaxis.set_ticklabels([]) plt.ylabel('rewards') ax.yaxis.tick_right() - if env and hasattr(env.env, "rwd_keys_wt"): + + if env and hasattr(env, "rwd_keys_wt"): ax = plt.subplot(nplt2, 2, 6) ax.set_prop_cycle(None) - for key in sorted(env.env.rwd_keys_wt.keys()): + for key in sorted(env.rwd_keys_wt.keys()): plt.plot( path['env_infos']['time'], - path['env_infos']['rwd_dict'][key]*env.env.rwd_keys_wt[key], + path['env_infos']['rwd_dict'][key]*env.rwd_keys_wt[key], label=key) plt.legend( loc='upper left', @@ -167,9 +190,30 @@ def plot(paths, env=None, fileName_prefix=''): # Render frames/videos def render(rollout_path, render_format:str="mp4", cam_names:list=["left"]): - # rollout_path: Absolute path of the rollout (h5/pickle)', default=None - # format: Format to save. Choice['rgb', 'mp4'] - # cam: list of cameras to render. Example ['left', 'right', 'top', 'Franka_wrist'] + """ + Render the frames from a given rollout. + + Parameters: + rollout_path (str): Absolute path of the rollout (h5/pickle). + render_format (str, optional): Format to save the rendered frames. Default is "mp4". + cam_names (list, optional): List of cameras to render. Default is ["left"]. Example ['left', 'right', 'top', 'Franka_wrist'] + + Returns: + None + + Raises: + TypeError: If the path format is unknown. + + Notes: + - The frames are saved in the specified render format. + - The rendered frames can be saved as an mp4 video or as individual RGB images. + - The frames are rendered for each camera specified in the cam_names list. + - The frames are saved in the same directory as the rollout path. + - The output file names are generated based on the rollout name and the camera names. + + Example: + render(rollout_path="/path/to/rollout.h5", render_format="mp4", cam_names=["left", "right"]) + """ output_dir = os.path.dirname(rollout_path) rollout_name = os.path.split(rollout_path)[-1] diff --git a/robohive/utils/prompt_utils.py b/robohive/utils/prompt_utils.py index 8801f316..9d59d632 100644 --- a/robohive/utils/prompt_utils.py +++ b/robohive/utils/prompt_utils.py @@ -7,7 +7,7 @@ """ Utility script to help with information verbosity produced by RoboHive -To control verbosity set env variable ROBOHIVE_VERBOSITY=NONE(default)/INFO/WARN/ERROR/ONCE/ALWAYS +To control verbosity set env variable ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS """ from termcolor import cprint @@ -18,12 +18,13 @@ # Define verbosity levels class Prompt(enum.IntEnum): """Prompt verbosity types""" - NONE = 0 # default (lowest priority) + ALL = 0 # print everything (lowest priority) INFO = 1 WARN = 2 ERROR = 3 - ONCE = 4 # print only once - ALWAYS = 5 # print always (highest priority) + ONCE = 4 # print: once and higher + ALWAYS = 5 # print: only always (highest priority) + SILENT = 6 # Supress all prints # Prompt Cache (to track for Prompt.ONCE messages) @@ -36,7 +37,9 @@ class Prompt(enum.IntEnum): VERBOSE_MODE = Prompt.WARN else: VERBOSE_MODE = VERBOSE_MODE.upper() - if VERBOSE_MODE == 'ALWAYS': + if VERBOSE_MODE == 'SILENT': + VERBOSE_MODE = Prompt.SILENT + elif VERBOSE_MODE == 'ALWAYS': VERBOSE_MODE = Prompt.ALWAYS elif VERBOSE_MODE == 'ERROR': VERBOSE_MODE = Prompt.ERROR @@ -44,14 +47,14 @@ class Prompt(enum.IntEnum): VERBOSE_MODE = Prompt.WARN elif VERBOSE_MODE == 'INFO': VERBOSE_MODE = Prompt.INFO - elif VERBOSE_MODE == 'NONE': - VERBOSE_MODE = Prompt.NONE + elif VERBOSE_MODE == 'ALL': + VERBOSE_MODE = Prompt.ALL else: raise TypeError("Unknown ROBOHIVE_VERBOSITY option") # Programatically override the verbosity -def set_prompt_verbosity(verbose_mode:Prompt=Prompt.NONE): +def set_prompt_verbosity(verbose_mode:Prompt=Prompt.ALL): global VERBOSE_MODE VERBOSE_MODE = verbose_mode @@ -61,7 +64,6 @@ def prompt(data, color=None, on_color=None, flush=False, end="\n", type:Prompt=P global PROMPT_CACHE - # Resolve if we need to print if type == Prompt.ONCE: data_hash = hash(data) if data_hash in PROMPT_CACHE: @@ -79,7 +81,9 @@ def prompt(data, color=None, on_color=None, flush=False, end="\n", type:Prompt=P on_color = "on_red" # resolve printing - if type>=VERBOSE_MODE: + if VERBOSE_MODE == Prompt.SILENT: + return + elif type>=VERBOSE_MODE: if not isinstance(data, str): data = data.__str__() cprint(data, color=color, on_color=on_color, flush=flush, end=end) diff --git a/robohive_init.py b/robohive_init.py index 593803c1..b5ff0b10 100644 --- a/robohive_init.py +++ b/robohive_init.py @@ -1,5 +1,7 @@ -import os, shutil +import os +import shutil from os.path import expanduser + import git curr_dir = os.path.dirname(os.path.abspath(__file__)) @@ -59,7 +61,7 @@ def fetch_simhive(): print("RoboHive:> Initializing...") # Mark the SimHive version (ToDo: Remove this when commits hashes are auto fetched from submodules) - __version__ = "0.6.0" + __version__ = "0.7.0" # Fetch SimHive print("RoboHive:> Downloading simulation assets (upto ~300MBs)") @@ -89,7 +91,7 @@ def fetch_simhive(): clone_path=simhive_path) fetch_git(repo_url="https://github.com/vikashplus/fetch_sim.git", - commit_hash="58d561fa416b6a151761ced18f2dc8f067188909", + commit_hash="7f6d25ae8a6f5778379a48fa60c17d685075e64d", clone_directory="fetch_sim", clone_path=simhive_path) @@ -104,7 +106,7 @@ def fetch_simhive(): clone_path=simhive_path) fetch_git(repo_url="https://github.com/vikashplus/object_sim.git", - commit_hash="87cd8dd5a11518b94fca16bc22bb04f6836c6aa7", + commit_hash="ee0ff14a5369c277687a4636165c5b703bccbf84", clone_directory="object_sim", clone_path=simhive_path) @@ -124,7 +126,7 @@ def fetch_simhive(): clone_path=simhive_path) fetch_git(repo_url="https://github.com/MyoHub/myo_sim.git", - commit_hash="aff0bc096d98085ee0a6befd613cc9fbff024944", + commit_hash="5e462da71589fe42164af25ef3c4311231a0d6b2", clone_directory="myo_sim", clone_path=simhive_path) diff --git a/setup.py b/setup.py index c8e28984..b19980c0 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,22 @@ ================================================= """ import os +import shutil import sys -from setuptools import setup, find_packages -if sys.version_info.major != 3: - print("This library is only compatible with Python 3, but you are running " - "Python {}. The installation will likely fail.".format(sys.version_info.major)) +from setuptools import find_packages, setup + +# Check and warn if FFmpeg is not available +if shutil.which("ffmpeg") is None: + help = """FFmpeg not found in your system. Please install FFmpeg before proceeding + Options: + (1) LINUX: apt-get install ffmpeg + (2) OSX: brew install ffmpeg""" + raise ModuleNotFoundError(help) + +if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 8): + print("This library requires Python 3.8 or higher, but you are running " + "Python {}.{}. The installation will likely fail.".format(sys.version_info.major, sys.version_info.minor)) def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() @@ -27,21 +37,23 @@ def package_files(directory): setup( name='robohive', - version='0.6.0', + version='0.7.0', license='Apache 2.0', packages=find_packages(), - package_data={"": extra_files}, + package_data={"": extra_files+['../robohive_init.py']}, include_package_data=True, - description='environments simulated in MuJoCo', + description='A Unified Framework for Robot Learning', long_description=read('README.md'), long_description_content_type="text/markdown", url='https://github.com/vikashplus/robohive.git', - author='Movement Control Lab, UW', + author='Vikash Kumar', + author_email="vikahsplus@gmail.com", install_requires=[ 'click', - 'gym==0.13', - 'mujoco==2.3.7', - 'dm-control==1.0.14', + # 'gym==0.13', # default to this stable point if caught in gym issues. + 'gymnasium==0.29.1', + 'mujoco==3.1.3', + 'dm-control==1.0.16', 'termcolor', 'sk-video', 'flatten_dict',