diff --git a/flybody/tasks/task_utils.py b/flybody/tasks/task_utils.py index aa7c74b..d846360 100755 --- a/flybody/tasks/task_utils.py +++ b/flybody/tasks/task_utils.py @@ -28,12 +28,12 @@ def observable_indices_in_tensor( def wing_qpos_to_conventional(model_wing_qpos: np.ndarray, body_pitch_angle: float = 47.5, ) -> np.ndarray: - """Transform model wing qpos to conventional wing kinematics definition. + """Transform model wing joint qpos to conventional wing kinematics definition. Args: model_wing_qpos: Wing MjData.qpos in radians, shape (B, 6). - Order of angles: yaw_left, roll_left, pitch_left, - yaw_right, roll_right, pitch_right. + Order of joints: yaw, roll, pitch, yaw, roll, pitch. + Left-right order is arbitrary. body_pitch_angle: Body pitch angle for initial flight pose, relative to ground, degrees. 0: horizontal body position. Default value from https://doi.org/10.1126/science.1248955 @@ -45,14 +45,13 @@ def wing_qpos_to_conventional(model_wing_qpos: np.ndarray, model_wing_qpos = np.array(model_wing_qpos) conventional = np.zeros_like(model_wing_qpos) body_pitch_angle = np.deg2rad(body_pitch_angle) - for i in [0, 3]: - # Yaw, doesn't require transformation. - conventional[..., i] = model_wing_qpos[..., i].copy() - # Roll. - conventional[..., i+1] = - model_wing_qpos[..., i+1] - # Pitch. - conventional[..., i+2] = (np.pi / 2 - body_pitch_angle - - model_wing_qpos[..., i+2]) + # Yaw, doesn't require transformation. + conventional[..., [0, 3]] = model_wing_qpos[..., [0, 3]].copy() + # Roll. + conventional[..., [1, 4]] = - model_wing_qpos[..., [1, 4]] + # Pitch. + conventional[..., [2, 5]] = ( + np.pi / 2 - body_pitch_angle - model_wing_qpos[..., [2, 5]]) return conventional