Skip to content

Commit

Permalink
Fix unavailable keys logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Paweł Budzianowski authored Jun 5, 2024
1 parent 5fd903c commit 3be2256
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions loco_mujoco/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir
Dictionary containing the joint angles and velocities of the modified mocap dataset.
"""

# extract the euler keys
euler_keys = list(joint_conf.keys())

Expand All @@ -153,16 +152,6 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir
n_datapoint = len(joint_pos[0])
joint_pos = dict(zip(joint_names, joint_pos))
joint_vel = dict(zip(joint_names, joint_vel))
if type(unavailable_keys) == list:
for ukey in unavailable_keys:
joint_pos[ukey] = np.zeros(n_datapoint)
joint_vel[ukey] = np.zeros(n_datapoint)
elif type(unavailable_keys) == dict:
for ukey, val in unavailable_keys.items():
joint_pos[ukey] = np.ones(n_datapoint) * val
joint_vel[ukey] = np.zeros(n_datapoint)
else:
raise TypeError

# get the relevant data
joint_pos = np.array([joint_pos[k] for k in euler_keys])
Expand All @@ -183,7 +172,7 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir
i = euler_keys.index(k)
euler_keys[i] = v

keys = ["q_" + k for k in euler_keys] + ["dq_" + k for k in euler_keys]
keys = ["q_" + k for k in euler_keys ] + ["dq_" + k for k in euler_keys]

# add goal if available
if "goal" in data.keys():
Expand All @@ -194,6 +183,18 @@ def adapt_mocap(path, joint_conf, unavailable_keys, rename_map=None, discard_fir
# create dataset
dataset = dict(zip(keys, trajec))

# add unavailable keys
if type(unavailable_keys) == list:
for ukey in unavailable_keys:
dataset["q_" + ukey] = np.zeros(n_datapoint)
dataset["dq_" + ukey] = np.zeros(n_datapoint)
elif type(unavailable_keys) == dict:
for ukey, val in unavailable_keys.items():
dataset["q_" + ukey] = np.ones(n_datapoint) * val
dataset["dq_" + ukey] = np.zeros(n_datapoint)
else:
raise TypeError

# if needed discard first and last part of the dataset
for j_name, val in dataset.items():
val_temp = val[discard_first:]
Expand Down

0 comments on commit 3be2256

Please sign in to comment.