diff --git a/mj_envs/logger/roboset_logger.py b/mj_envs/logger/roboset_logger.py index 1319552c..f783e78b 100644 --- a/mj_envs/logger/roboset_logger.py +++ b/mj_envs/logger/roboset_logger.py @@ -41,8 +41,8 @@ def path2dataset(self, path:dict, config_path=None)->dict: # Derived ===== pose_ee = [] - if 'pos_ee' in path_keys or 'rot_ee' in path_keys: - assert ('pos_ee' in path_keys and 'rot_ee' in path_keys), "Both pose_ee and rot_ee are required" + if 'env_infos/obs_dict/pos_ee' in path_keys or 'env_infos/obs_dict/rot_ee' in path_keys: + assert ('env_infos/obs_dict/pos_ee' in path_keys and 'env_infos/obs_dict/rot_ee' in path_keys), "Both pose_ee and rot_ee are required" dataset['derived/pose_ee'] = np.hstack([path['env_infos/obs_dict/pos_ee'], path['env_infos/obs_dict/rot_ee']]) # Config =====