diff --git a/loco_mujoco/utils/dataset.py b/loco_mujoco/utils/dataset.py index d289471..18a4625 100644 --- a/loco_mujoco/utils/dataset.py +++ b/loco_mujoco/utils/dataset.py @@ -132,7 +132,6 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir Dictionary containing the joint angles and velocities of the modified mocap dataset. """ - # extract the euler keys euler_keys = list(joint_conf.keys()) @@ -153,16 +152,6 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir n_datapoint = len(joint_pos[0]) joint_pos = dict(zip(joint_names, joint_pos)) joint_vel = dict(zip(joint_names, joint_vel)) - if type(unavailable_keys) == list: - for ukey in unavailable_keys: - joint_pos[ukey] = np.zeros(n_datapoint) - joint_vel[ukey] = np.zeros(n_datapoint) - elif type(unavailable_keys) == dict: - for ukey, val in unavailable_keys.items(): - joint_pos[ukey] = np.ones(n_datapoint) * val - joint_vel[ukey] = np.zeros(n_datapoint) - else: - raise TypeError # get the relevant data joint_pos = np.array([joint_pos[k] for k in euler_keys]) @@ -183,7 +172,7 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir i = euler_keys.index(k) euler_keys[i] = v - keys = ["q_" + k for k in euler_keys] + ["dq_" + k for k in euler_keys] + keys = ["q_" + k for k in euler_keys ] + ["dq_" + k for k in euler_keys] # add goal if available if "goal" in data.keys(): @@ -194,6 +183,18 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir # create dataset dataset = dict(zip(keys, trajec)) + # add unavailable keys + if type(unavailable_keys) == list: + for ukey in unavailable_keys: + dataset["q_" + ukey] = np.zeros(n_datapoint) + dataset["dq_" + ukey] = np.zeros(n_datapoint) + elif type(unavailable_keys) == dict: + for ukey, val in unavailable_keys.items(): + dataset["q_" + ukey] = np.ones(n_datapoint) * val + dataset["dq_" + ukey] = np.zeros(n_datapoint) + else: + raise TypeError + # if needed discard first and last part of the dataset for j_name, val in dataset.items(): val_temp = val[discard_first:]