Skip to content

Commit

Permalink
Add template task env and clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed Apr 8, 2024
1 parent 74b90f7 commit febbb3b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
31 changes: 31 additions & 0 deletions flybody/fly_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flybody.tasks.walk_imitation import WalkImitation
from flybody.tasks.walk_on_ball import WalkOnBall
from flybody.tasks.vision_flight import VisionFlightImitationWBPG
from flybody.tasks.template_task import TemplateTask

from flybody.tasks.arenas.ball import BallFloor
from flybody.tasks.arenas.hills import SineBumps, SineTrench
Expand Down Expand Up @@ -178,3 +179,33 @@ def vision_guided_flight(wpg_pattern_path: str,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True)


def template_task(random_state: np.random.RandomState | None = None,
joint_filter: float = 0.01,
adhesion_filter: float = 0.007,
time_limit: float = 1.):
"""Fake template walking task for testing.
Args:
random_state: Random state for reproducibility.
joint_filter: Timescale of filter for joint actuators. 0: disabled.
adhesion_filter: Timescale of filter for adhesion actuators. 0: disabled.
Returns:
Template walking environment.
"""
# Build a fruitfly walker and arena.
walker = fruitfly.FruitFly
arena = floors.Floor()
# Build a task that rewards the agent for tracking a walking ghost.
task = TemplateTask(walker=walker,
arena=arena,
joint_filter=joint_filter,
adhesion_filter=adhesion_filter,
time_limit=time_limit)

return composer.Environment(time_limit=time_limit,
task=task,
random_state=random_state,
strip_singleton_obs_buffer_dim=True)
13 changes: 5 additions & 8 deletions flybody/tasks/template_task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
"""Template class for walking fly tasks."""
# ruff: noqa: F821

from typing import Optional
import numpy as np

from flybody.tasks.base import Walking
from flybody.tasks.constants import (_TERMINAL_ANGVEL, _TERMINAL_LINVEL)


class TemplateTask(Walking):
"""Template class for walking fly tasks."""

def __init__(self, claw_friction: Optional[float] = 1.0, **kwargs):
def __init__(self, claw_friction: float = 1.0, **kwargs):
"""Template class for walking fly tasks.
Args:
Expand Down Expand Up @@ -47,11 +45,10 @@ def before_step(self, physics: 'mjcf.Physics', action,
def get_reward_factors(self, physics):
"""Returns factorized reward terms."""
# Calculate reward factors here.
return (1, )
return (1,)

def check_termination(self, physics: 'mjcf.Physics') -> bool:
"""Check various termination conditions."""
linvel = np.linalg.norm(self._walker.observables.velocimeter(physics))
angvel = np.linalg.norm(self._walker.observables.gyro(physics))
return (linvel > _TERMINAL_LINVEL or angvel > _TERMINAL_ANGVEL
or super().check_termination(physics))
# Maybe add some termination conditions.
should_terminate = False
return should_terminate or super().check_termination(physics)

0 comments on commit febbb3b

Please sign in to comment.