diff --git a/robosuite/robots/fixed_base_robot.py b/robosuite/robots/fixed_base_robot.py index 430bc5fd59..fdb761364a 100644 --- a/robosuite/robots/fixed_base_robot.py +++ b/robosuite/robots/fixed_base_robot.py @@ -6,6 +6,7 @@ import robosuite.utils.transform_utils as T from robosuite.controllers import composite_controller_factory +from robosuite.controllers.parts.generic import JointPositionController, JointTorqueController, JointVelocityController from robosuite.robots.robot import Robot @@ -118,6 +119,15 @@ def setup_references(self): self.eef_site_id[arm] = self.sim.model.site_name2id(self.gripper[arm].important_sites["grip_site"]) self.eef_cylinder_id[arm] = self.sim.model.site_name2id(self.gripper[arm].important_sites["grip_cylinder"]) + self._ref_actuator_to_joint_id = np.ones(self.sim.model.nu).astype(np.int32) * (-1) + for part_name, actuator_ids in self._ref_actuators_indexes_dict.items(): + self._ref_actuator_to_joint_id[actuator_ids] = np.array( + [ + self._ref_joints_indexes_dict[part_name].index(self.sim.model.actuator_trnid[i, 0]) + for i in actuator_ids + ] + ) + def control(self, action, policy_step=False): """ Actuate the robot with the @@ -149,8 +159,21 @@ def control(self, action, policy_step=False): for part_name, applied_action in applied_action_dict.items(): applied_action_low = self.sim.model.actuator_ctrlrange[self._ref_actuators_indexes_dict[part_name], 0] applied_action_high = self.sim.model.actuator_ctrlrange[self._ref_actuators_indexes_dict[part_name], 1] - applied_action = np.clip(applied_action, applied_action_low, applied_action_high) - self.sim.data.ctrl[self._ref_actuators_indexes_dict[part_name]] = applied_action + actuator_indexes = self._ref_actuators_indexes_dict[part_name] + actuator_gears = self.sim.model.actuator_gear[actuator_indexes, 0] + + part_controllers = self.composite_controller.get_controller(part_name) + if ( + isinstance(part_controllers, JointPositionController) + or isinstance(part_controllers, JointVelocityController) + or isinstance(part_controllers, JointTorqueController) + ): + # select only the joints that are actuated + actuated_joint_indexes = self._ref_actuator_to_joint_id[actuator_indexes] + applied_action = applied_action[actuated_joint_indexes] + + applied_action = np.clip(applied_action / actuator_gears, applied_action_low, applied_action_high) + self.sim.data.ctrl[actuator_indexes] = applied_action if policy_step: # Update proprioceptive values diff --git a/robosuite/robots/legged_robot.py b/robosuite/robots/legged_robot.py index 8d89431b5d..267ae4f012 100644 --- a/robosuite/robots/legged_robot.py +++ b/robosuite/robots/legged_robot.py @@ -8,6 +8,7 @@ import robosuite.utils.transform_utils as T from robosuite.controllers import composite_controller_factory, load_part_controller_config +from robosuite.controllers.parts.generic import JointPositionController, JointTorqueController, JointVelocityController from robosuite.models.bases.leg_base_model import LegBaseModel from robosuite.robots.mobile_robot import MobileRobot from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER @@ -188,6 +189,17 @@ def control(self, action, policy_step=False): applied_action_high = self.sim.model.actuator_ctrlrange[self._ref_actuators_indexes_dict[part_name], 1] actuator_indexes = self._ref_actuators_indexes_dict[part_name] actuator_gears = self.sim.model.actuator_gear[actuator_indexes, 0] + + part_controllers = self.composite_controller.get_controller(part_name) + if ( + isinstance(part_controllers, JointPositionController) + or isinstance(part_controllers, JointVelocityController) + or isinstance(part_controllers, JointTorqueController) + ): + # select only the joints that are actuated + actuated_joint_indexes = self._ref_actuator_to_joint_id[actuator_indexes] + applied_action = applied_action[actuated_joint_indexes] + applied_action = np.clip(applied_action / actuator_gears, applied_action_low, applied_action_high) self.sim.data.ctrl[actuator_indexes] = applied_action diff --git a/robosuite/robots/mobile_robot.py b/robosuite/robots/mobile_robot.py index 60c28bf35d..74ab9520e0 100644 --- a/robosuite/robots/mobile_robot.py +++ b/robosuite/robots/mobile_robot.py @@ -266,6 +266,15 @@ def setup_references(self): self.sim.model.joint_name2id(joint) for joint in self.robot_model.head_joints ] + self._ref_actuator_to_joint_id = np.ones(self.sim.model.nu).astype(np.int32) * (-1) + for part_name, actuator_ids in self._ref_actuators_indexes_dict.items(): + self._ref_actuator_to_joint_id[actuator_ids] = np.array( + [ + self._ref_joints_indexes_dict[part_name].index(self.sim.model.actuator_trnid[i, 0]) + for i in actuator_ids + ] + ) + def control(self, action, policy_step=False): """ Actuate the robot with the diff --git a/robosuite/robots/robot.py b/robosuite/robots/robot.py index bc1faca82b..44c953a95e 100644 --- a/robosuite/robots/robot.py +++ b/robosuite/robots/robot.py @@ -130,6 +130,7 @@ def __init__( self._ref_actuators_indexes_dict = {} self._ref_joints_indexes_dict = {} + self._ref_actuator_to_joint_id = None self._enabled_parts = {} self.composite_controller = None diff --git a/robosuite/robots/wheeled_robot.py b/robosuite/robots/wheeled_robot.py index 88cbd4265a..f40a804a35 100644 --- a/robosuite/robots/wheeled_robot.py +++ b/robosuite/robots/wheeled_robot.py @@ -6,6 +6,7 @@ import robosuite.utils.transform_utils as T from robosuite.controllers import composite_controller_factory +from robosuite.controllers.parts.generic import JointPositionController, JointTorqueController, JointVelocityController from robosuite.robots.mobile_robot import MobileRobot from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER @@ -124,8 +125,21 @@ def control(self, action, policy_step=False): for part_name, applied_action in applied_action_dict.items(): applied_action_low = self.sim.model.actuator_ctrlrange[self._ref_actuators_indexes_dict[part_name], 0] applied_action_high = self.sim.model.actuator_ctrlrange[self._ref_actuators_indexes_dict[part_name], 1] - applied_action = np.clip(applied_action, applied_action_low, applied_action_high) - self.sim.data.ctrl[self._ref_actuators_indexes_dict[part_name]] = applied_action + actuator_indexes = self._ref_actuators_indexes_dict[part_name] + actuator_gears = self.sim.model.actuator_gear[actuator_indexes, 0] + + part_controllers = self.composite_controller.get_controller(part_name) + if ( + isinstance(part_controllers, JointPositionController) + or isinstance(part_controllers, JointVelocityController) + or isinstance(part_controllers, JointTorqueController) + ): + # select only the joints that are actuated + actuated_joint_indexes = self._ref_actuator_to_joint_id[actuator_indexes] + applied_action = applied_action[actuated_joint_indexes] + + applied_action = np.clip(applied_action / actuator_gears, applied_action_low, applied_action_high) + self.sim.data.ctrl[actuator_indexes] = applied_action # If this is a policy step, also update buffers holding recent values of interest if policy_step: