diff --git a/examples/zbot_walking.kinfer b/examples/zbot_walking.kinfer new file mode 100644 index 00000000..e8705516 Binary files /dev/null and b/examples/zbot_walking.kinfer differ diff --git a/sim/envs/__init__.py b/sim/envs/__init__.py index 8291e70b..016df468 100755 --- a/sim/envs/__init__.py +++ b/sim/envs/__init__.py @@ -18,6 +18,8 @@ from sim.envs.humanoids.h1_env import H1FreeEnv from sim.envs.humanoids.xbot_config import XBotCfg, XBotCfgPPO from sim.envs.humanoids.xbot_env import XBotLFreeEnv +from sim.envs.humanoids.zbot2_config import ZBot2Cfg, ZBot2CfgPPO, ZBot2StandingCfg +from sim.envs.humanoids.zbot2_env import ZBot2Env from sim.envs.humanoids.zeroth_config import ZerothCfg, ZerothCfgPPO from sim.envs.humanoids.zeroth_env import ZerothEnv from sim.utils.task_registry import TaskRegistry # noqa: E402 @@ -29,4 +31,6 @@ task_registry.register("h1", H1FreeEnv, H1Cfg(), H1CfgPPO()) task_registry.register("g1", G1FreeEnv, G1Cfg(), G1CfgPPO()) task_registry.register("XBotL_free", XBotLFreeEnv, XBotCfg(), XBotCfgPPO()) +task_registry.register("zbot2", ZBot2Env, ZBot2Cfg(), ZBot2CfgPPO()) +task_registry.register("zbot2_standing", ZBot2Env, ZBot2StandingCfg(), ZBot2CfgPPO()) task_registry.register("zeroth", ZerothEnv, ZerothCfg(), ZerothCfgPPO()) diff --git a/sim/envs/base/legged_robot.py b/sim/envs/base/legged_robot.py index 47edade7..4d121ebf 100644 --- a/sim/envs/base/legged_robot.py +++ b/sim/envs/base/legged_robot.py @@ -80,8 +80,9 @@ def step(self, actions): def reset(self): """Reset all robots""" + self.reset_idx(torch.arange(self.num_envs, device=self.device)) - # self._resample_default_positions() + obs, privileged_obs, _, _, _ = self.step( torch.zeros(self.num_envs, self.num_actions, device=self.device, requires_grad=False) ) @@ -99,13 +100,8 @@ def post_physics_step(self): self.episode_length_buf += 1 self.common_step_counter += 1 - # prepare quantities - # TODO(pfb30) - debug this - origin = torch.tensor(self.cfg.init_state.rot, device=self.device).repeat(self.num_envs, 1) - origin = quat_conjugate(origin) - if self.imu_indices: - self.base_quat = quat_mul(origin, self.rigid_state[:, self.imu_indices, 3:7]) + self.base_quat = self.rigid_state[:, self.imu_indices, 3:7] self.base_lin_vel[:] = quat_rotate_inverse(self.base_quat, self.rigid_state[:, self.imu_indices, 7:10]) self.base_ang_vel[:] = quat_rotate_inverse(self.base_quat, self.rigid_state[:, self.imu_indices, 10:13]) else: @@ -161,6 +157,11 @@ def reset_idx(self, env_ids): if self.cfg.commands.curriculum and (self.common_step_counter % self.max_episode_length == 0): self.update_command_curriculum(env_ids) + # Add noise to the PD gains + if self.cfg.domain_rand.randomize_pd_gains: + self.p_gains[env_ids] = self.original_p_gains[env_ids] + torch.randn_like(self.p_gains[env_ids]) * 7 + self.d_gains[env_ids] = self.original_d_gains[env_ids] + torch.randn_like(self.d_gains[env_ids]) * 0.3 + # reset robot states self._reset_dofs(env_ids) @@ -193,14 +194,10 @@ def reset_idx(self, env_ids): if self.cfg.env.send_timeouts: self.extras["time_outs"] = self.time_out_buf - # fix reset gravity bug - # TODO(pfb30) - debug this - origin = torch.tensor(self.cfg.init_state.rot, device=self.device).repeat(self.num_envs, 1) - origin = quat_conjugate(origin) if self.imu_indices: - self.base_quat[env_ids] = quat_mul(origin[env_ids, :], self.rigid_state[env_ids, self.imu_indices, 3:7]) + self.base_quat[env_ids] = self.rigid_state[env_ids, self.imu_indices, 3:7] else: - self.base_quat[env_ids] = quat_mul(origin[env_ids, :], self.root_states[env_ids, 3:7]) + self.base_quat[env_ids] = self.root_states[env_ids, 3:7] self.base_euler_xyz = get_euler_xyz_tensor(self.base_quat) self.projected_gravity[env_ids] = quat_rotate_inverse(self.base_quat[env_ids], self.gravity_vec[env_ids]) @@ -502,15 +499,11 @@ def _init_buffers(self): self.rigid_state = gymtorch.wrap_tensor(rigid_body_state) # .view(self.num_envs, -1, 13) self.rigid_state = self.rigid_state.view(self.num_envs, -1, 13) - # TODO(pfb30): debug this - # self.base_quat = self.root_states[:, 3:7] - origin = torch.tensor(self.cfg.init_state.rot, device=self.device).repeat(self.num_envs, 1) - origin = quat_conjugate(origin) if self.imu_indices: - self.base_quat = quat_mul(origin, self.rigid_state[:, self.imu_indices, 3:7]) + self.base_quat = self.rigid_state[:, self.imu_indices, 3:7] else: - self.base_quat = quat_mul(origin, self.root_states[:, 3:7]) + self.base_quat = self.root_states[:, 3:7] self.base_euler_xyz = get_euler_xyz_tensor(self.base_quat) @@ -577,7 +570,7 @@ def _init_buffers(self): self.default_dof_pos = torch.zeros(self.num_dof, dtype=torch.float, device=self.device, requires_grad=False) for i in range(self.num_dofs): name = self.dof_names[i] - print(name) + print(i, name) self.default_dof_pos[i] = self.cfg.init_state.default_joint_angles[name] found = False @@ -591,6 +584,9 @@ def _init_buffers(self): self.d_gains[:, i] = 0.0 raise ValueError(f"PD gain of joint {name} were not defined, setting them to zero") + self.original_p_gains = self.p_gains.clone() + self.original_d_gains = self.d_gains.clone() + self.rand_push_force = torch.zeros((self.num_envs, 3), dtype=torch.float32, device=self.device) self.rand_push_torque = torch.zeros((self.num_envs, 3), dtype=torch.float32, device=self.device) self.default_dof_pos = self.default_dof_pos.unsqueeze(0) diff --git a/sim/envs/base/legged_robot_config.py b/sim/envs/base/legged_robot_config.py index 71428b71..9f16d9b0 100644 --- a/sim/envs/base/legged_robot_config.py +++ b/sim/envs/base/legged_robot_config.py @@ -218,6 +218,7 @@ class sim: substeps = 1 gravity = [0.0, 0.0, -9.81] # [m/s^2] up_axis = 1 # 0 is y, 1 is z + use_projected_gravity = False class physx: num_threads = 10 diff --git a/sim/envs/humanoids/gpr_config.py b/sim/envs/humanoids/gpr_config.py index 128183f9..3660736b 100644 --- a/sim/envs/humanoids/gpr_config.py +++ b/sim/envs/humanoids/gpr_config.py @@ -1,6 +1,6 @@ """Defines the environment configuration for the Getting up task""" -from kinfer import proto as P +# from kinfer import proto as P from sim.env import robot_urdf_path from sim.envs.base.legged_robot_config import ( # type: ignore @@ -28,92 +28,92 @@ class env(LeggedRobotCfg.env): episode_length_s = 24 # episode length in seconds use_ref_actions = False - input_schema = P.IOSchema( - values=[ - P.ValueSchema( - value_name="vector_command", - vector_command=P.VectorCommandSchema( - dimensions=3, # x_vel, y_vel, rot - ), - ), - P.ValueSchema( - value_name="timestamp", - timestamp=P.TimestampSchema( - start_seconds=0, - ), - ), - P.ValueSchema( - value_name="dof_pos", - joint_positions=P.JointPositionsSchema( - joint_names=Robot.joint_names(), - unit=P.JointPositionUnit.RADIANS, - ), - ), - P.ValueSchema( - value_name="dof_vel", - joint_velocities=P.JointVelocitiesSchema( - joint_names=Robot.joint_names(), - unit=P.JointVelocityUnit.RADIANS_PER_SECOND, - ), - ), - P.ValueSchema( - value_name="prev_actions", - joint_positions=P.JointPositionsSchema( - joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS - ), - ), - # Abusing the IMU schema to pass in euler and angular velocity instead of raw sensor data - P.ValueSchema( - value_name="imu_ang_vel", - imu=P.ImuSchema( - use_accelerometer=False, - use_gyroscope=True, - use_magnetometer=False, - ), - ), - P.ValueSchema( - value_name="imu_euler_xyz", - imu=P.ImuSchema( - use_accelerometer=True, - use_gyroscope=False, - use_magnetometer=False, - ), - ), - P.ValueSchema( - value_name="hist_obs", - state_tensor=P.StateTensorSchema( - # 11 is the number of single observation features - 6 from IMU, 5 from command input - # 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions) - shape=[frame_stack * (11 + NUM_JOINTS * 3)], - dtype=P.DType.FP32, - ), - ), - ] - ) - - output_schema = P.IOSchema( - values=[ - P.ValueSchema( - value_name="actions", - joint_positions=P.JointPositionsSchema( - joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS - ), - ), - P.ValueSchema( - value_name="actions_raw", - joint_positions=P.JointPositionsSchema( - joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS - ), - ), - P.ValueSchema( - value_name="new_x", - state_tensor=P.StateTensorSchema( - shape=[frame_stack * (11 + NUM_JOINTS * 3)], - dtype=P.DType.FP32, - ), - ), - ] - ) + # input_schema = P.IOSchema( + # values=[ + # P.ValueSchema( + # value_name="vector_command", + # vector_command=P.VectorCommandSchema( + # dimensions=3, # x_vel, y_vel, rot + # ), + # ), + # P.ValueSchema( + # value_name="timestamp", + # timestamp=P.TimestampSchema( + # start_seconds=0, + # ), + # ), + # P.ValueSchema( + # value_name="dof_pos", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), + # unit=P.JointPositionUnit.RADIANS, + # ), + # ), + # P.ValueSchema( + # value_name="dof_vel", + # joint_velocities=P.JointVelocitiesSchema( + # joint_names=Robot.joint_names(), + # unit=P.JointVelocityUnit.RADIANS_PER_SECOND, + # ), + # ), + # P.ValueSchema( + # value_name="prev_actions", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # # Abusing the IMU schema to pass in euler and angular velocity instead of raw sensor data + # P.ValueSchema( + # value_name="imu_ang_vel", + # imu=P.ImuSchema( + # use_accelerometer=False, + # use_gyroscope=True, + # use_magnetometer=False, + # ), + # ), + # P.ValueSchema( + # value_name="imu_euler_xyz", + # imu=P.ImuSchema( + # use_accelerometer=True, + # use_gyroscope=False, + # use_magnetometer=False, + # ), + # ), + # P.ValueSchema( + # value_name="hist_obs", + # state_tensor=P.StateTensorSchema( + # # 11 is the number of single observation features - 6 from IMU, 5 from command input + # # 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions) + # shape=[frame_stack * (11 + NUM_JOINTS * 3)], + # dtype=P.DType.FP32, + # ), + # ), + # ] + # ) + + # output_schema = P.IOSchema( + # values=[ + # P.ValueSchema( + # value_name="actions", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # P.ValueSchema( + # value_name="actions_raw", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # P.ValueSchema( + # value_name="new_x", + # state_tensor=P.StateTensorSchema( + # shape=[frame_stack * (11 + NUM_JOINTS * 3)], + # dtype=P.DType.FP32, + # ), + # ), + # ] + # ) class safety(LeggedRobotCfg.safety): # safety factors diff --git a/sim/envs/humanoids/zbot2_config.py b/sim/envs/humanoids/zbot2_config.py new file mode 100644 index 00000000..945c1e0c --- /dev/null +++ b/sim/envs/humanoids/zbot2_config.py @@ -0,0 +1,373 @@ +"""Defines the environment configuration for the Getting up task""" + + + +from sim.env import robot_urdf_path +from sim.envs.base.legged_robot_config import ( # type: ignore + LeggedRobotCfg, + LeggedRobotCfgPPO, +) +from sim.resources.zbot2.joints import Robot + +NUM_JOINTS = len(Robot.all_joints()) + + +class ZBot2Cfg(LeggedRobotCfg): + """Configuration class for the Legs humanoid robot.""" + + class env(LeggedRobotCfg.env): + # change the observation dim + frame_stack = 15 + c_frame_stack = 3 + # num_single_obs = 11 + NUM_JOINTS * 3 + num_single_obs = 8 + NUM_JOINTS * 3 # pfb30 + num_observations = int(frame_stack * num_single_obs) + single_num_privileged_obs = 25 + NUM_JOINTS * 4 + num_privileged_obs = int(c_frame_stack * single_num_privileged_obs) + num_actions = NUM_JOINTS + num_envs = 4096 + episode_length_s = 24 # episode length in seconds + use_ref_actions = False + + # from kinfer import proto as P + # input_schema = P.IOSchema( + # values=[ + # P.ValueSchema( + # value_name="vector_command", + # vector_command=P.VectorCommandSchema( + # dimensions=3, # x_vel, y_vel, rot + # ), + # ), + # P.ValueSchema( + # value_name="timestamp", + # timestamp=P.TimestampSchema( + # start_seconds=0, + # ), + # ), + # P.ValueSchema( + # value_name="dof_pos", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), + # unit=P.JointPositionUnit.RADIANS, + # ), + # ), + # P.ValueSchema( + # value_name="dof_vel", + # joint_velocities=P.JointVelocitiesSchema( + # joint_names=Robot.joint_names(), + # unit=P.JointVelocityUnit.RADIANS_PER_SECOND, + # ), + # ), + # P.ValueSchema( + # value_name="prev_actions", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # # Abusing the IMU schema to pass in euler and angular velocity instead of raw sensor data + # P.ValueSchema( + # value_name="imu_ang_vel", + # imu=P.ImuSchema( + # use_accelerometer=False, + # use_gyroscope=True, + # use_magnetometer=False, + # ), + # ), + # P.ValueSchema( + # value_name="imu_euler_xyz", + # imu=P.ImuSchema( + # use_accelerometer=True, + # use_gyroscope=False, + # use_magnetometer=False, + # ), + # ), + # P.ValueSchema( + # value_name="hist_obs", + # state_tensor=P.StateTensorSchema( + # # 11 is the number of single observation features - 6 from IMU, 5 from command input + # # 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions) + # shape=[frame_stack * (11 + NUM_JOINTS * 3)], + # dtype=P.DType.FP32, + # ), + # ), + # ] + # ) + + # output_schema = P.IOSchema( + # values=[ + # P.ValueSchema( + # value_name="actions", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # P.ValueSchema( + # value_name="actions_raw", + # joint_positions=P.JointPositionsSchema( + # joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS + # ), + # ), + # P.ValueSchema( + # value_name="new_x", + # state_tensor=P.StateTensorSchema( + # shape=[frame_stack * (11 + NUM_JOINTS * 3)], + # dtype=P.DType.FP32, + # ), + # ), + # ] + # ) + + class safety: + # safety factors + pos_limit = 1.0 + vel_limit = 1.0 + torque_limit = 0.85 + terminate_after_contacts_on = [] + + class asset(LeggedRobotCfg.asset): + name = "zbot2" + file = str(robot_urdf_path(name)) + + foot_name = ["FOOT", "FOOT_2"] + knee_name = ["WJ-DP00-0002-FK-AP-020_7_5", "WJ-DP00-0002-FK-AP-020_7_6"] + + termination_height = 0.1 + default_feet_height = 0.01 + + penalize_contacts_on = [] + self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter + flip_visual_attachments = False + replace_cylinder_with_capsule = False + fix_base_link = False + + # pfb30 + friction = 0.013343597773929877 + armature = 0.008793405204572328 + + class terrain(LeggedRobotCfg.terrain): + mesh_type = "plane" + # mesh_type = "trimesh" + curriculum = False + # rough terrain only: + measure_heights = False + static_friction = 0.6 + dynamic_friction = 0.6 + terrain_length = 8.0 + terrain_width = 8.0 + num_rows = 10 # number of terrain rows (levels) + num_cols = 10 # number of terrain cols (types) + max_init_terrain_level = 10 # starting curriculum state + # plane; obstacles; uniform; slope_up; slope_down, stair_up, stair_down + terrain_proportions = [0.2, 0.2, 0.4, 0.1, 0.1, 0, 0] + restitution = 0.0 + + class noise: + add_noise = True + noise_level = 0.6 # scales other values + + class noise_scales: + dof_pos = 0.05 + dof_vel = 0.5 + ang_vel = 0.1 + lin_vel = 0.05 + quat = 0.03 + height_measurements = 0.1 + + class init_state(LeggedRobotCfg.init_state): + pos = [0.0, 0.0, Robot.height] + rot = Robot.rotation + + default_joint_angles = {k: 0.0 for k in Robot.all_joints()} + + default_positions = Robot.default_walking() + for joint in default_positions: + default_joint_angles[joint] = default_positions[joint] + + class control(LeggedRobotCfg.control): + # PD Drive parameters: + stiffness = Robot.stiffness() + damping = Robot.damping() + # action scale: target angle = actionScale * action + defaultAngle + action_scale = 0.25 + # decimation: Number of control action updates @ sim DT per policy DT + decimation = 10 # 100hz + + class sim(LeggedRobotCfg.sim): + dt = 0.001 # 1000 Hz + substeps = 1 # 2 + up_axis = 1 # 0 is y, 1 is z + use_projected_gravity = False + + class physx(LeggedRobotCfg.sim.physx): + num_threads = 10 + solver_type = 1 # 0: pgs, 1: tgs + num_position_iterations = 4 + num_velocity_iterations = 1 + contact_offset = 0.01 # [m] + rest_offset = 0.0 # [m] + bounce_threshold_velocity = 0.1 # [m/s] + max_depenetration_velocity = 1.0 + max_gpu_contact_pairs = 2**23 # 2**24 -> needed for 8000 envs and more + default_buffer_size_multiplier = 5 + # 0: never, 1: last sub-step, 2: all sub-steps (default=2) + contact_collection = 2 + + class domain_rand(LeggedRobotCfg.domain_rand): + start_pos_noise = 0.05 + randomize_friction = True + friction_range = [0.1, 1.5] + randomize_base_mass = True + added_mass_range = [-0.1, 0.2] + push_robots = True + push_interval_s = 4 + max_push_vel_xy = 0.1 + max_push_ang_vel = 0.2 + # dynamic randomization + action_delay = 0.5 + action_noise = 0.02 + randomize_pd_gains = False + + class commands(LeggedRobotCfg.commands): + # Vers: lin_vel_x, lin_vel_y, ang_vel_yaw, heading (in heading mode ang_vel_yaw is recomputed from heading error) + num_commands = 4 + resampling_time = 8.0 # time before command are changed[s] + heading_command = True # if true: compute ang vel command from heading error + + class ranges: + lin_vel_x = [-0.3, 0.6] # min max [m/s] + lin_vel_y = [-0.3, 0.3] # min max [m/s] + ang_vel_yaw = [-0.3, 0.3] # min max [rad/s] + heading = [-3.14, 3.14] + + class rewards: + base_height_target = Robot.height + min_dist = 0.07 + max_dist = 0.14 + + # put some settings here for LLM parameter tuning + # pfb30 + target_joint_pos_scale = 0.24 # rad + target_feet_height = 0.03 # m + cycle_time = 0.4 # sec + # if true negative total rewards are clipped at zero (avoids early termination problems) + only_positive_rewards = True + # tracking reward = exp(error*sigma) + tracking_sigma = 5.0 + max_contact_force = 400 # forces above this value are penalized + + class scales: + # reference motion tracking + joint_pos = 1.6 + feet_clearance = 1.5 + feet_contact_number = 1.5 + feet_air_time = 1.4 + foot_slip = -0.1 + feet_distance = 0.2 + knee_distance = 0.2 + # contact + feet_contact_forces = -0.01 + # vel tracking + tracking_lin_vel = 1.6 + tracking_ang_vel = 1.6 + vel_mismatch_exp = 0.5 # lin_z; ang x,y + low_speed = 0.4 + track_vel_hard = 0.5 + + # base pos + default_joint_pos = 1.0 + orientation = 1 + base_height = 0.2 + base_acc = 0.2 + # energy + action_smoothness = -0.002 + torques = -1e-5 + dof_vel = -5e-4 + dof_acc = -1e-7 + collision = -1.0 + + class normalization: + class obs_scales: + lin_vel = 2.0 + ang_vel = 1.0 + dof_pos = 1.0 + dof_vel = 0.05 + quat = 1.0 + height_measurements = 5.0 + + clip_observations = 18.0 + clip_actions = 18.0 + + class viewer: + ref_env = 0 + pos = [4, -4, 2] + lookat = [0, -2, 0] + + +class ZBot2StandingCfg(ZBot2Cfg): + """Standing configuration for the ZBot2 humanoid robot.""" + + class init_state(LeggedRobotCfg.init_state): + pos = [0.0, 0.0, Robot.standing_height] + rot = Robot.rotation + + default_joint_angles = {k: 0.0 for k in Robot.all_joints()} + + default_positions = Robot.default_standing() + for joint in default_positions: + default_joint_angles[joint] = default_positions[joint] + + class rewards: + base_height_target = Robot.height + min_dist = 0.2 + max_dist = 0.5 + target_joint_pos_scale = 0.17 # rad + target_feet_height = 0.05 # m + cycle_time = 0.5 # sec + only_positive_rewards = False + tracking_sigma = 5 + max_contact_force = 200 + + class scales: + default_joint_pos = 1.0 + orientation = 1 + base_height = 0.2 + base_acc = 0.2 + action_smoothness = -0.002 + torques = -1e-5 + dof_vel = -1e-3 + dof_acc = -2.5e-7 + collision = -1.0 + + +class ZBot2CfgPPO(LeggedRobotCfgPPO): + seed = 5 + runner_class_name = "OnPolicyRunner" + + class policy: + init_noise_std = 1.0 + actor_hidden_dims = [512, 256, 128] + critic_hidden_dims = [768, 256, 128] + + class algorithm(LeggedRobotCfgPPO.algorithm): + entropy_coef = 0.001 + learning_rate = 1e-5 + num_learning_epochs = 2 + gamma = 0.994 + lam = 0.9 + num_mini_batches = 4 + + class runner: + policy_class_name = "ActorCritic" + algorithm_class_name = "PPO" + num_steps_per_env = 60 # per iteration + max_iterations = 3001 # number of policy updates + + # logging + save_interval = 100 # check for potential saves every this many iterations + experiment_name = "zbot2" + run_name = "" + # load and resume + resume = False + load_run = -1 # -1 = last run + checkpoint = -1 # -1 = last saved model + resume_path = None # updated from load_run and chkpt diff --git a/sim/envs/humanoids/zbot2_env.py b/sim/envs/humanoids/zbot2_env.py new file mode 100644 index 00000000..ad30246f --- /dev/null +++ b/sim/envs/humanoids/zbot2_env.py @@ -0,0 +1,527 @@ +# mypy: disable-error-code="valid-newtype" +"""Defines the environment for training the humanoid.""" + +from sim.envs.base.legged_robot import LeggedRobot +from sim.resources.zbot2.joints import Robot +from sim.utils.terrain import HumanoidTerrain + +from isaacgym import gymtorch # isort:skip +from isaacgym.torch_utils import * # isort: skip + + +import torch # isort:skip + + +class ZBot2Env(LeggedRobot): + """ZBot2Env is a class that represents a custom environment for a legged robot. + + Args: + cfg: Configuration object for the legged robot. + sim_params: Parameters for the simulation. + physics_engine: Physics engin e used in the simulation. + sim_device: Device used for the simulation. + headless: Flag indicating whether the simulation should be run in headless mode. + + Attributes: + last_feet_z (float): The z-coordinate of the last feet position. + feet_height (torch.Tensor): Tensor representing the height of the feet. + sim (gymtorch.GymSim): The simulation object. + terrain (HumanoidTerrain): The terrain object. + up_axis_idx (int): The index representing the up axis. + command_input (torch.Tensor): Tensor representing the command input. + privileged_obs_buf (torch.Tensor): Tensor representing the privileged observations buffer. + obs_buf (torch.Tensor): Tensor representing the observations buffer. + obs_history (collections.deque): Deque containing the history of observations. + critic_history (collections.deque): Deque containing the history of critic observations. + + Methods: + _push_robots(): Randomly pushes the robots by setting a randomized base velocity. + _get_phase(): Calculates the phase of the gait cycle. + _get_gait_phase(): Calculates the gait phase. + compute_ref_state(): Computes the reference state. + create_sim(): Creates the simulation, terrain, and environments. + _get_noise_scale_vec(cfg): Sets a vector used to scale the noise added to the observations. + step(actions): Performs a simulation step with the given actions. + compute_observations(): Computes the observations. + reset_idx(env_ids): Resets the environment for the specified environment IDs. + """ + + def __init__(self, cfg, sim_params, physics_engine, sim_device, headless): + super().__init__(cfg, sim_params, physics_engine, sim_device, headless) + self.last_feet_z = self.cfg.asset.default_feet_height + self.feet_height = torch.zeros((self.num_envs, 2), device=self.device) + self.reset_idx(torch.tensor(range(self.num_envs), device=self.device)) + + env_handle = self.envs[0] + actor_handle = self.actor_handles[0] + + self.legs_joints = {} + for name, joint in Robot.legs.left.joints_motors(): + # print(name) + joint_handle = self.gym.find_actor_dof_handle(env_handle, actor_handle, joint) + self.legs_joints["left_" + name] = joint_handle + + for name, joint in Robot.legs.right.joints_motors(): + joint_handle = self.gym.find_actor_dof_handle(env_handle, actor_handle, joint) + self.legs_joints["right_" + name] = joint_handle + + self.compute_observations() + + def _push_robots(self): + """Random pushes the robots. Emulates an impulse by setting a randomized base velocity.""" + max_vel = self.cfg.domain_rand.max_push_vel_xy + max_push_angular = self.cfg.domain_rand.max_push_ang_vel + self.rand_push_force[:, :2] = torch_rand_float( + -max_vel, max_vel, (self.num_envs, 2), device=self.device + ) # lin vel x/y + self.root_states[:, 7:9] = self.rand_push_force[:, :2] + + self.rand_push_torque = torch_rand_float( + -max_push_angular, max_push_angular, (self.num_envs, 3), device=self.device + ) + self.root_states[:, 10:13] = self.rand_push_torque + + self.gym.set_actor_root_state_tensor(self.sim, gymtorch.unwrap_tensor(self.root_states)) + + def _get_phase(self): + cycle_time = self.cfg.rewards.cycle_time + phase = self.episode_length_buf * self.dt / cycle_time + return phase + + def _get_gait_phase(self): + # return float mask 1 is stance, 0 is swing + phase = self._get_phase() + sin_pos = torch.sin(2 * torch.pi * phase) + # Add double support phase + stance_mask = torch.zeros((self.num_envs, 2), device=self.device) + # left foot stance + stance_mask[:, 0] = sin_pos >= 0 + # right foot stance + stance_mask[:, 1] = sin_pos < 0 + # Double support phase + stance_mask[torch.abs(sin_pos) < 0.1] = 1 + + return stance_mask + + def check_termination(self): + """Check if environments need to be reset""" + self.reset_buf = torch.any( + torch.norm(self.contact_forces[:, self.termination_contact_indices, :], dim=-1) > 1.0, + dim=1, + ) + self.reset_buf |= self.root_states[:, 2] < self.cfg.asset.termination_height + self.time_out_buf = self.episode_length_buf > self.max_episode_length # no terminal reward for time-outs + self.reset_buf |= self.time_out_buf + + def compute_ref_state(self): + phase = self._get_phase() + sin_pos = torch.sin(2 * torch.pi * phase) + sin_pos_l = sin_pos.clone() + sin_pos_r = sin_pos.clone() + default_clone = self.default_dof_pos.clone() + self.ref_dof_pos = self.default_dof_pos.repeat(self.num_envs, 1) + + scale_1 = self.cfg.rewards.target_joint_pos_scale + scale_2 = 2 * scale_1 + # left foot stance phase set to default joint pos + sin_pos_l[sin_pos_l > 0] = 0 + self.ref_dof_pos[:, self.legs_joints["left_hip_pitch"]] += sin_pos_l * scale_1 + self.ref_dof_pos[:, self.legs_joints["left_knee_pitch"]] += sin_pos_l * scale_2 + self.ref_dof_pos[:, self.legs_joints["left_ankle_pitch"]] += sin_pos_l * scale_1 + + # right foot stance phase set to default joint pos + sin_pos_r[sin_pos_r < 0] = 0 + self.ref_dof_pos[:, self.legs_joints["right_hip_pitch"]] += sin_pos_r * scale_1 + self.ref_dof_pos[:, self.legs_joints["right_knee_pitch"]] += sin_pos_r * scale_2 + self.ref_dof_pos[:, self.legs_joints["right_ankle_pitch"]] += sin_pos_r * scale_1 + + # Double support phase + self.ref_dof_pos[torch.abs(sin_pos) < 0.1] = 0 + + self.ref_action = 2 * self.ref_dof_pos + + def create_sim(self): + """Creates simulation, terrain and evironments""" + self.up_axis_idx = 2 # 2 for z, 1 for y -> adapt gravity accordingly + self.sim = self.gym.create_sim( + self.sim_device_id, + self.graphics_device_id, + self.physics_engine, + self.sim_params, + ) + mesh_type = self.cfg.terrain.mesh_type + if mesh_type in ["heightfield", "trimesh"]: + self.terrain = HumanoidTerrain(self.cfg.terrain, self.num_envs) + if mesh_type == "plane": + self._create_ground_plane() + elif mesh_type == "heightfield": + self._create_heightfield() + elif mesh_type == "trimesh": + self._create_trimesh() + elif mesh_type is not None: + raise ValueError("Terrain mesh type not recognised. Allowed types are [None, plane, heightfield, trimesh]") + self._create_envs() + + def _get_noise_scale_vec(self, cfg): + """Sets a vector used to scale the noise added to the observations. + [NOTE]: Must be adapted when changing the observations structure + + Args: + cfg (Dict): Environment config file + + Returns: + [torch.Tensor]: Vector of scales used to multiply a uniform distribution in [-1, 1] + """ + num_actions = self.num_actions + noise_vec = torch.zeros(self.cfg.env.num_single_obs, device=self.device) + self.add_noise = self.cfg.noise.add_noise + noise_scales = self.cfg.noise.noise_scales + noise_vec[0:5] = 0.0 # commands + noise_vec[5 : (num_actions + 5)] = noise_scales.dof_pos * self.obs_scales.dof_pos + noise_vec[(num_actions + 5) : (2 * num_actions + 5)] = noise_scales.dof_vel * self.obs_scales.dof_vel + noise_vec[(2 * num_actions + 5) : (3 * num_actions + 5)] = 0.0 # previous actions + noise_vec[(3 * num_actions + 5) : (3 * num_actions + 5) + 3] = ( + noise_scales.ang_vel * self.obs_scales.ang_vel + ) # ang vel + noise_vec[(3 * num_actions + 5) + 3 : (3 * num_actions + 5)] = ( + noise_scales.quat * self.obs_scales.quat + ) # euler x,y + return noise_vec + + def compute_observations(self): + phase = self._get_phase() + self.compute_ref_state() + + sin_pos = torch.sin(2 * torch.pi * phase).unsqueeze(1) + cos_pos = torch.cos(2 * torch.pi * phase).unsqueeze(1) + + stance_mask = self._get_gait_phase() + contact_mask = self.contact_forces[:, self.feet_indices, 2] > 5.0 + + self.command_input = torch.cat((sin_pos, cos_pos, self.commands[:, :3] * self.commands_scale), dim=1) + q = (self.dof_pos - self.default_dof_pos) * self.obs_scales.dof_pos + dq = self.dof_vel * self.obs_scales.dof_vel + + diff = self.dof_pos - self.ref_dof_pos + + # pfb30 + # if self.cfg.sim.use_projected_gravity: + # observation_imu = self.projected_gravity.clone() + # else: + # observation_imu = self.base_euler_xyz.clone() * self.obs_scales.quat + + self.privileged_obs_buf = torch.cat( + ( + self.command_input, # 2 + 3 + (self.dof_pos - self.default_joint_pd_target) * self.obs_scales.dof_pos, # 12 + self.dof_vel * self.obs_scales.dof_vel, # 12 + self.actions, # 12 + diff, # 12 + self.base_lin_vel * self.obs_scales.lin_vel, # 3 + self.base_ang_vel * self.obs_scales.ang_vel, # 3 + self.projected_gravity, # 3 + self.rand_push_force[:, :2], # 3 + self.rand_push_torque, # 3 + self.env_frictions, # 1 + self.body_mass / 30.0, # 1 + stance_mask, # 2 + contact_mask, # 2 + ), + dim=-1, + ) + + + obs_buf = torch.cat( + ( + self.command_input, # 5 = 2D(sin cos) + 3D(vel_x, vel_y, aug_vel_yaw) + q, # 20D + dq, # 20D + self.actions, # 20D + # pfb30 + self.projected_gravity, + # self.base_quat, # 4 + # self.base_ang_vel * self.obs_scales.ang_vel, # 3 + # observation_imu, # 3 + ), + dim=-1, + ) + + if self.cfg.terrain.measure_heights: + heights = ( + torch.clip( + self.root_states[:, 2].unsqueeze(1) - 0.5 - self.measured_heights, + -1, + 1.0, + ) + * self.obs_scales.height_measurements + ) + self.privileged_obs_buf = torch.cat((self.obs_buf, heights), dim=-1) + + if self.add_noise: + obs_now = obs_buf.clone() + torch.randn_like(obs_buf) * self.noise_scale_vec * self.cfg.noise.noise_level + else: + obs_now = obs_buf.clone() + self.obs_history.append(obs_now) + self.critic_history.append(self.privileged_obs_buf) + + obs_buf_all = torch.stack([self.obs_history[i] for i in range(self.obs_history.maxlen)], dim=1) # N,T,K + + self.obs_buf = obs_buf_all.reshape(self.num_envs, -1) # N, T*K + self.privileged_obs_buf = torch.cat([self.critic_history[i] for i in range(self.cfg.env.c_frame_stack)], dim=1) + + def reset_idx(self, env_ids): + super().reset_idx(env_ids) + for i in range(self.obs_history.maxlen): + self.obs_history[i][env_ids] *= 0 + for i in range(self.critic_history.maxlen): + self.critic_history[i][env_ids] *= 0 + + # ================================================ Rewards ================================================== # + def _reward_joint_pos(self): + """Calculates the reward based on the difference between the current joint positions and the target joint positions.""" + joint_pos = self.dof_pos.clone() + pos_target = self.ref_dof_pos.clone() + diff = joint_pos - pos_target + r = torch.exp(-2 * torch.norm(diff, dim=1)) - 0.2 * torch.norm(diff, dim=1).clamp(0, 0.5) + + return r + + def _reward_feet_distance(self): + """Calculates the reward based on the distance between the feet. Penilize feet get close to each other or too far away.""" + foot_pos = self.rigid_state[:, self.feet_indices, :2] + foot_dist = torch.norm(foot_pos[:, 0, :] - foot_pos[:, 1, :], dim=1) + fd = self.cfg.rewards.min_dist + max_df = self.cfg.rewards.max_dist + d_min = torch.clamp(foot_dist - fd, -0.5, 0.0) + d_max = torch.clamp(foot_dist - max_df, 0, 0.5) + return (torch.exp(-torch.abs(d_min) * 100) + torch.exp(-torch.abs(d_max) * 100)) / 2 + + def _reward_knee_distance(self): + """Calculates the reward based on the distance between the knee of the humanoid.""" + foot_pos = self.rigid_state[:, self.knee_indices, :2] + foot_dist = torch.norm(foot_pos[:, 0, :] - foot_pos[:, 1, :], dim=1) + fd = self.cfg.rewards.min_dist + max_df = self.cfg.rewards.max_dist / 2 + d_min = torch.clamp(foot_dist - fd, -0.5, 0.0) + d_max = torch.clamp(foot_dist - max_df, 0, 0.5) + return (torch.exp(-torch.abs(d_min) * 100) + torch.exp(-torch.abs(d_max) * 100)) / 2 + + def _reward_foot_slip(self): + """Calculates the reward for minimizing foot slip. The reward is based on the contact forces + and the speed of the feet. A contact threshold is used to determine if the foot is in contact + with the ground. The speed of the foot is calculated and scaled by the contact condition. + """ + contact = self.contact_forces[:, self.feet_indices, 2] > 5.0 + foot_speed_norm = torch.norm(self.rigid_state[:, self.feet_indices, 7:9], dim=2) + rew = torch.sqrt(foot_speed_norm) + rew *= contact + return torch.sum(rew, dim=1) + + def _reward_feet_air_time(self): + """Calculates the reward for feet air time, promoting longer steps. This is achieved by + checking the first contact with the ground after being in the air. The air time is + limited to a maximum value for reward calculation. + """ + contact = self.contact_forces[:, self.feet_indices, 2] > 5.0 + stance_mask = self._get_gait_phase() + self.contact_filt = torch.logical_or(torch.logical_or(contact, stance_mask), self.last_contacts) + self.last_contacts = contact + first_contact = (self.feet_air_time > 0.0) * self.contact_filt + self.feet_air_time += self.dt + air_time = self.feet_air_time.clamp(0, 0.5) * first_contact + self.feet_air_time *= ~self.contact_filt + return air_time.sum(dim=1) + + def _reward_feet_contact_number(self): + """Calculates a reward based on the number of feet contacts aligning with the gait phase. + Rewards or penalizes depending on whether the foot contact matches the expected gait phase. + """ + contact = self.contact_forces[:, self.feet_indices, 2] > 5.0 + stance_mask = self._get_gait_phase() + reward = torch.where(contact == stance_mask, 1, -0.3) + return torch.mean(reward, dim=1) + + def _reward_orientation(self): + """Calculates the reward for maintaining a flat base orientation. It penalizes deviation + from the desired base orientation using the base euler angles and the projected gravity vector. + """ + quat_mismatch = torch.exp(-torch.sum(torch.abs(self.base_euler_xyz[:, :2]), dim=1) * 10) + orientation = torch.exp(-torch.norm(self.projected_gravity[:, :2], dim=1) * 20) + return (quat_mismatch + orientation) / 2 + + def _reward_feet_contact_forces(self): + """Calculates the reward for keeping contact forces within a specified range. Penalizes + high contact forces on the feet. + """ + return torch.sum( + ( + torch.norm(self.contact_forces[:, self.feet_indices, :], dim=-1) - self.cfg.rewards.max_contact_force + ).clip(0, 400), + dim=1, + ) + + def _reward_default_joint_pos(self): + """Calculates the reward for keeping joint positions close to default positions, with a focus + on penalizing deviation in yaw and roll directions. Excludes yaw and roll from the main penalty. + """ + joint_diff = self.dof_pos - self.default_joint_pd_target + left_yaw_roll = joint_diff[:, [self.legs_joints["left_hip_roll"], self.legs_joints["left_hip_yaw"]]] + right_yaw_roll = joint_diff[:, [self.legs_joints["right_hip_roll"], self.legs_joints["right_hip_yaw"]]] + yaw_roll = torch.norm(left_yaw_roll, dim=1) + torch.norm(right_yaw_roll, dim=1) + yaw_roll = torch.clamp(yaw_roll - 0.1, 0, 50) + return torch.exp(-yaw_roll * 100) - 0.01 * torch.norm(joint_diff, dim=1) + + def _reward_base_height(self): + """Calculates the reward based on the robot's base height. Penalizes deviation from a target base height. + The reward is computed based on the height difference between the robot's base and the average height + of its feet when they are in contact with the ground. + """ + stance_mask = self._get_gait_phase() + measured_heights = torch.sum(self.rigid_state[:, self.feet_indices, 2] * stance_mask, dim=1) / torch.sum( + stance_mask, dim=1 + ) + base_height = self.root_states[:, 2] - (measured_heights - self.cfg.asset.default_feet_height) + reward = torch.exp(-torch.abs(base_height - self.cfg.rewards.base_height_target) * 100) + return reward + + def _reward_base_acc(self): + """Computes the reward based on the base's acceleration. Penalizes high accelerations of the robot's base, + encouraging smoother motion. + """ + root_acc = self.last_root_vel - self.root_states[:, 7:13] + rew = torch.exp(-torch.norm(root_acc, dim=1) * 3) + return rew + + def _reward_vel_mismatch_exp(self): + """Computes a reward based on the mismatch in the robot's linear and angular velocities. + Encourages the robot to maintain a stable velocity by penalizing large deviations. + """ + lin_mismatch = torch.exp(-torch.square(self.base_lin_vel[:, 2]) * 10) + ang_mismatch = torch.exp(-torch.norm(self.base_ang_vel[:, :2], dim=1) * 5.0) + + c_update = (lin_mismatch + ang_mismatch) / 2.0 + + return c_update + + def _reward_track_vel_hard(self): + """Calculates a reward for accurately tracking both linear and angular velocity commands. + Penalizes deviations from specified linear and angular velocity targets. + """ + # Tracking of linear velocity commands (xy axes) + lin_vel_error = torch.norm(self.commands[:, :2] - self.base_lin_vel[:, :2], dim=1) + lin_vel_error_exp = torch.exp(-lin_vel_error * 10) + + # Tracking of angular velocity commands (yaw) + ang_vel_error = torch.abs(self.commands[:, 2] - self.base_ang_vel[:, 2]) + ang_vel_error_exp = torch.exp(-ang_vel_error * 10) + + linear_error = 0.2 * (lin_vel_error + ang_vel_error) + + return (lin_vel_error_exp + ang_vel_error_exp) / 2.0 - linear_error + + def _reward_tracking_lin_vel(self): + """Tracks linear velocity commands along the xy axes. + Calculates a reward based on how closely the robot's linear velocity matches the commanded values. + """ + lin_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.base_lin_vel[:, :2]), dim=1) + return torch.exp(-lin_vel_error * self.cfg.rewards.tracking_sigma) + + def _reward_tracking_ang_vel(self): + """Tracks angular velocity commands for yaw rotation. + Computes a reward based on how closely the robot's angular velocity matches the commanded yaw values. + """ + ang_vel_error = torch.square(self.commands[:, 2] - self.base_ang_vel[:, 2]) + return torch.exp(-ang_vel_error * self.cfg.rewards.tracking_sigma) + + def _reward_feet_clearance(self): + """Calculates reward based on the clearance of the swing leg from the ground during movement. + Encourages appropriate lift of the feet during the swing phase of the gait. + """ + # Compute feet contact mask + contact = self.contact_forces[:, self.feet_indices, 2] > 5.0 + + # Get the z-position of the feet and compute the change in z-position + feet_z = self.rigid_state[:, self.feet_indices, 2] - self.cfg.asset.default_feet_height + delta_z = feet_z - self.last_feet_z + self.feet_height += delta_z + self.last_feet_z = feet_z + + # Compute swing mask + swing_mask = 1 - self._get_gait_phase() + + # feet height should be closed to target feet height at the peak + rew_pos = torch.abs(self.feet_height - self.cfg.rewards.target_feet_height) < 0.02 + rew_pos = torch.sum(rew_pos * swing_mask, dim=1) + self.feet_height *= ~contact + + return rew_pos + + def _reward_low_speed(self): + """Rewards or penalizes the robot based on its speed relative to the commanded speed. + This function checks if the robot is moving too slow, too fast, or at the desired speed, + and if the movement direction matches the command. + """ + # Calculate the absolute value of speed and command for comparison + absolute_speed = torch.abs(self.base_lin_vel[:, 0]) + absolute_command = torch.abs(self.commands[:, 0]) + + # Define speed criteria for desired range + speed_too_low = absolute_speed < 0.5 * absolute_command + speed_too_high = absolute_speed > 1.2 * absolute_command + speed_desired = ~(speed_too_low | speed_too_high) + + # Check if the speed and command directions are mismatched + sign_mismatch = torch.sign(self.base_lin_vel[:, 0]) != torch.sign(self.commands[:, 0]) + + # Initialize reward tensor + reward = torch.zeros_like(self.base_lin_vel[:, 0]) + + # Assign rewards based on conditions + # Speed too low + reward[speed_too_low] = -1.0 + # Speed too high + reward[speed_too_high] = 0.0 + # Speed within desired range + reward[speed_desired] = 1.2 + # Sign mismatch has the highest priority + reward[sign_mismatch] = -2.0 + return reward * (self.commands[:, 0].abs() > 0.1) + + def _reward_torques(self): + """Penalizes the use of high torques in the robot's joints. Encourages efficient movement by minimizing + the necessary force exerted by the motors. + """ + return torch.sum(torch.square(self.torques), dim=1) + + def _reward_dof_vel(self): + """Penalizes high velocities at the degrees of freedom (DOF) of the robot. This encourages smoother and + more controlled movements. + """ + return torch.sum(torch.square(self.dof_vel), dim=1) + + def _reward_dof_acc(self): + """Penalizes high accelerations at the robot's degrees of freedom (DOF). This is important for ensuring + smooth and stable motion, reducing wear on the robot's mechanical parts. + """ + return torch.sum(torch.square((self.last_dof_vel - self.dof_vel) / self.dt), dim=1) + + def _reward_collision(self): + """Penalizes collisions of the robot with the environment, specifically focusing on selected body parts. + This encourages the robot to avoid undesired contact with objects or surfaces. + """ + return torch.sum( + 1.0 * (torch.norm(self.contact_forces[:, self.penalised_contact_indices, :], dim=-1) > 0.1), + dim=1, + ) + + def _reward_action_smoothness(self): + """Encourages smoothness in the robot's actions by penalizing large differences between consecutive actions. + This is important for achieving fluid motion and reducing mechanical stress. + """ + term_1 = torch.sum(torch.square(self.last_actions - self.actions), dim=1) + term_2 = torch.sum( + torch.square(self.actions + self.last_last_actions - 2 * self.last_actions), + dim=1, + ) + term_3 = 0.05 * torch.sum(torch.abs(self.actions), dim=1) + return term_1 + term_2 + term_3 diff --git a/sim/model_export2.py b/sim/model_export2.py new file mode 100644 index 00000000..965acb6e --- /dev/null +++ b/sim/model_export2.py @@ -0,0 +1,364 @@ +"""Script to convert weights to Rust-compatible format.""" + +import re +from dataclasses import dataclass, fields +from io import BytesIO +from typing import List, Optional, Tuple + +import onnx +import onnxruntime as ort +import torch +# from scripts.create_mjcf import load_embodiment +from torch import Tensor, nn +from torch.distributions import Normal +import importlib + + +def load_embodiment(embodiment: str): # noqa: ANN401 + # Dynamically import embodiment + module_name = f"sim.resources.{embodiment}.joints" + module = importlib.import_module(module_name) + robot = getattr(module, "Robot") + return robot + + +@dataclass +class ActorCfg: + embodiment: str + cycle_time: float # Cycle time for sinusoidal command input + action_scale: float # Scale for actions + lin_vel_scale: float # Scale for linear velocity + ang_vel_scale: float # Scale for angular velocity + quat_scale: float # Scale for quaternion + dof_pos_scale: float # Scale for joint positions + dof_vel_scale: float # Scale for joint velocities + frame_stack: int # Number of frames to stack for the policy input + clip_observations: float # Clip observations to this value + clip_actions: float # Clip actions to this value + sim_dt: float # Simulation time step + sim_decimation: int # Simulation decimation + tau_factor: float # Torque limit factor + use_projected_gravity: bool # Use projected gravity as IMU observation + + +class ActorCritic(nn.Module): + def __init__( + self, + num_actor_obs: int, + num_critic_obs: int, + num_actions: int, + actor_hidden_dims: List[int] = [256, 256, 256], + critic_hidden_dims: List[int] = [256, 256, 256], + init_noise_std: float = 1.0, + activation: nn.Module = nn.ELU(), + ) -> None: + super(ActorCritic, self).__init__() + + mlp_input_dim_a = num_actor_obs + mlp_input_dim_c = num_critic_obs + + # Policy function. + actor_layers = [] + actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) + actor_layers.append(activation) + for dim_i in range(len(actor_hidden_dims)): + if dim_i == len(actor_hidden_dims) - 1: + actor_layers.append(nn.Linear(actor_hidden_dims[dim_i], num_actions)) + else: + actor_layers.append(nn.Linear(actor_hidden_dims[dim_i], actor_hidden_dims[dim_i + 1])) + actor_layers.append(activation) + self.actor = nn.Sequential(*actor_layers) + + # Value function. + critic_layers = [] + critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) + critic_layers.append(activation) + for dim_i in range(len(critic_hidden_dims)): + if dim_i == len(critic_hidden_dims) - 1: + critic_layers.append(nn.Linear(critic_hidden_dims[dim_i], 1)) + else: + critic_layers.append(nn.Linear(critic_hidden_dims[dim_i], critic_hidden_dims[dim_i + 1])) + critic_layers.append(activation) + self.critic = nn.Sequential(*critic_layers) + + # Action noise. + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + self.distribution = None + + # Disable args validation for speedup. + Normal.set_default_validate_args = False + + +class Actor(nn.Module): + """Actor model. + + Parameters: + policy: The policy network. + cfg: The configuration for the actor. + """ + + def __init__(self, policy: nn.Module, cfg: ActorCfg) -> None: + super().__init__() + + self.robot = load_embodiment(cfg.embodiment) + + # Policy config + default_dof_pos_dict = self.robot.default_standing() + self.num_actions = len(self.robot.all_joints()) + self.frame_stack = cfg.frame_stack + + # 11 is the number of single observation features - 6 from IMU, 5 from command input + # 9 is the number of single observation features - 3 from IMU(quat), 5 from command input + # 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions) + self.num_single_obs = 8 + self.num_actions * 3 # pfb30 + self.num_observations = int(self.frame_stack * self.num_single_obs) + + self.policy = policy + + # This is the policy reference joint positions and should be the same order as the policy and mjcf file. + # CURRENTLY NOT USED IN FORWARD PASS TO MAKE MORE GENERALIZEABLE FOR REAL AND SIM + self.default_dof_pos = torch.tensor(list(default_dof_pos_dict.values())) + + self.action_scale = cfg.action_scale + self.lin_vel_scale = cfg.lin_vel_scale + self.ang_vel_scale = cfg.ang_vel_scale + self.quat_scale = cfg.quat_scale + self.dof_pos_scale = cfg.dof_pos_scale + self.dof_vel_scale = cfg.dof_vel_scale + + self.clip_observations = cfg.clip_observations + self.clip_actions = cfg.clip_actions + + self.cycle_time = cfg.cycle_time + self.use_projected_gravity = cfg.use_projected_gravity + + def get_init_buffer(self) -> Tensor: + return torch.zeros(self.num_observations) + + def forward( + self, + x_vel: Tensor, # x-coordinate of the target velocity + y_vel: Tensor, # y-coordinate of the target velocity + rot: Tensor, # target angular velocity + t: Tensor, # current policy time (sec) + dof_pos: Tensor, # current angular position of the DoFs relative to default + dof_vel: Tensor, # current angular velocity of the DoFs + prev_actions: Tensor, # previous actions taken by the model + projected_gravity: Tensor, # quaternion of the IMU + # imu_euler_xyz: Tensor, # euler angles of the IMU + buffer: Tensor, # buffer of previous observations + ) -> Tuple[Tensor, Tensor, Tensor]: + """Runs the actor model forward pass. + + Args: + x_vel: The x-coordinate of the target velocity, with shape (1). + y_vel: The y-coordinate of the target velocity, with shape (1). + rot: The target angular velocity, with shape (1). + t: The current policy time step, with shape (1). + dof_pos: The current angular position of the DoFs relative to default, with shape (num_actions). + dof_vel: The current angular velocity of the DoFs, with shape (num_actions). + prev_actions: The previous actions taken by the model, with shape (num_actions). + imu_ang_vel: The angular velocity of the IMU, with shape (3), + in radians per second. If IMU is not used, can be all zeros. + imu_euler_xyz: The euler angles of the IMU, with shape (3), + in radians. "XYZ" means (roll, pitch, yaw). If IMU is not used, + can be all zeros. + buffer: The buffer of previous actions, with shape (frame_stack * num_single_obs). This is + the return value of the previous forward pass. On the first + pass, it should be all zeros. + + Returns: + actions_scaled: The actions to take, with shape (num_actions), scaled by the action_scale. + actions: The actions to take, with shape (num_actions). + x: The new buffer of observations, with shape (frame_stack * num_single_obs). + """ + sin_pos = torch.sin(2 * torch.pi * t / self.cycle_time) + cos_pos = torch.cos(2 * torch.pi * t / self.cycle_time) + + # Construct command input + command_input = torch.cat( + ( + sin_pos, + cos_pos, + x_vel * self.lin_vel_scale, + y_vel * self.lin_vel_scale, + rot * self.ang_vel_scale, + ), + dim=0, + ) + + # Calculate current position and velocity observations + q = dof_pos * self.dof_pos_scale + dq = dof_vel * self.dof_vel_scale + + # if self.use_projected_gravity: + # imu_observation = imu_euler_xyz + # else: + # imu_observation = imu_euler_xyz * self.quat_scale + # Construct new observation + new_x = torch.cat( + ( + command_input, + q, + dq, + prev_actions, + projected_gravity, + # imu_ang_vel * self.ang_vel_scale, + # imu_observation, + ), + dim=0, + ) + + # Clip the inputs + new_x = torch.clamp(new_x, -self.clip_observations, self.clip_observations) + + # Add the new frame to the buffer + x = torch.cat((buffer, new_x), dim=0) + # Pop the oldest frame + x = x[self.num_single_obs :] + + policy_input = x.unsqueeze(0) + + # Get actions from the policy + actions = self.policy(policy_input).squeeze(0) + actions_scaled = actions * self.action_scale + + return actions_scaled, actions, x + + +def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]: + all_weights = torch.load(model_path, map_location="cpu")#, weights_only=True) + weights = all_weights["model_state_dict"] + num_actor_obs = weights["actor.0.weight"].shape[1] + num_critic_obs = weights["critic.0.weight"].shape[1] + num_actions = weights["std"].shape[0] + actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)] + critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)] + actor_hidden_dims = actor_hidden_dims[:-1] + critic_hidden_dims = critic_hidden_dims[:-1] + + ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims) + ac_model.load_state_dict(weights) + + a_model = Actor(ac_model.actor, cfg) + + # Gets the model input tensors. + x_vel = torch.randn(1) + y_vel = torch.randn(1) + rot = torch.randn(1) + t = torch.randn(1) + dof_pos = torch.randn(a_model.num_actions) + dof_vel = torch.randn(a_model.num_actions) + prev_actions = torch.randn(a_model.num_actions) + projected_gravity = torch.randn(3) # pfb30 + # imu_euler_xyz = torch.randn(3) + buffer = a_model.get_init_buffer() + input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, projected_gravity, buffer) + + jit_model = torch.jit.script(a_model) + + # Add sim2sim metadata + robot_effort = list(a_model.robot.effort().values()) + robot_stiffness = list(a_model.robot.stiffness().values()) + robot_damping = list(a_model.robot.damping().values()) + default_standing = list(a_model.robot.default_standing().values()) + num_actions = a_model.num_actions + num_observations = a_model.num_observations + + return a_model, { + "robot_effort": robot_effort, + "robot_stiffness": robot_stiffness, + "robot_damping": robot_damping, + "default_standing": default_standing, + "num_actions": num_actions, + "num_observations": num_observations, + }, input_tensors + + +def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession: + """Converts a PyTorch model to a ONNX format. + + Args: + model_path: Path to the PyTorch model. + cfg: The configuration for the actor. + save_path: Path to save the ONNX model. + + Returns: + An ONNX inference session. + """ + all_weights = torch.load(model_path, map_location="cpu")#, weights_only=True) + weights = all_weights["model_state_dict"] + num_actor_obs = weights["actor.0.weight"].shape[1] + num_critic_obs = weights["critic.0.weight"].shape[1] + num_actions = weights["std"].shape[0] + actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)] + critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)] + actor_hidden_dims = actor_hidden_dims[:-1] + critic_hidden_dims = critic_hidden_dims[:-1] + + ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims) + ac_model.load_state_dict(weights) + + a_model = Actor(ac_model.actor, cfg) + + # Gets the model input tensors. + x_vel = torch.randn(1) + y_vel = torch.randn(1) + rot = torch.randn(1) + t = torch.randn(1) + dof_pos = torch.randn(a_model.num_actions) + dof_vel = torch.randn(a_model.num_actions) + prev_actions = torch.randn(a_model.num_actions) + imu_ang_vel = torch.randn(3) + imu_euler_xyz = torch.randn(3) + buffer = a_model.get_init_buffer() + input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer) + + jit_model = torch.jit.script(a_model) + + # Export the model to a buffer + buffer = BytesIO() + torch.onnx.export(jit_model, input_tensors, buffer) + buffer.seek(0) + + # Load the model as an onnx model + model_proto = onnx.load_model(buffer) + + # Add sim2sim metadata + robot_effort = list(a_model.robot.effort().values()) + robot_stiffness = list(a_model.robot.stiffness().values()) + robot_damping = list(a_model.robot.damping().values()) + num_actions = a_model.num_actions + num_observations = a_model.num_observations + + for field_name, field in [ + ("robot_effort", robot_effort), + ("robot_stiffness", robot_stiffness), + ("robot_damping", robot_damping), + ("num_actions", num_actions), + ("num_observations", num_observations), + ]: + meta = model_proto.metadata_props.add() + meta.key = field_name + meta.value = str(field) + + # Add the configuration of the model + for field in fields(cfg): + value = getattr(cfg, field.name) + meta = model_proto.metadata_props.add() + meta.key = field.name + meta.value = str(value) + + if save_path: + onnx.save_model(model_proto, save_path) + + # Convert model to bytes + buffer2 = BytesIO() + onnx.save_model(model_proto, buffer2) + buffer2.seek(0) + + return ort.InferenceSession(buffer2.read()) + + +if __name__ == "__main__": + convert_model_to_onnx("model_3000.pt", ActorCfg(), "policy.onnx") \ No newline at end of file diff --git a/sim/play2.py b/sim/play2.py new file mode 100644 index 00000000..422b421d --- /dev/null +++ b/sim/play2.py @@ -0,0 +1,285 @@ +# mypy: ignore-errors +"""Play a trained policy in the environment. + +Run: + python sim/play2.py --task zbot2 +""" +import argparse +import copy +import logging +import math +import os +import time +import uuid +from datetime import datetime +from typing import Any, Union + +import cv2 +import h5py +import numpy as np +from isaacgym import gymapi +from tqdm import tqdm + +# Local imports third +from sim.env import run_dir +from sim.envs import task_registry +from sim.h5_logger import HDF5Logger + +import torch # special case with isort: skip comment +from sim.env import run_dir # noqa: E402 +from sim.envs import task_registry # noqa: E402 +from sim.model_export2 import ActorCfg, get_actor_policy # noqa: E402 +from sim.utils.helpers import get_args # noqa: E402 +from sim.utils.logger import Logger # noqa: E402 +from kinfer.export.pytorch import export_to_onnx + +logger = logging.getLogger(__name__) + + +def export_policy_as_jit(actor_critic: Any, path: Union[str, os.PathLike]) -> None: + os.makedirs(path, exist_ok=True) + path = os.path.join(path, "policy_1.pt") + model = copy.deepcopy(actor_critic.actor).to("cpu") + traced_script_module = torch.jit.script(model) + traced_script_module.save(path) + + +def play(args: argparse.Namespace) -> None: + logger.info("Configuring environment and training settings...") + env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) + + num_parallel_envs = 2 + env_cfg.env.num_envs = num_parallel_envs + env_cfg.sim.max_gpu_contact_pairs = 2**10 * num_parallel_envs + + if args.trimesh: + env_cfg.terrain.mesh_type = "trimesh" + else: + env_cfg.terrain.mesh_type = "plane" + env_cfg.terrain.num_rows = 5 + env_cfg.terrain.num_cols = 5 + env_cfg.terrain.curriculum = False + env_cfg.terrain.max_init_terrain_level = 5 + env_cfg.noise.add_noise = True + env_cfg.domain_rand.push_robots = True + env_cfg.domain_rand.joint_angle_noise = 0.0 + env_cfg.noise.curriculum = False + env_cfg.noise.noise_level = 0.5 + + train_cfg.seed = 123145 + logger.info("train_cfg.runner_class_name: %s", train_cfg.runner_class_name) + + # prepare environment + env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg) + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + obs = env.get_observations() + + # load policy + train_cfg.runner.resume = True + ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args, train_cfg=train_cfg) + policy = ppo_runner.get_inference_policy(device=env.device) + + # Export policy if needed + if args.export_policy: + path = os.path.join(".") + export_policy_as_jit(ppo_runner.alg.actor_critic, path) + print("Exported policy as jit script to: ", path) + + # export policy as a onnx module (used to run it on web) + if args.export_onnx: + path = ppo_runner.load_path + embodiment = ppo_runner.cfg['experiment_name'].lower() + policy_cfg = ActorCfg( + embodiment=embodiment, + cycle_time=env_cfg.rewards.cycle_time, + sim_dt=env_cfg.sim.dt, + sim_decimation=env_cfg.control.decimation, + tau_factor=env_cfg.safety.torque_limit, + action_scale=env_cfg.control.action_scale, + lin_vel_scale=env_cfg.normalization.obs_scales.lin_vel, + ang_vel_scale=env_cfg.normalization.obs_scales.ang_vel, + quat_scale=env_cfg.normalization.obs_scales.quat, + dof_pos_scale=env_cfg.normalization.obs_scales.dof_pos, + dof_vel_scale=env_cfg.normalization.obs_scales.dof_vel, + frame_stack=env_cfg.env.frame_stack, + clip_observations=env_cfg.normalization.clip_observations, + clip_actions=env_cfg.normalization.clip_actions, + use_projected_gravity=env_cfg.sim.use_projected_gravity, + ) + + actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg) + + # Merge policy_cfg and sim2sim_info into a single config object + export_config = {**vars(policy_cfg), **sim2sim_info} + + export_to_onnx( + actor_model, + input_tensors=input_tensors, + config=export_config, + save_path="kinfer_policy.onnx" + ) + print("Exported policy as kinfer-compatible onnx to: ", path) + + # Prepare for logging + env_logger = Logger(env.dt) + robot_index = 0 + joint_index = 1 + env_steps_to_run = 1000 + + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + if args.log_h5: + # Create directory for HDF5 files + h5_dir = run_dir() / "h5_out" / args.task / now + h5_dir.mkdir(parents=True, exist_ok=True) + + # Get observation dimensions + num_actions = env.num_dof + obs_buffer = env.obs_buf.shape[1] + prev_actions = np.zeros((num_actions), dtype=np.double) + + h5_loggers = [] + for env_idx in range(env_cfg.env.num_envs): + h5_dir = run_dir() / "h5_out" / args.task / now / f"env_{env_idx}" + h5_dir.mkdir(parents=True, exist_ok=True) + + h5_loggers.append(HDF5Logger( + data_name=f"{args.task}_env_{env_idx}", + num_actions=num_actions, + max_timesteps=env_steps_to_run, + num_observations=obs_buffer, + h5_out_dir=str(h5_dir) + )) + + if args.render: + camera_properties = gymapi.CameraProperties() + camera_properties.width = 1920 + camera_properties.height = 1080 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + logger.info("body_handle: %s", body_handle) + logger.info("actor_handle: %s", actor_handle) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + fourcc = cv2.VideoWriter_fourcc(*"MJPG") # type: ignore[attr-defined] + + # Creates a directory to store videos. + video_dir = run_dir() / "videos" + experiment_dir = video_dir / train_cfg.runner.experiment_name + experiment_dir.mkdir(parents=True, exist_ok=True) + + dir = os.path.join(experiment_dir, now + str(args.run_name) + ".mp4") + if not os.path.exists(video_dir): + os.mkdir(video_dir) + if not os.path.exists(experiment_dir): + os.mkdir(experiment_dir) + video = cv2.VideoWriter(dir, fourcc, 50.0, (1920, 1080)) + + for t in tqdm(range(env_steps_to_run)): + actions = policy(obs.detach()) + + if args.fix_command: + env.commands[:, 0] = 0.2 + env.commands[:, 1] = 0.0 + env.commands[:, 2] = 0.0 + env.commands[:, 3] = 0.0 + obs, critic_obs, rews, dones, infos = env.step(actions.detach()) + + if args.render: + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (1080, 1920, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + + video.write(img[..., :3]) + + # Log states + dof_pos_target = actions[robot_index, joint_index].item() * env.cfg.control.action_scale + dof_pos = env.dof_pos[robot_index, joint_index].item() + dof_vel = env.dof_vel[robot_index, joint_index].item() + dof_torque = env.torques[robot_index, joint_index].item() + command_x = env.commands[robot_index, 0].item() + command_y = env.commands[robot_index, 1].item() + command_yaw = env.commands[robot_index, 2].item() + base_vel_x = env.base_lin_vel[robot_index, 0].item() + base_vel_y = env.base_lin_vel[robot_index, 1].item() + base_vel_z = env.base_lin_vel[robot_index, 2].item() + base_vel_yaw = env.base_ang_vel[robot_index, 2].item() + contact_forces_z = env.contact_forces[robot_index, env.feet_indices, 2].cpu().numpy() + + env_logger.log_states( + { + "dof_pos_target": dof_pos_target, + "dof_pos": dof_pos, + "dof_vel": dof_vel, + "dof_torque": dof_torque, + "command_x": command_x, + "command_y": command_y, + "command_yaw": command_yaw, + "base_vel_x": base_vel_x, + "base_vel_y": base_vel_y, + "base_vel_z": base_vel_z, + "base_vel_yaw": base_vel_yaw, + "contact_forces_z": contact_forces_z, + } + ) + actions = actions.detach().cpu().numpy() + if args.log_h5: + # Extract the current observation + for env_idx in range(env_cfg.env.num_envs): + h5_loggers[env_idx].log_data({ + "t": np.array([t * env.dt], dtype=np.float32), + "2D_command": np.array( + [ + np.sin(2 * math.pi * t * env.dt / env.cfg.rewards.cycle_time), + np.cos(2 * math.pi * t * env.dt / env.cfg.rewards.cycle_time), + ], + dtype=np.float32, + ), + "3D_command": np.array(env.commands[env_idx, :3].cpu().numpy(), dtype=np.float32), + "joint_pos": np.array(env.dof_pos[env_idx].cpu().numpy(), dtype=np.float32), + "joint_vel": np.array(env.dof_vel[env_idx].cpu().numpy(), dtype=np.float32), + "prev_actions": prev_actions[env_idx].astype(np.float32), + "curr_actions": actions[env_idx].astype(np.float32), + "ang_vel": env.base_ang_vel[env_idx].cpu().numpy().astype(np.float32), + "euler_rotation": env.base_euler_xyz[env_idx].cpu().numpy().astype(np.float32), + "buffer": env.obs_buf[env_idx].cpu().numpy().astype(np.float32) + }) + + prev_actions = actions + + if infos["episode"]: + num_episodes = env.reset_buf.sum().item() + if num_episodes > 0: + env_logger.log_rewards(infos["episode"], num_episodes) + + env_logger.print_rewards() + + if args.render: + video.release() + + if args.log_h5: + # print(f"Saving HDF5 file to {h5_logger.h5_file_path}") # TODO use code from kdatagen + for h5_logger in h5_loggers: + h5_logger.close() + print(f"HDF5 file(s) saved!") + + +if __name__ == "__main__": + base_args = get_args() + parser = argparse.ArgumentParser(description="Extend base arguments with log_h5") + parser.add_argument("--log_h5", action="store_true", help="Enable HDF5 logging") + parser.add_argument("--render", action="store_true", help="Enable rendering", default=True) + parser.add_argument("--fix_command", action="store_true", help="Fix command", default=True) + parser.add_argument("--export_onnx", action="store_true", help="Export policy as ONNX", default=True) + parser.add_argument("--export_policy", action="store_true", help="Export policy as JIT", default=True) + args, unknown = parser.parse_known_args(namespace=base_args) + + play(args) \ No newline at end of file diff --git a/sim/resources/gpr/robot_fixed.xml b/sim/resources/gpr/robot_fixed.xml index 3bb52cbe..d89f95a9 100644 --- a/sim/resources/gpr/robot_fixed.xml +++ b/sim/resources/gpr/robot_fixed.xml @@ -191,4 +191,4 @@ - \ No newline at end of file + diff --git a/sim/resources/zbot2/imu_data.csv b/sim/resources/zbot2/imu_data.csv new file mode 100644 index 00000000..0008fa80 --- /dev/null +++ b/sim/resources/zbot2/imu_data.csv @@ -0,0 +1,128 @@ +timestamp,accel_x,accel_y,accel_z,gyro_x,gyro_y,gyro_z,mag_x,mag_y,mag_z +2025-01-07T17:45:54.101723,-1.1399999857,9.7899999619,-0.3599999845,-0.2500000000,0.4375000000,1.3750000000,-12.2500000000,-19.5000000000,-19.5000000000 +2025-01-07T17:45:54.180039,-0.8499999642,9.7799997330,-0.4299999774,-0.5000000000,-0.3750000000,2.6875000000,-12.2500000000,-19.1875000000,-19.5000000000 +2025-01-07T17:45:54.261766,-0.7699999809,9.8099994659,-0.3499999940,-0.8750000000,0.0625000000,0.1875000000,-12.2500000000,-20.1875000000,-19.2500000000 +2025-01-07T17:45:54.343540,-0.9799999595,9.7699995041,-0.3499999940,-0.5625000000,-0.2500000000,0.7500000000,-12.0000000000,-20.1875000000,-20.7500000000 +2025-01-07T17:45:54.420801,-0.8899999857,9.8000001907,-0.3299999833,-0.2500000000,0.0625000000,0.0000000000,-12.2500000000,-20.1875000000,-20.7500000000 +2025-01-07T17:45:54.501844,-0.8999999762,9.7899999619,-0.3299999833,-0.0625000000,-0.1250000000,-0.1250000000,-13.0000000000,-19.8750000000,-20.7500000000 +2025-01-07T17:45:54.584752,-0.8399999738,9.7699995041,-0.3100000024,0.0625000000,0.1250000000,0.3750000000,-12.2500000000,-21.0000000000,-21.1875000000 +2025-01-07T17:45:54.663277,-0.8999999762,9.7699995041,-0.3199999928,0.1875000000,0.1875000000,0.0000000000,-13.0000000000,-20.1875000000,-19.5625000000 +2025-01-07T17:45:54.744018,-0.8100000024,9.7899999619,-0.3299999833,0.0000000000,-0.0625000000,0.6250000000,-13.0000000000,-19.8750000000,-20.0625000000 +2025-01-07T17:45:54.825365,-0.8199999928,9.7799997330,-0.2999999821,0.5625000000,-0.3125000000,1.2500000000,-12.6875000000,-20.1875000000,-22.0000000000 +2025-01-07T17:45:54.904344,-0.7599999905,9.8299999237,-0.3999999762,0.6250000000,-0.2500000000,-0.1875000000,-13.0000000000,-19.8750000000,-20.7500000000 +2025-01-07T17:45:54.983908,-1.0000000000,9.8099994659,-0.4699999988,0.4375000000,0.1875000000,-0.8125000000,-13.0000000000,-19.5000000000,-20.3750000000 +2025-01-07T17:45:55.064528,-0.8899999857,9.8199996948,-0.3499999940,-0.4375000000,0.0000000000,0.3750000000,-13.3750000000,-20.5625000000,-19.5625000000 +2025-01-07T17:45:55.152842,-0.9300000072,9.7799997330,-0.3599999845,-0.3125000000,0.1875000000,0.3750000000,-12.2500000000,-20.5625000000,-20.3750000000 +2025-01-07T17:45:55.235584,-0.8899999857,9.7799997330,-0.3999999762,-0.5625000000,0.3750000000,-0.1250000000,-13.0000000000,-20.5625000000,-20.3750000000 +2025-01-07T17:45:55.316714,-0.8100000024,9.8099994659,-0.3299999833,-0.5625000000,0.0000000000,-0.1250000000,-13.3750000000,-19.8750000000,-20.0000000000 +2025-01-07T17:45:55.397102,-0.8399999738,9.7699995041,-0.2899999917,-0.2500000000,-0.1250000000,-0.5000000000,-12.2500000000,-19.8750000000,-19.5625000000 +2025-01-07T17:45:55.484345,-0.8999999762,9.7899999619,-0.2999999821,0.0625000000,-0.4375000000,-0.3125000000,-13.7500000000,-20.1875000000,-19.2500000000 +2025-01-07T17:45:55.567359,-0.8299999833,9.8000001907,-0.3599999845,0.1875000000,-0.3750000000,0.6250000000,-12.6875000000,-20.1875000000,-19.5625000000 +2025-01-07T17:45:55.650892,-0.8299999833,9.8000001907,-0.3799999952,0.1250000000,0.0000000000,-0.1875000000,-12.6875000000,-21.0000000000,-20.7500000000 +2025-01-07T17:45:55.729605,-0.8999999762,9.7899999619,-0.3400000036,-0.1250000000,0.0625000000,-0.6875000000,-12.6875000000,-19.0625000000,-21.5625000000 +2025-01-07T17:45:55.809749,-0.8599999547,9.7899999619,-0.3199999928,-0.1250000000,0.0625000000,-0.3750000000,-12.6875000000,-20.1875000000,-20.0625000000 +2025-01-07T17:45:55.889831,-0.8899999857,9.7899999619,-0.2999999821,-0.1875000000,0.6250000000,-0.3750000000,-13.3750000000,-19.8750000000,-21.5625000000 +2025-01-07T17:45:55.968885,-0.6800000072,9.7799997330,-0.3299999833,-0.1875000000,1.1250000000,-0.8750000000,-13.0000000000,-20.5625000000,-20.7500000000 +2025-01-07T17:45:56.051414,-0.9199999571,9.8000001907,-0.3499999940,0.0000000000,-0.1250000000,-0.0625000000,-13.7500000000,-19.5000000000,-20.6875000000 +2025-01-07T17:45:56.131055,-0.8499999642,9.7799997330,-0.2699999809,0.0000000000,0.5625000000,-0.1250000000,-12.2500000000,-19.8750000000,-20.2500000000 +2025-01-07T17:45:56.215095,-0.9199999571,9.7799997330,-0.3100000024,0.3125000000,0.0625000000,0.3125000000,-13.3750000000,-19.8750000000,-20.6875000000 +2025-01-07T17:45:56.292211,-0.9499999881,9.8000001907,-0.3400000036,0.0000000000,0.0000000000,0.5625000000,-13.0000000000,-19.8750000000,-19.8750000000 +2025-01-07T17:45:56.371972,-0.9300000072,9.7699995041,-0.3100000024,-0.1875000000,0.5625000000,-0.1875000000,-13.0000000000,-19.8750000000,-19.5000000000 +2025-01-07T17:45:56.452707,-0.7999999523,9.8099994659,-0.3499999940,-0.3125000000,0.1875000000,0.1250000000,-12.6875000000,-19.5000000000,-20.2500000000 +2025-01-07T17:45:56.551917,-0.8999999762,9.7799997330,-0.3100000024,-0.3125000000,0.1250000000,-0.1875000000,-12.2500000000,-19.5000000000,-20.6875000000 +2025-01-07T17:45:56.650230,-1.0000000000,9.7799997330,-0.2399999946,-0.0625000000,0.0000000000,-0.1250000000,-13.0000000000,-20.1875000000,-21.0625000000 +2025-01-07T17:45:56.729698,-0.7799999714,9.7899999619,-0.3100000024,0.0625000000,-0.1875000000,-0.0625000000,-12.2500000000,-19.5000000000,-21.0625000000 +2025-01-07T17:45:56.812723,-1.0099999905,9.7899999619,-0.2899999917,0.3750000000,0.0000000000,0.8125000000,-12.6875000000,-19.8750000000,-19.0625000000 +2025-01-07T17:45:56.895949,-0.8799999952,9.8099994659,-0.2899999917,0.5000000000,-0.3750000000,0.5625000000,-12.0000000000,-19.5000000000,-20.2500000000 +2025-01-07T17:45:56.976409,-0.9599999785,9.7899999619,-0.3299999833,0.0000000000,0.3750000000,-0.3125000000,-12.2500000000,-19.1875000000,-20.6875000000 +2025-01-07T17:45:57.056415,-0.8999999762,9.7699995041,-0.3299999833,-0.2500000000,0.5625000000,0.1875000000,-12.2500000000,-20.5625000000,-20.6875000000 +2025-01-07T17:45:57.136177,-0.8700000048,9.7899999619,-0.3100000024,0.0000000000,-0.0625000000,0.3750000000,-13.0000000000,-19.8750000000,-19.5000000000 +2025-01-07T17:45:57.216527,-0.9499999881,9.8000001907,-0.3700000048,0.1250000000,-0.6250000000,1.1250000000,-13.3750000000,-19.1875000000,-19.8750000000 +2025-01-07T17:45:57.297268,-0.8799999952,9.7799997330,-0.3299999833,-0.6250000000,-0.1875000000,0.0625000000,-12.2500000000,-19.5000000000,-20.6875000000 +2025-01-07T17:45:57.376972,-0.9799999595,9.7899999619,-0.2299999893,-0.6250000000,0.7500000000,0.1250000000,-12.6875000000,-19.8750000000,-20.6875000000 +2025-01-07T17:45:57.457879,-0.8999999762,9.7699995041,-0.2699999809,-0.1875000000,0.1875000000,1.8750000000,-12.2500000000,-20.1875000000,-20.6875000000 +2025-01-07T17:45:57.532355,-0.6899999976,9.8099994659,-0.3499999940,0.1250000000,-1.2500000000,1.9375000000,-13.3750000000,-19.8750000000,-20.6875000000 +2025-01-07T17:45:57.608843,-0.8100000024,9.8000001907,-0.3199999928,-0.4375000000,0.1875000000,0.8750000000,-12.6875000000,-19.8750000000,-21.0625000000 +2025-01-07T17:45:57.693140,-0.7400000095,9.7899999619,-0.2500000000,-0.2500000000,-0.3750000000,0.7500000000,-12.2500000000,-19.8750000000,-21.5000000000 +2025-01-07T17:45:57.770711,-0.7799999714,9.8000001907,-0.2999999821,-0.1875000000,0.0625000000,0.6875000000,-13.3750000000,-19.5000000000,-20.2500000000 +2025-01-07T17:45:57.850232,-0.8599999547,9.7899999619,-0.3499999940,-0.5625000000,0.5000000000,1.1875000000,-13.0000000000,-20.1875000000,-20.2500000000 +2025-01-07T17:45:57.930374,-0.6699999571,9.8199996948,-0.2800000012,-0.8750000000,-0.1250000000,0.3750000000,-12.6875000000,-20.1875000000,-21.5000000000 +2025-01-07T17:45:58.009076,-0.9099999666,9.7699995041,-0.2599999905,-1.1875000000,1.3125000000,-0.6250000000,-12.6875000000,-19.8750000000,-21.0625000000 +2025-01-07T17:45:58.090820,-0.7999999523,9.7799997330,-0.1999999881,-0.4375000000,1.0000000000,-0.3125000000,-12.2500000000,-19.8750000000,-21.5000000000 +2025-01-07T17:45:58.169681,-0.8899999857,9.8099994659,-0.2500000000,0.1250000000,0.0000000000,1.7500000000,-13.0000000000,-20.1875000000,-21.0625000000 +2025-01-07T17:45:58.250541,-0.4699999988,9.8299999237,-0.2299999893,0.5000000000,0.1875000000,-0.2500000000,-12.6875000000,-19.5000000000,-19.8750000000 +2025-01-07T17:45:58.333736,-0.7799999714,9.8099994659,-0.2099999934,1.0625000000,-0.6250000000,-1.6250000000,-12.6875000000,-19.8750000000,-21.5000000000 +2025-01-07T17:45:58.412990,-1.0399999619,9.7899999619,-0.2299999893,1.0000000000,0.3125000000,-0.5625000000,-12.6875000000,-20.1875000000,-20.6875000000 +2025-01-07T17:45:58.493966,-0.7599999905,9.7799997330,-0.2599999905,0.8125000000,0.0000000000,-0.3750000000,-12.0000000000,-19.8750000000,-21.0625000000 +2025-01-07T17:45:58.573059,-0.6299999952,9.8099994659,-0.3599999845,0.6875000000,-1.3125000000,0.3125000000,-12.6875000000,-19.8750000000,-20.6875000000 +2025-01-07T17:45:58.654102,-1.7300000191,9.6799993515,-0.3100000024,0.3125000000,2.7500000000,6.5625000000,-12.2500000000,-19.1875000000,-19.0625000000 +2025-01-07T17:45:58.736044,-3.3199999332,9.9200000763,-0.2099999934,6.6250000000,10.6250000000,32.8125000000,-13.7500000000,-19.8750000000,-20.6875000000 +2025-01-07T17:45:58.814798,1.8399999142,9.0399999619,-0.1299999952,-10.0625000000,17.2500000000,54.4375000000,-13.2500000000,-19.5000000000,-20.6875000000 +2025-01-07T17:45:58.887458,1.2599999905,9.9600000381,-1.8299999237,-2.0625000000,84.3750000000,72.7500000000,-13.6875000000,-19.5000000000,-18.6875000000 +2025-01-07T17:45:58.964572,-1.9699999094,9.9399995804,-1.0799999237,-38.4375000000,28.8750000000,65.6875000000,-15.0000000000,-18.7500000000,-22.3750000000 +2025-01-07T17:45:59.041501,1.7500000000,8.7899999619,0.7599999905,-36.8125000000,19.4375000000,82.0000000000,-17.0625000000,-18.8750000000,-24.0000000000 +2025-01-07T17:45:59.123427,2.0799999237,7.7599997520,2.2400000095,-28.1250000000,2.2500000000,98.5625000000,-20.6875000000,-18.5625000000,-24.3750000000 +2025-01-07T17:45:59.198109,5.7599997520,7.5199999809,2.4400000572,2.5000000000,-12.6250000000,103.4375000000,-21.8750000000,-15.6875000000,-22.3750000000 +2025-01-07T17:45:59.278435,4.9400000572,6.1799998283,0.9899999499,1.5000000000,-24.8750000000,108.7500000000,-22.1875000000,-12.7500000000,-21.0000000000 +2025-01-07T17:45:59.357586,7.4400000572,6.3499999046,0.3100000024,17.1875000000,-37.0625000000,92.5625000000,-24.2500000000,-8.5000000000,-17.2500000000 +2025-01-07T17:45:59.438190,6.6099996567,4.8699998856,0.8199999928,23.5625000000,-15.0625000000,117.8125000000,-27.0000000000,-5.2500000000,-15.8750000000 +2025-01-07T17:45:59.518392,7.0599999428,3.5399999619,-1.0499999523,22.0625000000,-51.8750000000,134.9375000000,-28.0625000000,0.7500000000,-10.6875000000 +2025-01-07T17:45:59.599096,10.8800001144,1.4299999475,-2.2899999619,21.4375000000,-65.8750000000,127.3125000000,-32.0625000000,8.5000000000,4.1875000000 +2025-01-07T17:45:59.678530,11.3199996948,1.0099999905,-2.9299998283,24.6875000000,-27.6250000000,68.8750000000,-35.7500000000,11.1875000000,21.5000000000 +2025-01-07T17:45:59.767782,15.9200000763,-2.8699998856,-1.4199999571,-8.6875000000,18.5625000000,-13.0000000000,-35.6875000000,-3.2500000000,70.5000000000 +2025-01-07T17:45:59.845379,9.5999994278,1.6799999475,-1.1200000048,5.6875000000,31.8750000000,-5.6250000000,-31.1875000000,-7.6875000000,65.6875000000 +2025-01-07T17:45:59.928786,8.6099996567,1.5799999237,-2.0899999142,-0.6875000000,-29.7500000000,5.0625000000,-39.1875000000,-6.8750000000,60.5625000000 +2025-01-07T17:46:00.011361,9.5199995041,1.4599999189,-2.7200000286,-2.5625000000,-34.2500000000,1.0625000000,-34.7500000000,-8.5625000000,65.0625000000 +2025-01-07T17:46:00.089920,9.9799995422,1.4599999189,-2.9900000095,-0.0625000000,-6.2500000000,0.5000000000,-31.5625000000,-10.3750000000,71.3750000000 +2025-01-07T17:46:00.171093,8.9899997711,1.4399999380,-2.7100000381,-0.3125000000,-6.4375000000,1.1250000000,-31.1875000000,-12.0000000000,74.5000000000 +2025-01-07T17:46:00.254163,9.5399999619,1.5000000000,-3.1399998665,-0.2500000000,-6.8125000000,0.4375000000,-30.3750000000,-12.6875000000,75.6875000000 +2025-01-07T17:46:00.331684,8.8599996567,1.4199999571,-2.9600000381,-1.8750000000,-13.5000000000,3.9375000000,-29.0625000000,-13.3750000000,76.7500000000 +2025-01-07T17:46:00.407442,8.9099998474,1.3199999332,-3.3699998856,-6.0000000000,-32.1875000000,6.5000000000,-26.6875000000,-13.3750000000,79.0000000000 +2025-01-07T17:46:00.487482,8.9399995804,1.3899999857,-3.8699998856,-5.1875000000,-44.5625000000,5.4375000000,-22.5000000000,-14.8750000000,81.0000000000 +2025-01-07T17:46:00.570736,8.3099994659,1.2099999189,-4.4699997902,-4.5000000000,-48.8125000000,10.2500000000,-12.5000000000,-16.2500000000,84.0625000000 +2025-01-07T17:46:00.647384,8.5699996948,0.8299999833,-4.5399999619,-10.6875000000,-50.9375000000,9.6875000000,-4.8750000000,-16.5625000000,85.7500000000 +2025-01-07T17:46:00.727021,8.3899993896,0.7899999619,-5.2999997139,-3.5625000000,-47.8750000000,4.8125000000,9.0000000000,-17.7500000000,80.1875000000 +2025-01-07T17:46:00.806000,7.8899998665,0.9099999666,-5.9800000191,1.1875000000,-23.5000000000,1.9375000000,22.5625000000,-16.8750000000,77.2500000000 +2025-01-07T17:46:00.885593,6.6900000572,1.0299999714,-6.1599998474,1.8750000000,-51.4375000000,8.6875000000,25.0625000000,-18.1875000000,78.1875000000 +2025-01-07T17:46:00.960687,-2.5899999142,-1.4800000191,1.8199999332,-3.3125000000,-54.6875000000,20.8125000000,36.3750000000,-15.0625000000,65.0000000000 +2025-01-07T17:46:01.045338,11.3099994659,-1.6599999666,-6.1299996376,25.8125000000,28.1875000000,-27.3750000000,40.7500000000,-9.3750000000,41.0000000000 +2025-01-07T17:46:01.126177,8.2200002670,0.3700000048,-4.6399998665,4.7500000000,34.8750000000,6.3125000000,40.8750000000,-12.6875000000,66.6875000000 +2025-01-07T17:46:01.204133,8.2100000381,0.4699999988,-5.1300001144,1.6875000000,6.7500000000,-0.5625000000,35.0625000000,-11.5625000000,71.5625000000 +2025-01-07T17:46:01.291573,8.6300001144,0.6399999857,-4.8699998856,0.6875000000,14.4375000000,-0.1250000000,34.1875000000,-12.8750000000,70.5000000000 +2025-01-07T17:46:01.368369,8.7100000381,0.5099999905,-4.6500000954,0.1875000000,10.0000000000,1.3125000000,31.7500000000,-10.5625000000,70.6875000000 +2025-01-07T17:46:01.449193,8.6499996185,0.4599999785,-4.7300000191,-0.0625000000,14.5625000000,-0.5000000000,30.8750000000,-11.2500000000,71.1875000000 +2025-01-07T17:46:01.527374,8.8299999237,0.4699999988,-4.4400000572,2.0000000000,25.9375000000,1.1250000000,27.5625000000,-10.5000000000,70.8750000000 +2025-01-07T17:46:01.605288,9.0199995041,0.4199999869,-4.0399999619,2.6250000000,34.4375000000,2.3125000000,26.0000000000,-9.6875000000,72.7500000000 +2025-01-07T17:46:01.686228,9.1700000763,0.2199999988,-3.2599999905,1.1875000000,36.3125000000,2.9375000000,19.3750000000,-9.6875000000,74.0625000000 +2025-01-07T17:46:01.761837,9.3599996567,0.4299999774,-3.1499998569,2.1250000000,55.0000000000,4.5625000000,12.0625000000,-8.0000000000,73.6875000000 +2025-01-07T17:46:01.841921,10.1099996567,-1.3199999332,4.2699999809,12.3125000000,46.5625000000,6.0625000000,8.5000000000,-6.1875000000,73.7500000000 +2025-01-07T17:46:01.917382,9.6300001144,1.6899999380,-3.1899998188,-2.6875000000,-6.6250000000,-4.1250000000,2.0625000000,-4.5000000000,70.2500000000 +2025-01-07T17:46:01.999825,9.4600000381,0.1400000006,-2.0799999237,0.5000000000,0.0000000000,-0.7500000000,0.7500000000,-5.2500000000,71.3750000000 +2025-01-07T17:46:02.080103,9.5699996948,0.3999999762,-2.2300000191,0.1250000000,-0.9375000000,-1.2500000000,2.1875000000,-5.8750000000,70.7500000000 +2025-01-07T17:46:02.153738,9.5799999237,0.3400000036,-2.1299998760,0.1875000000,0.0625000000,0.0000000000,1.3750000000,-5.2500000000,72.7500000000 +2025-01-07T17:46:02.231339,9.4799995422,0.3700000048,-2.1699998379,0.0000000000,0.1250000000,-0.1875000000,2.1875000000,-5.5625000000,72.0000000000 +2025-01-07T17:46:02.317104,9.5699996948,0.3999999762,-2.1599998474,0.1250000000,-0.1250000000,-0.5625000000,2.1875000000,-5.5625000000,71.8750000000 +2025-01-07T17:46:02.392925,9.4499998093,0.3899999857,-2.2500000000,-0.1250000000,-0.2500000000,-0.6250000000,1.3750000000,-5.2500000000,70.3750000000 +2025-01-07T17:46:02.473786,9.5699996948,0.3599999845,-2.1599998474,0.1250000000,0.0625000000,-0.4375000000,1.7500000000,-5.5625000000,72.7500000000 +2025-01-07T17:46:02.553734,9.4799995422,0.3799999952,-2.1699998379,0.0000000000,-0.1250000000,-0.3125000000,1.7500000000,-5.5625000000,71.5625000000 +2025-01-07T17:46:02.634969,9.5500001907,0.4199999869,-2.2200000286,-0.1250000000,-0.3750000000,-0.6875000000,2.7500000000,-5.5625000000,72.0000000000 +2025-01-07T17:46:02.715675,9.6199998856,0.3799999952,-2.2000000477,0.1875000000,-0.1875000000,0.0625000000,1.7500000000,-5.2500000000,71.8750000000 +2025-01-07T17:46:02.794275,9.4499998093,0.3599999845,-2.1199998856,-0.1250000000,0.1875000000,0.1875000000,1.7500000000,-5.5625000000,71.5625000000 +2025-01-07T17:46:02.874033,9.5699996948,0.4099999964,-2.2000000477,0.1875000000,-0.1250000000,-0.0625000000,2.7500000000,-5.2500000000,71.1875000000 +2025-01-07T17:46:02.955848,9.5000000000,0.4099999964,-2.2100000381,-0.1250000000,-0.1250000000,-0.3125000000,1.7500000000,-5.2500000000,72.7500000000 +2025-01-07T17:46:03.035083,9.7199993134,0.3899999857,-2.2500000000,0.0625000000,-0.2500000000,-0.6250000000,1.3750000000,-5.8750000000,72.3750000000 +2025-01-07T17:46:03.113076,9.5000000000,0.3700000048,-2.2000000477,0.0000000000,0.1875000000,-0.0625000000,2.7500000000,-4.8750000000,72.0000000000 +2025-01-07T17:46:03.195464,9.5000000000,0.4499999881,-2.1099998951,0.2500000000,-0.2500000000,-0.2500000000,2.1875000000,-5.8750000000,72.7500000000 +2025-01-07T17:46:03.274580,9.5500001907,0.3599999845,-2.1900000572,-0.0625000000,0.0625000000,-0.2500000000,1.7500000000,-5.8750000000,71.5625000000 +2025-01-07T17:46:03.359508,9.6399993896,0.4099999964,-2.2000000477,0.1250000000,-0.1875000000,0.6250000000,2.1875000000,-5.2500000000,71.8750000000 +2025-01-07T17:46:03.438249,9.4699993134,0.4399999976,-2.2300000191,-0.1250000000,0.1875000000,0.1875000000,2.1875000000,-4.8750000000,71.5625000000 +2025-01-07T17:46:03.517519,9.6499996185,0.2800000012,-2.1599998474,-0.1250000000,0.3750000000,0.1250000000,1.3750000000,-6.6875000000,71.1875000000 +2025-01-07T17:46:03.599221,9.3999996185,0.5399999619,-2.1900000572,0.0000000000,0.0625000000,-0.0625000000,1.0000000000,-5.8750000000,72.7500000000 +2025-01-07T17:46:03.677903,9.5699996948,0.2500000000,-2.0399999619,0.0625000000,0.3750000000,0.0625000000,2.1875000000,-5.2500000000,71.8750000000 +2025-01-07T17:46:03.755578,9.4099998474,0.3999999762,-2.2000000477,0.2500000000,0.0625000000,0.3750000000,2.5000000000,-5.5625000000,70.7500000000 +2025-01-07T17:46:03.838712,9.5000000000,0.3899999857,-2.0699999332,-0.0625000000,-0.0625000000,-0.2500000000,1.7500000000,-5.8750000000,71.5625000000 +2025-01-07T17:46:03.917275,9.4799995422,0.3899999857,-2.2500000000,0.0625000000,0.0625000000,0.2500000000,1.7500000000,-5.8750000000,71.8750000000 +2025-01-07T17:46:03.997955,9.5199995041,0.4699999988,-2.1299998760,0.0625000000,-0.2500000000,-0.1875000000,2.1875000000,-6.2500000000,72.7500000000 +2025-01-07T17:46:04.080224,9.5299997330,0.3100000024,-2.2000000477,0.0625000000,0.0625000000,0.1875000000,1.7500000000,-4.8750000000,72.7500000000 +2025-01-07T17:46:04.158528,9.4799995422,0.4599999785,-2.1099998951,0.0625000000,-0.1875000000,-0.0625000000,1.7500000000,-5.8750000000,73.0625000000 +2025-01-07T17:46:04.240572,9.5399999619,0.4099999964,-2.2500000000,0.0000000000,-0.1250000000,0.0000000000,2.5000000000,-6.2500000000,72.0000000000 \ No newline at end of file diff --git a/sim/resources/zbot2/joints.py b/sim/resources/zbot2/joints.py new file mode 100644 index 00000000..cf7f24da --- /dev/null +++ b/sim/resources/zbot2/joints.py @@ -0,0 +1,261 @@ +"""Provides a Pythonic interface for referencing joint names from the given MuJoCo XML. + +Organizes them by sub-assembly (arms, legs) and defines convenient methods for +defaults, limits, etc. +""" + +import textwrap +from abc import ABC +from typing import Dict, List, Tuple, Union + + +class Node(ABC): + @classmethod + def children(cls) -> List["Union[Node, str]"]: + # Returns Node or string attributes (for recursion) + return [ + attr + for attr in (getattr(cls, attr_name) for attr_name in dir(cls) if not attr_name.startswith("__")) + if isinstance(attr, (Node, str)) + ] + + @classmethod + def joints(cls) -> List[str]: + # Returns only the attributes that are strings (i.e., joint names). + return [ + attr + for attr in (getattr(cls, attr_name) for attr_name in dir(cls) if not attr_name.startswith("__")) + if isinstance(attr, str) + ] + + @classmethod + def joints_motors(cls) -> List[Tuple[str, str]]: + # Returns pairs of (attribute_name, joint_string) + joint_names: List[Tuple[str, str]] = [] + for attr_name in dir(cls): + if not attr_name.startswith("__"): + attr = getattr(cls, attr_name) + if isinstance(attr, str): + joint_names.append((attr_name, attr)) + return joint_names + + @classmethod + def all_joints(cls) -> List[str]: + # Recursively collect all string joint names + leaves = cls.joints() + for child in cls.children(): + if isinstance(child, Node): + leaves.extend(child.all_joints()) + return leaves + + def __str__(self) -> str: + # Pretty-print the hierarchy + parts = [str(child) for child in self.children() if isinstance(child, Node)] + parts_str = textwrap.indent("\n" + "\n".join(parts), " ") if parts else "" + return f"[{self.__class__.__name__}]{parts_str}" + + +# ----- Define the sub-assemblies --------------------------------------------------- +class RightArm(Node): + shoulder_yaw = "right_shoulder_yaw" + shoulder_pitch = "right_shoulder_pitch" + elbow_yaw = "right_elbow_yaw" + gripper = "right_gripper" + + +class LeftArm(Node): + shoulder_yaw = "left_shoulder_yaw" + shoulder_pitch = "left_shoulder_pitch" + elbow_yaw = "left_elbow_yaw" + gripper = "left_gripper" + + +class Arms(Node): + right = RightArm() + left = LeftArm() + + +class RightLeg(Node): + hip_roll = "R_Hip_Roll" + hip_yaw = "R_Hip_Yaw" + hip_pitch = "R_Hip_Pitch" + knee_pitch = "R_Knee_Pitch" + ankle_pitch = "R_Ankle_Pitch" + + +class LeftLeg(Node): + hip_roll = "L_Hip_Roll" + hip_yaw = "L_Hip_Yaw" + hip_pitch = "L_Hip_Pitch" + knee_pitch = "L_Knee_Pitch" + ankle_pitch = "L_Ankle_Pitch" + + +class Legs(Node): + right = RightLeg() + left = LeftLeg() + + +class Robot(Node): + legs = Legs() + + height = 0.40 + standing_height = 0.407 + rotation = [0, 0, 0, 1.0] + + @classmethod + def default_walking(cls) -> Dict[str, float]: + """Example angles for a nominal 'standing' pose. Adjust as needed.""" + return { + # Left Leg + cls.legs.left.hip_roll: 0.0, + cls.legs.left.hip_yaw: 0.0, + cls.legs.left.hip_pitch: -0.377, + cls.legs.left.knee_pitch: 0.796, + cls.legs.left.ankle_pitch: 0.377, + # Right Leg + cls.legs.right.hip_roll: 0.0, + cls.legs.right.hip_yaw: 0.0, + cls.legs.right.hip_pitch: 0.377, + cls.legs.right.knee_pitch: -0.796, + cls.legs.right.ankle_pitch: -0.377, + } + + @classmethod + def default_standing(cls) -> Dict[str, float]: + """Example angles for a nominal 'standing' pose. Adjust as needed.""" + return { + # Left Leg + cls.legs.left.hip_roll: 0.0, + cls.legs.left.hip_yaw: 0.0, + cls.legs.left.hip_pitch: 0.0, + cls.legs.left.knee_pitch: 0.0, + cls.legs.left.ankle_pitch: 0.0, + # Right Leg + cls.legs.right.hip_roll: 0.0, + cls.legs.right.hip_yaw: 0.0, + cls.legs.right.hip_pitch: 0.0, + cls.legs.right.knee_pitch: 0.0, + cls.legs.right.ankle_pitch: 0.0, + } + + # CONTRACT - this should be ordered according to how the policy is trained. + # E.g. the first entry should be the name of the first joint in the policy. + @classmethod + def joint_names(cls) -> List[str]: + return list(cls.default_standing().keys()) + + @classmethod + def default_limits(cls) -> Dict[str, Dict[str, float]]: + """Minimal example of per-joint limits. + + You can refine these to match your MJCF's 'range' tags or real specs. + """ + return { + # Left side + cls.legs.left.hip_roll: {"lower": -0.9, "upper": 0.9}, + cls.legs.left.hip_yaw: {"lower": -0.9, "upper": 0.9}, + cls.legs.left.hip_pitch: {"lower": -0.9, "upper": 0.9}, + cls.legs.left.knee_pitch: {"lower": -0.9, "upper": 0.9}, + cls.legs.left.ankle_pitch: {"lower": -0.9, "upper": 0.9}, + # Right side + cls.legs.right.hip_roll: {"lower": -0.9, "upper": 0.9}, + cls.legs.right.hip_yaw: {"lower": -0.9, "upper": 0.9}, + cls.legs.right.hip_pitch: {"lower": -0.9, "upper": 0.9}, + cls.legs.right.knee_pitch: {"lower": -0.9, "upper": 0.9}, + cls.legs.right.ankle_pitch: {"lower": -0.9, "upper": 0.9}, + } + + # p_gains + @classmethod + def stiffness(cls) -> Dict[str, float]: + return { + "Hip_Pitch": 17.68, + "Hip_Yaw": 17.68, + "Hip_Roll": 17.68, + "Knee_Pitch": 17.68, + "Ankle_Pitch": 17.68, + } + + # d_gains + @classmethod + def damping(cls) -> Dict[str, float]: + return { + "Hip_Pitch": 0.53, + "Hip_Yaw": 0.53, + "Hip_Roll": 0.53, + "Knee_Pitch": 0.53, + "Ankle_Pitch": 0.53, + } + + @classmethod + def effort(cls) -> Dict[str, float]: + return { + "Hip_Pitch": 3.0, + "Hip_Yaw": 3.0, + "Hip_Roll": 3.0, + "Knee_Pitch": 3.0, + "Ankle_Pitch": 3.0, + } + + # vel_limits + @classmethod + def velocity(cls) -> Dict[str, float]: + return { + "Hip_Pitch": 10, + "Hip_Yaw": 10, + "Hip_Roll": 10, + "Knee_Pitch": 10, + "Ankle_Pitch": 10, + } + + @classmethod + def friction(cls) -> Dict[str, float]: + """Example friction dictionary for certain joints.""" + # Usually you'd have more specific friction values or a model. + return { + cls.legs.left.ankle_pitch: 0.01, + cls.legs.right.ankle_pitch: 0.01, + # etc... + } + + @classmethod + def effort_mapping(cls) -> Dict[str, float]: + mapping = {} + effort = cls.effort() + for side in ["left", "right"]: + for joint, value in effort.items(): + mapping[f"{side}_{joint}"] = value + return mapping + + @classmethod + def stiffness_mapping(cls) -> Dict[str, float]: + mapping = {} + stiffness = cls.stiffness() + for side in ["left", "right"]: + for joint, value in stiffness.items(): + mapping[f"{side}_{joint}"] = value + return mapping + + @classmethod + def damping_mapping(cls) -> Dict[str, float]: + mapping = {} + damping = cls.damping() + for side in ["left", "right"]: + for joint, value in damping.items(): + mapping[f"{side}_{joint}"] = value + return mapping + + +def print_joints() -> None: + # Gather all joints and check for duplicates + joints_list = Robot.all_joints() + assert len(joints_list) == len(set(joints_list)), "Duplicate joint names found!" + + # Print out the structure for debugging + print(Robot()) + print("\nAll Joints:", joints_list) + + +if __name__ == "__main__": + print_joints() diff --git a/sim/resources/zbot2/meshes/FINGER_1.stl b/sim/resources/zbot2/meshes/FINGER_1.stl new file mode 100644 index 00000000..3a671e16 Binary files /dev/null and b/sim/resources/zbot2/meshes/FINGER_1.stl differ diff --git a/sim/resources/zbot2/meshes/FINGER_1_2.stl b/sim/resources/zbot2/meshes/FINGER_1_2.stl new file mode 100644 index 00000000..c5c251e6 Binary files /dev/null and b/sim/resources/zbot2/meshes/FINGER_1_2.stl differ diff --git a/sim/resources/zbot2/meshes/FK-AP-019-25T_11.stl b/sim/resources/zbot2/meshes/FK-AP-019-25T_11.stl new file mode 100644 index 00000000..c880b008 Binary files /dev/null and b/sim/resources/zbot2/meshes/FK-AP-019-25T_11.stl differ diff --git a/sim/resources/zbot2/meshes/FK-AP-019-25T_11_2.stl b/sim/resources/zbot2/meshes/FK-AP-019-25T_11_2.stl new file mode 100644 index 00000000..5a156ab2 Binary files /dev/null and b/sim/resources/zbot2/meshes/FK-AP-019-25T_11_2.stl differ diff --git a/sim/resources/zbot2/meshes/FOOT.stl b/sim/resources/zbot2/meshes/FOOT.stl new file mode 100644 index 00000000..64b39838 Binary files /dev/null and b/sim/resources/zbot2/meshes/FOOT.stl differ diff --git a/sim/resources/zbot2/meshes/FOOT_2.stl b/sim/resources/zbot2/meshes/FOOT_2.stl new file mode 100644 index 00000000..86035664 Binary files /dev/null and b/sim/resources/zbot2/meshes/FOOT_2.stl differ diff --git a/sim/resources/zbot2/meshes/L-ARM_1.stl b/sim/resources/zbot2/meshes/L-ARM_1.stl new file mode 100644 index 00000000..9802e3a1 Binary files /dev/null and b/sim/resources/zbot2/meshes/L-ARM_1.stl differ diff --git a/sim/resources/zbot2/meshes/R-ARM-1.stl b/sim/resources/zbot2/meshes/R-ARM-1.stl new file mode 100644 index 00000000..7b2318c4 Binary files /dev/null and b/sim/resources/zbot2/meshes/R-ARM-1.stl differ diff --git a/sim/resources/zbot2/meshes/U-HIP-L.stl b/sim/resources/zbot2/meshes/U-HIP-L.stl new file mode 100644 index 00000000..ef0beac6 Binary files /dev/null and b/sim/resources/zbot2/meshes/U-HIP-L.stl differ diff --git a/sim/resources/zbot2/meshes/U-HIP-R.stl b/sim/resources/zbot2/meshes/U-HIP-R.stl new file mode 100644 index 00000000..87f45b16 Binary files /dev/null and b/sim/resources/zbot2/meshes/U-HIP-R.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7.stl new file mode 100644 index 00000000..9d55058a Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_2.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_2.stl new file mode 100644 index 00000000..3ad00b5d Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_2.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_3.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_3.stl new file mode 100644 index 00000000..4eed15d4 Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_3.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_4.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_4.stl new file mode 100644 index 00000000..33adf2b8 Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_4.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_5.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_5.stl new file mode 100644 index 00000000..11a4e6cb Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_5.stl differ diff --git a/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_6.stl b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_6.stl new file mode 100644 index 00000000..da1939f0 Binary files /dev/null and b/sim/resources/zbot2/meshes/WJ-DP00-0002-FK-AP-020_7_6.stl differ diff --git a/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2.stl b/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2.stl new file mode 100644 index 00000000..bf3fc57b Binary files /dev/null and b/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2.stl differ diff --git a/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2_2.stl b/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2_2.stl new file mode 100644 index 00000000..f6e73e6f Binary files /dev/null and b/sim/resources/zbot2/meshes/Z-BOT2-MASTER-SHOULDER2_2.stl differ diff --git a/sim/resources/zbot2/meshes/Z-BOT2_MASTER-BODY-SKELETON.stl b/sim/resources/zbot2/meshes/Z-BOT2_MASTER-BODY-SKELETON.stl new file mode 100644 index 00000000..779c7f58 Binary files /dev/null and b/sim/resources/zbot2/meshes/Z-BOT2_MASTER-BODY-SKELETON.stl differ diff --git a/sim/resources/zbot2/robot.urdf b/sim/resources/zbot2/robot.urdf new file mode 100644 index 00000000..356697ed --- /dev/null +++ b/sim/resources/zbot2/robot.urdf @@ -0,0 +1,558 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sim/resources/zbot2/robot.xml b/sim/resources/zbot2/robot.xml new file mode 100644 index 00000000..157518eb --- /dev/null +++ b/sim/resources/zbot2/robot.xml @@ -0,0 +1,272 @@ + + \ No newline at end of file diff --git a/sim/resources/zbot2/robot_fixed.urdf b/sim/resources/zbot2/robot_fixed.urdf new file mode 100644 index 00000000..58b60276 --- /dev/null +++ b/sim/resources/zbot2/robot_fixed.urdf @@ -0,0 +1,558 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/sim/resources/zbot2/robot_fixed.xml b/sim/resources/zbot2/robot_fixed.xml new file mode 100644 index 00000000..8ebc0cb3 --- /dev/null +++ b/sim/resources/zbot2/robot_fixed.xml @@ -0,0 +1,210 @@ + + \ No newline at end of file diff --git a/sim/scripts/imu_data_comparison.py b/sim/scripts/imu_data_comparison.py new file mode 100644 index 00000000..8487a2a5 --- /dev/null +++ b/sim/scripts/imu_data_comparison.py @@ -0,0 +1,245 @@ +"""Testing the falling down IMU data comparison. + +Run: + python sim/scripts/imu_data_comparison.py --embodiment zbot2 +""" +import argparse +import os +from copy import deepcopy + +import matplotlib.pyplot as plt +import mediapy as media +import mujoco +import mujoco_viewer +import numpy as np +import pandas as pd +from tqdm import tqdm + + +def plot_comparison(sim_data: pd.DataFrame, real_data: pd.DataFrame) -> None: + """Plot the real IMU data. + + Args: + sim_data: The simulated IMU data. + real_data: The real IMU data. + """ + plt.figure(figsize=(10, 15)) + + + real_timestamps = (real_data['timestamp'] - real_data['timestamp'].iloc[0]).dt.total_seconds().to_numpy() + + # Accelerometer plots + plt.subplot(6, 1, 1) + plt.plot(real_timestamps, sim_data['accel_x'].to_numpy(), label='Simulated Accel X') + plt.plot(real_timestamps, real_data['accel_x'].to_numpy(), label='Real Accel X') + plt.title('Accelerometer X') + plt.legend() + + plt.subplot(6, 1, 2) + plt.plot(real_timestamps, sim_data['accel_y'].to_numpy(), label='Simulated Accel Y') + plt.plot(real_timestamps, real_data['accel_y'].to_numpy(), label='Real Accel Y') + plt.title('Accelerometer Y') + plt.legend() + + plt.subplot(6, 1, 3) + plt.plot(real_timestamps, sim_data['accel_z'].to_numpy(), label='Simulated Accel Z') + plt.plot(real_timestamps, real_data['accel_z'].to_numpy(), label='Real Accel Z') + plt.title('Accelerometer Z') + plt.legend() + + # Gyroscope plots + plt.subplot(6, 1, 4) + plt.plot(real_timestamps, sim_data['gyro_x'].to_numpy(), label='Simulated Gyro X') + plt.plot(real_timestamps, real_data['gyro_x'].to_numpy(), label='Real Gyro X') + plt.title('Gyroscope X') + plt.legend() + + plt.subplot(6, 1, 5) + plt.plot(real_timestamps, sim_data['gyro_y'].to_numpy(), label='Simulated Gyro Y') + plt.plot(real_timestamps, real_data['gyro_y'].to_numpy(), label='Real Gyro Y') + plt.title('Gyroscope Y') + plt.legend() + + plt.subplot(6, 1, 6) + plt.plot(real_timestamps, sim_data['gyro_z'].to_numpy(), label='Simulated Gyro Z') + plt.plot(real_timestamps, real_data['gyro_z'].to_numpy(), label='Real Gyro Z') + plt.title('Gyroscope Z') + plt.legend() + + plt.tight_layout() + plt.savefig('imu_data_comparison.png') + + +def read_real_data(data_file: str = "sim/resources/zbot2/imu_data.csv") -> None: + """Plot the real IMU data. + + Args: + data_file: The path to the real IMU data file. + + Returns: + The real IMU data. + """ + # Reading the data from CSV file + df = pd.read_csv(data_file) + + df = df.apply(pd.to_numeric, errors='ignore') + df['timestamp'] = pd.to_datetime(df['timestamp']) + + return df + + +def pd_control( + target_q: np.ndarray, + q: np.ndarray, + kp: np.ndarray, + dq: np.ndarray, + kd: np.ndarray, + default: np.ndarray, +) -> np.ndarray: + """Calculates torques from position commands + + Args: + target_q: The target position. + q: The current position. + kp: The proportional gain. + dq: The current velocity. + kd: The derivative gain. + default: The default position. + + Returns: + The calculated torques. + """ + return kp * (target_q + default - q) - kd * dq + + +def run_simulation( + embodiment: str, + kp: float = 1.0, + kd: float = 1.0, + sim_duration: float = 15.0, + effort: float = 5.0, +) -> None: + """ + Run the Mujoco simulation using the provided policy and configuration. + + Args: + embodiment: The embodiment to use for the simulation. + kp: The proportional gain. + kd: The derivative gain. + sim_duration: The duration of the simulation. + effort: The effort to apply to the robot. + """ + model_info = { + "sim_dt": 0.001, + "tau_factor": 2, + "num_actions": 10, + "num_observations": 10, + "robot_effort": [effort] * 10, + "robot_stiffness": [kp] * 10, + "robot_damping": [kd] * 10, + } + frames = [] + framerate = 30 + model_dir = os.environ.get("MODEL_DIR", "sim/resources") + mujoco_model_path = f"{model_dir}/{embodiment}/robot_fixed.xml" + + model = mujoco.MjModel.from_xml_path(mujoco_model_path) + model.opt.timestep = model_info["sim_dt"] + data = mujoco.MjData(model) + + tau_limit = np.array(list(model_info["robot_effort"])) * model_info["tau_factor"] + kps = np.array(model_info["robot_stiffness"]) + kds = np.array(model_info["robot_damping"]) + + data.qpos = model.keyframe("standing").qpos + default = deepcopy(model.keyframe("standing").qpos)[-model_info["num_actions"] :] + print("Default position:", default) + + target_q = np.zeros((model_info["num_actions"]), dtype=np.double) + viewer = mujoco_viewer.MujocoViewer(model, data,"offscreen") + + force_duration = 400 # Duration of force application in timesteps + force_timer = 0 + + applied_force = np.array([0.0, -3, 0.0]) + + sim_data = { + "timestamp": [], + "gyro_x": [], + "gyro_y": [], + "gyro_z": [], + "accel_x": [], + "accel_y": [], + "accel_z": [], + "mag_x": [], + "mag_y": [], + "mag_z": [], + } + + for timestep in tqdm(range(int(sim_duration / model_info["sim_dt"])), desc="Simulating..."): + if timestep == 500: + force_timer = force_duration + if timestep % 10 == 0: + # Keep the robot in the same position + q = data.qpos.astype(np.double)[-model_info["num_actions"] :] + dq = data.qvel.astype(np.double)[-model_info["num_actions"] :] + tau = pd_control(target_q, q, kps, dq, kds, default) # Calc torques + tau = np.clip(tau, -tau_limit, tau_limit) # Clamp torques + data.ctrl = tau + mujoco.mj_step(model, data) + if timestep % 100 == 0: + img = viewer.read_pixels() + frames.append(img) + + # Obtain an observation + gyroscope = data.sensor("angular-velocity").data.astype(np.double) + accelerometer = data.sensor("linear-acceleration").data.astype(np.double) + magnetometer = data.sensor("magnetometer").data.astype(np.double) + + sim_data["timestamp"].append(timestep * model_info["sim_dt"]) + sim_data["gyro_x"].append(gyroscope[0]) + sim_data["gyro_y"].append(gyroscope[1]) + sim_data["gyro_z"].append(gyroscope[2]) + sim_data["accel_x"].append(accelerometer[0]) + sim_data["accel_y"].append(accelerometer[1]) + sim_data["accel_z"].append(accelerometer[2]) + sim_data["mag_x"].append(magnetometer[0]) + sim_data["mag_y"].append(magnetometer[1]) + sim_data["mag_z"].append(magnetometer[2]) + + if timestep == 12680: + break + + if force_timer > 0: + # Apply force if timer is active + if force_timer > 0: + data.xfrc_applied[1] = np.concatenate([applied_force, np.zeros(3)]) + force_timer -= 1 + else: + data.xfrc_applied[1] = np.zeros(6) + + media.write_video("push_tests.mp4", frames, fps=framerate) + + # sim_data["timestamp"] = np.array(sim_data["timestamp"]) + # sim_data["gyro_x"] = np.array(sim_data["gyro_x"]) + # sim_data["gyro_y"] = np.array(sim_data["gyro_y"]) + # sim_data["gyro_z"] = np.array(sim_data["gyro_z"]) + # sim_data["accel_x"] = np.array(sim_data["accel_x"]) + # sim_data["accel_y"] = np.array(sim_data["accel_y"]) + # sim_data["accel_z"] = np.array(sim_data["accel_z"]) + # sim_data["mag_x"] = np.array(sim_data["mag_x"]) + # sim_data["mag_y"] = np.array(sim_data["mag_y"]) + # sim_data["mag_z"] = np.array(sim_data["mag_z"]) + return pd.DataFrame(sim_data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deployment script.") + parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.") + parser.add_argument("--kp", type=float, default=10.0, help="Path to run to load from.") + parser.add_argument("--kd", type=float, default=1.0, help="Path to run to load from.") + args = parser.parse_args() + + sim_data = run_simulation(args.embodiment, args.kp, args.kd) + real_data = read_real_data() + plot_comparison(sim_data, real_data) \ No newline at end of file diff --git a/sim/scripts/push_standing_tests.py b/sim/scripts/push_standing_tests.py new file mode 100644 index 00000000..41208c24 --- /dev/null +++ b/sim/scripts/push_standing_tests.py @@ -0,0 +1,136 @@ +"""Id and standing test. + +Run: + python sim/scripts/push_standing_tests.py --load_model kinfer.onnx --embodiment zbot2 +""" +import argparse +import os +from copy import deepcopy + +import mediapy as media +import mujoco +import mujoco_viewer +import numpy as np +from tqdm import tqdm + + +def pd_control( + target_q: np.ndarray, + q: np.ndarray, + kp: np.ndarray, + dq: np.ndarray, + kd: np.ndarray, + default: np.ndarray, +) -> np.ndarray: + """Calculates torques from position commands""" + return kp * (target_q + default - q) - kd * dq + + +def run_test( + embodiment: str, + kp: float = 1.0, + kd: float = 1.0, + push_force: float = 1.0, + sim_duration: float = 3.0, + effort: float = 5.0, +) -> None: + """ + Run the Mujoco simulation using the provided policy and configuration. + + Args: + policy: The policy used for controlling the simulation. + cfg: The configuration object containing simulation settings. + """ + model_info = { + "sim_dt": 0.001, + "tau_factor": 2, + "num_actions": 10, + "num_observations": 10, + "robot_effort": [effort] * 10, + "robot_stiffness": [kp] * 10, + "robot_damping": [kd] * 10, + } + frames = [] + framerate = 30 + model_dir = os.environ.get("MODEL_DIR", "sim/resources") + mujoco_model_path = f"{model_dir}/{embodiment}/robot_fixed.xml" + + model = mujoco.MjModel.from_xml_path(mujoco_model_path) + model.opt.timestep = model_info["sim_dt"] + data = mujoco.MjData(model) + + tau_limit = np.array(list(model_info["robot_effort"])) * model_info["tau_factor"] + kps = np.array(model_info["robot_stiffness"]) + kds = np.array(model_info["robot_damping"]) + print(kps) + print(kds) + print(tau_limit) + + data.qpos = model.keyframe("standing").qpos + default = deepcopy(model.keyframe("standing").qpos)[-model_info["num_actions"] :] + print("Default position:", default) + + mujoco.mj_step(model, data) + for ii in range(len(data.ctrl) + 1): + print(data.joint(ii).id, data.joint(ii).name) + + data.qvel = np.zeros_like(data.qvel) + data.qacc = np.zeros_like(data.qacc) + + target_q = np.zeros((model_info["num_actions"]), dtype=np.double) + viewer = mujoco_viewer.MujocoViewer(model, data,"offscreen") + + force_application_interval = 1000 # Apply force every 1000 steps (1 second at 1000Hz) + force_magnitude_range = (-push_force, push_force) # Force range in Newtons + force_duration = 100 # Duration of force application in timesteps + force_timer = 0 + + + for timestep in tqdm(range(int(sim_duration / model_info["sim_dt"])), desc="Simulating..."): + # Obtain an observation + q = data.qpos.astype(np.double)[-model_info["num_actions"] :] + dq = data.qvel.astype(np.double)[-model_info["num_actions"] :] + + # Generate PD control + tau = pd_control(target_q, q, kps, dq, kds, default) # Calc torques + tau = np.clip(tau, -tau_limit, tau_limit) # Clamp torques + + data.ctrl = tau + + # Apply random forces periodically + if timestep % force_application_interval == 0: + print("Applying force") + # Generate random force vector + random_force = np.random.uniform( + force_magnitude_range[0], + force_magnitude_range[1], + size=3 + ) + force_timer = force_duration + + # Apply force if timer is active + if force_timer > 0: + data.xfrc_applied[1] = np.concatenate([random_force, np.zeros(3)]) + force_timer -= 1 + else: + data.xfrc_applied[1] = np.zeros(6) + + mujoco.mj_step(model, data) + if timestep % 100 == 0: + img = viewer.read_pixels() + frames.append(img) + + # viewer.render() + breakpoint() + media.write_video("push_tests.mp4", frames, fps=framerate) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deployment script.") + parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.") + parser.add_argument("--kp", type=float, default=17.0, help="Path to run to load from.") + parser.add_argument("--kd", type=float, default=1.0, help="Path to run to load from.") + parser.add_argument("--push_force", type=float, default=1.0, help="Path to run to load from.") + args = parser.parse_args() + + run_test(args.embodiment, args.kp, args.kd, args.push_force) diff --git a/sim/sim2sim2.py b/sim/sim2sim2.py new file mode 100755 index 00000000..8b549e60 --- /dev/null +++ b/sim/sim2sim2.py @@ -0,0 +1,329 @@ +"""Sim2sim deployment test. + +Run: + python sim/sim2sim2.py --load_model examples/gpr_walking.kinfer --embodiment gpr + python sim/sim2sim2.py --load_model kinfer_policy.onnx --embodiment zbot2 +""" + +import argparse +import math +import os +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import mujoco +import mujoco_viewer +import numpy as np +import onnxruntime as ort +import pygame +import torch +from kinfer.export.pytorch import export_to_onnx +from kinfer.inference.python import ONNXModel +from scipy.spatial.transform import Rotation as R +from tqdm import tqdm + + +def handle_keyboard_input() -> None: + global x_vel_cmd, y_vel_cmd, yaw_vel_cmd + + keys = pygame.key.get_pressed() + + # Update movement commands based on arrow keys + if keys[pygame.K_UP]: + x_vel_cmd += 0.0005 + if keys[pygame.K_DOWN]: + x_vel_cmd -= 0.0005 + if keys[pygame.K_LEFT]: + y_vel_cmd += 0.0005 + if keys[pygame.K_RIGHT]: + y_vel_cmd -= 0.0005 + + # Yaw control + if keys[pygame.K_a]: + yaw_vel_cmd += 0.001 + if keys[pygame.K_z]: + yaw_vel_cmd -= 0.001 + + +def quaternion_to_euler_array(quat: np.ndarray) -> np.ndarray: + # Ensure quaternion is in the correct format [x, y, z, w] + x, y, z, w = quat + + # Roll (x-axis rotation) + t0 = +2.0 * (w * x + y * z) + t1 = +1.0 - 2.0 * (x * x + y * y) + roll_x = np.arctan2(t0, t1) + + # Pitch (y-axis rotation) + t2 = +2.0 * (w * y - z * x) + t2 = np.clip(t2, -1.0, 1.0) + pitch_y = np.arcsin(t2) + + # Yaw (z-axis rotation) + t3 = +2.0 * (w * z + x * y) + t4 = +1.0 - 2.0 * (y * y + z * z) + yaw_z = np.arctan2(t3, t4) + + # Returns roll, pitch, yaw in a NumPy array in radians + return np.array([roll_x, pitch_y, yaw_z]) + + +def get_gravity_orientation(quaternion): + """ + Args: + quaternion: np.ndarray[float, float, float, float] + + Returns: + gravity_orientation: np.ndarray[float, float, float] + """ + qw = quaternion[0] + qx = quaternion[1] + qy = quaternion[2] + qz = quaternion[3] + + gravity_orientation = np.zeros(3) + + gravity_orientation[0] = 2 * (-qz * qx + qw * qy) + gravity_orientation[1] = -2 * (qz * qy + qw * qx) + gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) + + return gravity_orientation + + +def get_obs(data: mujoco.MjData) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Extracts an observation from the mujoco data structure""" + q = data.qpos.astype(np.double) + dq = data.qvel.astype(np.double) + quat = data.sensor("orientation").data[[1, 2, 3, 0]].astype(np.double) + r = R.from_quat(quat) + gvec = r.apply(np.array([0.0, 0.0, -1.0]), inverse=True).astype(np.double) + v = r.apply(data.qvel[:3], inverse=True).astype(np.double) # In the base frame + omega = data.sensor("angular-velocity").data.astype(np.double) + + # gvec = get_gravity_orientation(data.sensor("orientation").data) + return (q, dq, quat, v, omega, gvec) + + +def pd_control( + target_q: np.ndarray, + q: np.ndarray, + kp: np.ndarray, + dq: np.ndarray, + kd: np.ndarray, + default: np.ndarray, +) -> np.ndarray: + """Calculates torques from position commands""" + return kp * (target_q + default - q) - kd * dq + + +def run_mujoco( + embodiment: str, + policy: ort.InferenceSession, + model_info: Dict[str, Union[float, List[float], str]], + keyboard_use: bool = False, + log_h5: bool = False, + render: bool = True, + sim_duration: float = 60.0, + h5_out_dir: str = "sim/resources", +) -> None: + """ + Run the Mujoco simulation using the provided policy and configuration. + + Args: + policy: The policy used for controlling the simulation. + cfg: The configuration object containing simulation settings. + """ + model_dir = os.environ.get("MODEL_DIR", "sim/resources") + mujoco_model_path = f"{model_dir}/{embodiment}/robot_fixed.xml" + + model = mujoco.MjModel.from_xml_path(mujoco_model_path) + model.opt.timestep = model_info["sim_dt"] + data = mujoco.MjData(model) + + assert isinstance(model_info["num_actions"], int) + assert isinstance(model_info["num_observations"], int) + assert isinstance(model_info["robot_effort"], list) + assert isinstance(model_info["robot_stiffness"], list) + assert isinstance(model_info["robot_damping"], list) + + tau_limit = np.array(list(model_info["robot_effort"]) + list(model_info["robot_effort"])) * model_info["tau_factor"] + kps = np.array(list(model_info["robot_stiffness"]) + list(model_info["robot_stiffness"])) + kds = np.array(list(model_info["robot_damping"]) + list(model_info["robot_damping"])) + + try: + data.qpos = model.keyframe("default").qpos + default = deepcopy(model.keyframe("default").qpos)[-model_info["num_actions"] :] + print("Default position:", default) + except: + print("No default position found, using zero initialization") + default = np.zeros(model_info["num_actions"]) # 3 for pos, 4 for quat, cfg.num_actions for joints + default += np.random.uniform(-0.03, 0.03, size=default.shape) + print("Default position:", default) + mujoco.mj_step(model, data) + for ii in range(len(data.ctrl) + 1): + print(data.joint(ii).id, data.joint(ii).name) + + data.qvel = np.zeros_like(data.qvel) + data.qacc = np.zeros_like(data.qacc) + + if render: + viewer = mujoco_viewer.MujocoViewer(model, data) + + target_q = np.zeros((model_info["num_actions"]), dtype=np.double) + prev_actions = np.zeros((model_info["num_actions"]), dtype=np.double) + hist_obs = np.zeros((model_info["num_observations"]), dtype=np.double) + + count_lowlevel = 0 + + input_data = { + "x_vel.1": np.zeros(1).astype(np.float32), + "y_vel.1": np.zeros(1).astype(np.float32), + "rot.1": np.zeros(1).astype(np.float32), + "t.1": np.zeros(1).astype(np.float32), + "dof_pos.1": np.zeros(model_info["num_actions"]).astype(np.float32), + "dof_vel.1": np.zeros(model_info["num_actions"]).astype(np.float32), + "prev_actions.1": np.zeros(model_info["num_actions"]).astype(np.float32), + # "imu_ang_vel.1": np.zeros(3).astype(np.float32), + # "imu_euler_xyz.1": np.zeros(3).astype(np.float32), + "projected_gravity.1": np.zeros(3).astype(np.float32), + "buffer.1": np.zeros(model_info["num_observations"]).astype(np.float32), + } + + if log_h5: + from sim.h5_logger import HDF5Logger + + stop_state_log = int(sim_duration / model_info["sim_dt"]) / model_info["sim_decimation"] + logger = HDF5Logger( + data_name=embodiment, + num_actions=model_info["num_actions"], + max_timesteps=stop_state_log, + num_observations=model_info["num_observations"], + h5_out_dir=h5_out_dir, + ) + + # Initialize variables for tracking upright steps and average speed + upright_steps = 0 + total_speed = 0.0 + step_count = 0 + + for _ in tqdm(range(int(sim_duration / model_info["sim_dt"])), desc="Simulating..."): + if keyboard_use: + handle_keyboard_input() + + # Obtain an observation + q, dq, quat, v, omega, gvec = get_obs(data) + q = q[-model_info["num_actions"] :] + dq = dq[-model_info["num_actions"] :] + + # eu_ang = quaternion_to_euler_array(quat) + # eu_ang[eu_ang > math.pi] -= 2 * math.pi + + # eu_ang = np.array([0.0, 0.0, 0.0]) + # omega = np.array([0.0, 0.0, 0.0]) + + # Calculate speed and accumulate for average speed calculation + speed = np.linalg.norm(v[:2]) # Speed in the x-y plane + total_speed += speed + step_count += 1 + + # 1000hz -> 50hz + if count_lowlevel % model_info["sim_decimation"] == 0: + # Convert sim coordinates to policy coordinates + cur_pos_obs = q - default + cur_vel_obs = dq + + input_data["x_vel.1"] = np.array([x_vel_cmd], dtype=np.float32) + input_data["y_vel.1"] = np.array([y_vel_cmd], dtype=np.float32) + input_data["rot.1"] = np.array([yaw_vel_cmd], dtype=np.float32) + + input_data["t.1"] = np.array([count_lowlevel * model_info["sim_dt"]], dtype=np.float32) + + input_data["dof_pos.1"] = cur_pos_obs.astype(np.float32) + input_data["dof_vel.1"] = cur_vel_obs.astype(np.float32) + + input_data["prev_actions.1"] = prev_actions.astype(np.float32) + + input_data["projected_gravity.1"] = gvec.astype(np.float32) + # input_data["imu_ang_vel.1"] = omega.astype(np.float32) + # input_data["imu_euler_xyz.1"] = eu_ang.astype(np.float32) + + input_data["buffer.1"] = hist_obs.astype(np.float32) + + policy_output = policy(input_data) + positions = policy_output["actions_scaled"] + curr_actions = policy_output["actions"] + hist_obs = policy_output["x.3"] + + target_q = positions + + prev_actions = curr_actions + + # Generate PD control + tau = pd_control(target_q, q, kps, dq, kds, default) # Calc torques + tau = np.clip(tau, -tau_limit, tau_limit) # Clamp torques + + data.ctrl = tau + mujoco.mj_step(model, data) + + if render: + viewer.render() + count_lowlevel += 1 + + if render: + viewer.close() + + # Calculate average speed + if step_count > 0: + average_speed = total_speed / step_count + else: + average_speed = 0.0 + + # Save or print the statistics at the end of the episode + print(f"Number of upright steps: {upright_steps}") + print(f"Average speed: {average_speed:.4f} m/s") + + if log_h5: + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deployment script.") + parser.add_argument("--embodiment", type=str, required=True, help="Embodiment name.") + parser.add_argument("--load_model", type=str, required=True, help="Path to run to load from.") + parser.add_argument("--keyboard_use", action="store_true", help="keyboard_use") + parser.add_argument("--log_h5", action="store_true", help="log_h5") + parser.add_argument("--h5_out_dir", type=str, default="sim/resources", help="Directory to save HDF5 files") + parser.add_argument("--no_render", action="store_false", dest="render", help="Disable rendering") + parser.set_defaults(render=True) + args = parser.parse_args() + + if args.keyboard_use: + x_vel_cmd, y_vel_cmd, yaw_vel_cmd = 0.0, 0.0, 0.0 + pygame.init() + pygame.display.set_caption("Simulation Control") + else: + x_vel_cmd, y_vel_cmd, yaw_vel_cmd = 0.25, 0.0, 0.0 + + policy = ONNXModel(args.load_model) + metadata = policy.get_metadata() + model_info = { + "num_actions": metadata["num_actions"], + "num_observations": metadata["num_observations"], + "robot_effort": metadata["robot_effort"], + "robot_stiffness": metadata["robot_stiffness"], + "robot_damping": metadata["robot_damping"], + "sim_dt": metadata["sim_dt"], + "sim_decimation": metadata["sim_decimation"], + "tau_factor": metadata["tau_factor"], + } + + run_mujoco( + embodiment=args.embodiment, + policy=policy, + model_info=model_info, + keyboard_use=args.keyboard_use, + log_h5=args.log_h5, + render=args.render, + h5_out_dir=args.h5_out_dir, + )