diff --git a/flybody/fly_envs.py b/flybody/fly_envs.py index c5d5881..ee7caa6 100755 --- a/flybody/fly_envs.py +++ b/flybody/fly_envs.py @@ -11,6 +11,7 @@ from flybody.tasks.walk_imitation import WalkImitation from flybody.tasks.walk_on_ball import WalkOnBall from flybody.tasks.vision_flight import VisionFlightImitationWBPG +from flybody.tasks.template_task import TemplateTask from flybody.tasks.arenas.ball import BallFloor from flybody.tasks.arenas.hills import SineBumps, SineTrench @@ -178,3 +179,33 @@ def vision_guided_flight(wpg_pattern_path: str, task=task, random_state=random_state, strip_singleton_obs_buffer_dim=True) + + +def template_task(random_state: np.random.RandomState | None = None, + joint_filter: float = 0.01, + adhesion_filter: float = 0.007, + time_limit: float = 1.): + """Fake template walking task for testing. + + Args: + random_state: Random state for reproducibility. + joint_filter: Timescale of filter for joint actuators. 0: disabled. + adhesion_filter: Timescale of filter for adhesion actuators. 0: disabled. + + Returns: + Template walking environment. + """ + # Build a fruitfly walker and arena. + walker = fruitfly.FruitFly + arena = floors.Floor() + # Build a task that rewards the agent for tracking a walking ghost. + task = TemplateTask(walker=walker, + arena=arena, + joint_filter=joint_filter, + adhesion_filter=adhesion_filter, + time_limit=time_limit) + + return composer.Environment(time_limit=time_limit, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) diff --git a/flybody/tasks/template_task.py b/flybody/tasks/template_task.py index 8ad4ea0..f7fb1b3 100755 --- a/flybody/tasks/template_task.py +++ b/flybody/tasks/template_task.py @@ -1,17 +1,15 @@ """Template class for walking fly tasks.""" # ruff: noqa: F821 -from typing import Optional import numpy as np from flybody.tasks.base import Walking -from flybody.tasks.constants import (_TERMINAL_ANGVEL, _TERMINAL_LINVEL) class TemplateTask(Walking): """Template class for walking fly tasks.""" - def __init__(self, claw_friction: Optional[float] = 1.0, **kwargs): + def __init__(self, claw_friction: float = 1.0, **kwargs): """Template class for walking fly tasks. Args: @@ -47,11 +45,10 @@ def before_step(self, physics: 'mjcf.Physics', action, def get_reward_factors(self, physics): """Returns factorized reward terms.""" # Calculate reward factors here. - return (1, ) + return (1,) def check_termination(self, physics: 'mjcf.Physics') -> bool: """Check various termination conditions.""" - linvel = np.linalg.norm(self._walker.observables.velocimeter(physics)) - angvel = np.linalg.norm(self._walker.observables.gyro(physics)) - return (linvel > _TERMINAL_LINVEL or angvel > _TERMINAL_ANGVEL - or super().check_termination(physics)) + # Maybe add some termination conditions. + should_terminate = False + return should_terminate or super().check_termination(physics)