diff --git a/flybody/fly_envs.py b/flybody/fly_envs.py index c7ee17d..1ee347d 100755 --- a/flybody/fly_envs.py +++ b/flybody/fly_envs.py @@ -21,7 +21,7 @@ def flight_imitation(wpg_pattern_path: str, ref_path: str, - random_state: np.random.RandomState = None, + random_state: np.random.RandomState | None = None, terminal_com_dist: float = 2.0): """Requires a fruitfly to track a flying reference. @@ -61,7 +61,7 @@ def flight_imitation(wpg_pattern_path: str, def walk_imitation(ref_path: str, - random_state: np.random.RandomState = None, + random_state: np.random.RandomState | None = None, terminal_com_dist: float = 0.3): """Requires a fruitfly to track a reference walking fly. @@ -77,8 +77,8 @@ def walk_imitation(ref_path: str, walker = fruitfly.FruitFly arena = floors.Floor() # Initialize a walking trajectory loader. - traj_generator = HDF5WalkingTrajectoryLoader(path=ref_path, - random_state=random_state) + traj_generator = HDF5WalkingTrajectoryLoader( + path=ref_path, random_state=random_state) # Build a task that rewards the agent for tracking a walking ghost. time_limit = 10.0 task = WalkImitation(walker=walker, @@ -97,7 +97,7 @@ def walk_imitation(ref_path: str, strip_singleton_obs_buffer_dim=True) -def walk_on_ball(random_state: np.random.RandomState = None): +def walk_on_ball(random_state: np.random.RandomState | None = None): """Requires a tethered fruitfly to walk on a floating ball. Args: @@ -128,7 +128,7 @@ def walk_on_ball(random_state: np.random.RandomState = None): def vision_guided_flight(wpg_pattern_path: str, bumps_or_trench: str = 'bumps', - random_state: np.random.RandomState = None, + random_state: np.random.RandomState | None = None, **kwargs_arena): """Vision-guided flight tasks: 'bumps' and 'trench'. diff --git a/flybody/tasks/walk_imitation.py b/flybody/tasks/walk_imitation.py index 373d374..75f5abd 100755 --- a/flybody/tasks/walk_imitation.py +++ b/flybody/tasks/walk_imitation.py @@ -1,7 +1,7 @@ """Walking imitation task for fruit fly.""" # ruff: noqa: F821 -from typing import Optional, Sequence +from typing import Sequence import numpy as np from flybody.tasks.base import Walking @@ -22,7 +22,7 @@ def __init__(self, mocap_joint_names: Sequence[str], mocap_site_names: Sequence[str], terminal_com_dist: float = 0.33, - claw_friction: Optional[float] = 1.0, + claw_friction: float | None = 1.0, trajectory_sites: bool = True, **kwargs): """This task is a combination of imitation walking and ghost tracking.