diff --git a/robohive/envs/claws/reorient_v0.py b/robohive/envs/claws/reorient_v0.py index ae49abd9..2af1124a 100644 --- a/robohive/envs/claws/reorient_v0.py +++ b/robohive/envs/claws/reorient_v0.py @@ -98,7 +98,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self): + def reset(self, **kwargs): desired_pos = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low']) self.sim.model.site_pos[self.target_sid] = desired_pos self.sim_obsd.model.site_pos[self.target_sid] = desired_pos @@ -108,5 +108,5 @@ def reset(self): self.sim.model.site_quat[self.target_sid] = euler2quat(desired_orien) self.sim_obsd.model.site_quat[self.target_sid] = euler2quat(desired_orien) - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) return obs diff --git a/robohive/envs/fm/franka_ee_pose_v0.py b/robohive/envs/fm/franka_ee_pose_v0.py index 088ec2ac..82008527 100644 --- a/robohive/envs/fm/franka_ee_pose_v0.py +++ b/robohive/envs/fm/franka_ee_pose_v0.py @@ -119,9 +119,9 @@ def get_target_pose(self): return self.np_random.uniform(low=self.sim.model.actuator_ctrlrange[:,0], high=self.sim.model.actuator_ctrlrange[:,1]) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): self.target_pose = self.get_target_pose() - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs class FrankaRobotiqPose(FrankaEEPose): diff --git a/robohive/envs/hands/baoding_v1.py b/robohive/envs/hands/baoding_v1.py index 049face9..1b7cafec 100644 --- a/robohive/envs/hands/baoding_v1.py +++ b/robohive/envs/hands/baoding_v1.py @@ -256,7 +256,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs): # reset counters self.counter=0 @@ -264,7 +264,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6) self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel) + obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) return obs def create_goal_trajectory(self, time_step=.1, time_period=6): @@ -325,6 +325,6 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): class BaodingRandomEnvV1(BaodingFixedEnvV1): - def reset(self): - obs = super().reset(time_period = self.np_random.uniform(high=5, low=7)) + def reset(self, **kwargs): + obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) return obs diff --git a/robohive/envs/multi_task/common/franka_appliance_v1.py b/robohive/envs/multi_task/common/franka_appliance_v1.py index 69e72996..9d5a5920 100644 --- a/robohive/envs/multi_task/common/franka_appliance_v1.py +++ b/robohive/envs/multi_task/common/franka_appliance_v1.py @@ -54,7 +54,7 @@ def _setup( **kwargs, ) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): # randomize object bodies, if requested if self.obj_body_randomize: for body_name in self.obj_body_randomize: @@ -78,4 +78,4 @@ def reset(self, reset_qpos=None, reset_qvel=None): * (self.robot_ranges[:, 1] - self.robot_ranges[:, 0]) ) - return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel) + return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) diff --git a/robohive/envs/multi_task/common/franka_kitchen_v2.py b/robohive/envs/multi_task/common/franka_kitchen_v2.py index 94f4b412..5162d11d 100644 --- a/robohive/envs/multi_task/common/franka_kitchen_v2.py +++ b/robohive/envs/multi_task/common/franka_kitchen_v2.py @@ -113,7 +113,7 @@ def _setup( ) - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): if reset_qpos is None: reset_qpos = self.init_qpos.copy() @@ -128,4 +128,4 @@ def reset(self, reset_qpos=None, reset_qvel=None): if self.robot_base_range: self.sim.model.body_pos[self.robot_base_bid] = self.robot_base_pos + self.np_random.uniform(**self.robot_base_range) - return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel) \ No newline at end of file + return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) \ No newline at end of file diff --git a/robohive/envs/myo/myobase/baoding_v1.py b/robohive/envs/myo/myobase/baoding_v1.py index 418e6767..44f63fad 100644 --- a/robohive/envs/myo/myobase/baoding_v1.py +++ b/robohive/envs/myo/myobase/baoding_v1.py @@ -260,7 +260,7 @@ def evaluate_success(self, paths, logger=None, successful_steps=5): logger.log_kv('effort', effort) return success_percentage - def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None, **kwargs): # reset counters self.counter=0 self.x_radius=self.np_random.uniform(low=self.goal_xrange[0], high=self.goal_xrange[1]) @@ -273,7 +273,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=No self.goal = self.create_goal_trajectory(time_step=self.dt, time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel) + obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) return obs def create_goal_trajectory(self, time_step=.1, time_period=6): diff --git a/robohive/envs/myo/myodm/myodm_v0.py b/robohive/envs/myo/myodm/myodm_v0.py index 789d8a64..abd47ea0 100644 --- a/robohive/envs/myo/myodm/myodm_v0.py +++ b/robohive/envs/myo/myodm/myodm_v0.py @@ -289,10 +289,10 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs diff --git a/robohive/envs/myo/myomimic/myomimic_v0.py b/robohive/envs/myo/myomimic/myomimic_v0.py index 21148b49..c634de95 100644 --- a/robohive/envs/myo/myomimic/myomimic_v0.py +++ b/robohive/envs/myo/myomimic/myomimic_v0.py @@ -155,9 +155,9 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs diff --git a/robohive/envs/quadrupeds/orient_v0.py b/robohive/envs/quadrupeds/orient_v0.py index 92299535..776e9978 100644 --- a/robohive/envs/quadrupeds/orient_v0.py +++ b/robohive/envs/quadrupeds/orient_v0.py @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6) @@ -182,5 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0]) # Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0]) - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs diff --git a/robohive/envs/quadrupeds/stand_v0.py b/robohive/envs/quadrupeds/stand_v0.py index 234e2b81..72778084 100644 --- a/robohive/envs/quadrupeds/stand_v0.py +++ b/robohive/envs/quadrupeds/stand_v0.py @@ -170,7 +170,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): if reset_qpos is None: reset_qpos = self.init_qpos.copy() @@ -189,5 +189,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): else: raise TypeError(f"Unknown reset type: {self.reset_type}") - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs diff --git a/robohive/envs/quadrupeds/walk_v0.py b/robohive/envs/quadrupeds/walk_v0.py index a78d98a0..bf91307f 100644 --- a/robohive/envs/quadrupeds/walk_v0.py +++ b/robohive/envs/quadrupeds/walk_v0.py @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None): + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6) @@ -182,5 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None): self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0]) # Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0]) - obs = super().reset(reset_qpos, reset_qvel) + obs = super().reset(reset_qpos, reset_qvel, **kwargs) return obs diff --git a/robohive/envs/tcdm/track.py b/robohive/envs/tcdm/track.py index 025893c0..af75b558 100644 --- a/robohive/envs/tcdm/track.py +++ b/robohive/envs/tcdm/track.py @@ -277,10 +277,10 @@ def playback(self): return idxs[0] < self.ref.horizon-1 - def reset(self): + def reset(self, **kwargs): # print("Reset") self.ref.reset() - obs = super().reset(self.init_qpos, self.init_qvel) + obs = super().reset(self.init_qpos, self.init_qvel, **kwargs) # print(self.time, self.sim.data.qpos) return obs