Skip to content

Commit

Permalink
Add walking trajectory loader for testing/inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Mar 26, 2024
1 parent e4067e1 commit fbe8807
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions flybody/tasks/trajectory_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,38 @@ def get_site_names(self):
def get_joint_names(self):
"""Returns snippet joint names."""
return [s.decode('utf-8') for s in self._h5['id2name']['joints']]


class WalkingTrajectoryTestPlug():
"""A simple inference/test-time replacement for walking trajectory loader.
To use this class, create qpos and qvel for your test trajectory and then
set this trajectory for loading in the walking task by calling:
env.task._traj_generator.set_next_trajectory(qpos, qvel)
"""

def __init__(self):
# Nothing here!
pass

def set_next_trajectory(self, qpos: np.ndarray, qvel: np.ndarray):
"""Set new trajectory to be returned by get_trajectory.
Args:
qpos: Center-of-mass trajectory, (time, 7).
qvel: Velocity of CoM trajectory, (time, 6).
"""
self._snippet = {'qpos': qpos, 'qvel': qvel}

def get_trajectory(self, traj_idx: int):
del traj_idx # Unused.
if not hasattr(self, '_snippet'):
raise AttributeError(
'Trajectory not set yet. Call set_next_trajectory first.')
return self._snippet

def get_joint_names(self):
return []

def get_site_names(self):
return []

0 comments on commit fbe8807

Please sign in to comment.