Skip to content

Commit

Permalink
Add inference mode flag.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Mar 26, 2024
1 parent fbe8807 commit 6d3b28b
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion flybody/tasks/walk_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
terminal_com_dist: float = 0.33,
claw_friction: float | None = 1.0,
trajectory_sites: bool = True,
inference_mode: bool = False,
**kwargs):
"""This task is a combination of imitation walking and ghost tracking.
Expand All @@ -36,6 +37,8 @@ def __init__(self,
from model to ghost exceeds terminal_com_dist.
claw_friction: Friction of claw.
trajectory_sites: Whether to render trajectory sites.
inference_mode: Whether to run in test mode and skip full-body
reward calculation.
**kwargs: Arguments passed to the superclass constructor.
"""

Expand All @@ -44,6 +47,7 @@ def __init__(self,
self._traj_generator = traj_generator
self._terminal_com_dist = terminal_com_dist
self._trajectory_sites = trajectory_sites
self._inference_mode = inference_mode
self._max_episode_steps = round(
self._time_limit / self.control_timestep) + 1
self._next_traj_idx = None
Expand Down Expand Up @@ -145,7 +149,8 @@ def before_step(self, physics: 'mjcf.Physics', action,

def get_reward_factors(self, physics):
"""Returns factorized reward terms."""

if self._inference_mode:
return (1,)
step = round(physics.time() / self.control_timestep)
walker_ft = get_walker_features(physics, self._mocap_joints,
self._mocap_sites)
Expand Down

0 comments on commit 6d3b28b

Please sign in to comment.