Skip to content

Commit

Permalink
Upgrade Esquilax and remove unused random keys (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Dec 27, 2024
1 parent 1e66e78 commit 407ff79
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 55 deletions.
13 changes: 4 additions & 9 deletions jumanji/environments/swarms/common/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import List, Tuple

import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -56,16 +55,14 @@ def test_velocity_update(
actions: List[float],
expected: Tuple[float, float],
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.zeros((1, 2)),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_heading, new_speed = updates.update_velocity(key, params, (actions, state))
new_heading, new_speed = updates.update_velocity(params, (actions, state))

assert jnp.isclose(new_heading[0], expected[0])
assert jnp.isclose(new_speed[0], expected[1])
Expand Down Expand Up @@ -117,16 +114,14 @@ def test_state_update(
expected_speed: float,
env_size: float,
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.array([pos]),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_state = updates.update_state(key, env_size, params, state, actions)
new_state = updates.update_state(env_size, params, state, actions)

assert isinstance(new_state, types.AgentState)
assert jnp.allclose(new_state.pos, jnp.array([expected_pos]))
Expand All @@ -137,7 +132,7 @@ def test_state_update(
def test_view_reduction() -> None:
view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5])
view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2])
result = updates.view_reduction(view_a, view_b)
result = updates.view_reduction_fn(view_a, view_b)
assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2]))


Expand Down Expand Up @@ -170,7 +165,7 @@ def test_view(pos: List[float], view_angle: float, env_size: float, expected: Li
)

obs = updates.view(
None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size
(view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size
)
assert jnp.allclose(obs, jnp.array(expected))

Expand Down
11 changes: 2 additions & 9 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@

@esquilax.transforms.amap
def update_velocity(
_key: chex.PRNGKey,
params: types.AgentParams,
x: Tuple[chex.Array, types.AgentState],
) -> Tuple[chex.Numeric, chex.Numeric]:
"""
Get the updated agent heading and speeds from actions
Args:
_key: Dummy JAX random key.
params: Agent parameters.
x: Agent rotation and acceleration actions.
Expand Down Expand Up @@ -72,7 +70,6 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: floa


def update_state(
key: chex.PRNGKey,
env_size: float,
params: types.AgentParams,
state: types.AgentState,
Expand All @@ -82,7 +79,6 @@ def update_state(
Update the state of a group of agents from a sample of actions
Args:
key: Dummy JAX random key.
env_size: Size of the environment.
params: Agent parameters.
state: Current agent states.
Expand All @@ -93,7 +89,7 @@ def update_state(
actions and updating positions.
"""
actions = jnp.clip(actions, min=-1.0, max=1.0)
headings, speeds = update_velocity(key, params, (actions, state))
headings, speeds = update_velocity(params, (actions, state))
positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size)

return types.AgentState(
Expand All @@ -103,7 +99,7 @@ def update_state(
)


def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array:
def view_reduction_fn(view_a: chex.Array, view_b: chex.Array) -> chex.Array:
"""
Binary view reduction function for use in Esquilax spatial transformation.
Expand Down Expand Up @@ -161,7 +157,6 @@ def angular_width(


def view(
_key: chex.PRNGKey,
params: Tuple[float, float],
viewing_agent: types.AgentState,
viewed_agent: types.AgentState,
Expand All @@ -181,8 +176,6 @@ def view(
Currently, this model assumes the viewed agent/objects are circular.
Args:
_key: Dummy JAX random key, required by esquilax API, but
not used during the interaction.
params: Tuple containing agent view angle and view-radius.
viewing_agent: Viewing agent state.
viewed_agent: State of agent being viewed.
Expand Down
10 changes: 3 additions & 7 deletions jumanji/environments/swarms/search_and_rescue/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional, Sequence, Tuple

import chex
import esquilax
import jax
import jax.numpy as jnp
from esquilax.transforms import spatial
Expand Down Expand Up @@ -222,25 +223,20 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
state: Updated searcher and target positions and velocities.
timestep: Transition timestep with individual agent local observations.
"""
# Note: only one new key is needed for the target updates, as all other
# keys are just dummy values required by Esquilax
key, target_key = jax.random.split(state.key, num=2)
searchers = update_state(
key, self.generator.env_size, self.searcher_params, state.searchers, actions
self.generator.env_size, self.searcher_params, state.searchers, actions
)

targets = self._target_dynamics(target_key, state.targets, self.generator.env_size)

# Searchers return an array of flags of any targets they are in range of,
# and that have not already been located, result shape here is (n-searcher, n-targets)
targets_found = spatial(
utils.searcher_detect_targets,
reduction=jnp.logical_or,
default=jnp.zeros((self.generator.num_targets,), dtype=bool),
reduction=esquilax.reductions.logical_or((self.generator.num_targets,)),
i_range=self.target_contact_range,
dims=self.generator.env_size,
)(
key,
self.searcher_params.view_angle,
searchers,
(jnp.arange(self.generator.num_targets), targets),
Expand Down
45 changes: 19 additions & 26 deletions jumanji/environments/swarms/search_and_rescue/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@
from typing import Tuple

import chex
import esquilax
import jax.numpy as jnp
from esquilax.transforms import spatial

from jumanji.environments.swarms.common.types import AgentState
from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction
from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction_fn
from jumanji.environments.swarms.search_and_rescue.types import State, TargetState


def view_reduction(view_shape: Tuple[int, ...]) -> esquilax.reductions.Reduction:
return esquilax.reductions.Reduction(
fn=view_reduction_fn,
id=-jnp.ones(view_shape),
)


class ObservationFn(abc.ABC):
def __init__(
self,
Expand Down Expand Up @@ -109,15 +116,13 @@ def __call__(self, state: State) -> chex.Array:
Array of individual agent views of shape
(n-agents, 1, n-vision).
"""
searcher_views = spatial(
searcher_views = esquilax.transforms.spatial(
view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
reduction=view_reduction((self.num_vision,)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.searchers,
Expand All @@ -130,7 +135,6 @@ def __call__(self, state: State) -> chex.Array:


def found_target_view(
_key: chex.PRNGKey,
params: Tuple[float, float],
searcher: AgentState,
target: TargetState,
Expand All @@ -146,7 +150,6 @@ def found_target_view(
by Esquilax.
Args:
_key: Dummy random key (required by Esquilax).
params: View angle and target visual radius.
searcher: Searcher agent state
target: Target state
Expand Down Expand Up @@ -224,15 +227,13 @@ def __call__(self, state: State) -> chex.Array:
(n-agents, 2, n-vision). Other agents are shown
in channel 0, and located targets 1.
"""
searcher_views = spatial(
searcher_views = esquilax.transforms.spatial(
view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
reduction=view_reduction((self.num_vision,)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.searchers,
Expand All @@ -241,15 +242,13 @@ def __call__(self, state: State) -> chex.Array:
i_range=self.vision_range,
env_size=self.env_size,
)
target_views = spatial(
target_views = esquilax.transforms.spatial(
found_target_view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
reduction=view_reduction((self.num_vision,)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.targets,
Expand All @@ -263,7 +262,6 @@ def __call__(self, state: State) -> chex.Array:


def all_target_view(
_key: chex.PRNGKey,
params: Tuple[float, float],
searcher: AgentState,
target: TargetState,
Expand All @@ -279,7 +277,6 @@ def all_target_view(
by Esquilax.
Args:
_key: Dummy random key (required by Esquilax).
params: View angle and target visual radius.
searcher: Searcher agent state
target: Target state
Expand Down Expand Up @@ -361,15 +358,13 @@ def __call__(self, state: State) -> chex.Array:
in channel 0, located targets 1, and un-located
targets at index 2.
"""
searcher_views = spatial(
searcher_views = esquilax.transforms.spatial(
view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
reduction=view_reduction((self.num_vision,)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.searchers,
Expand All @@ -378,15 +373,13 @@ def __call__(self, state: State) -> chex.Array:
i_range=self.vision_range,
env_size=self.env_size,
)
target_views = spatial(
target_views = esquilax.transforms.spatial(
all_target_view,
reduction=view_reduction,
default=-jnp.ones((2, self.num_vision)),
reduction=view_reduction((2, self.num_vision)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.targets,
Expand Down
2 changes: 0 additions & 2 deletions jumanji/environments/swarms/search_and_rescue/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _check_target_in_view(


def searcher_detect_targets(
_key: chex.PRNGKey,
searcher_view_angle: float,
searcher: AgentState,
target: Tuple[chex.Array, TargetState],
Expand All @@ -65,7 +64,6 @@ def searcher_detect_targets(
searchers view cone, and has not already been detected.
Args:
_key: Dummy random key (required by Esquilax).
searcher_view_angle: View angle of searching agents
representing a fraction of pi from the agents heading.
searcher: State of the searching agent (i.e. the agent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_target_found(
)

found = jax.jit(partial(searcher_detect_targets, env_size=env_size, n_targets=1))(
None,
view_angle,
searcher,
(jnp.arange(1), target),
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
chex>=0.1.3
dm-env>=1.5
esquilax>=1.0.3
esquilax>=2.0.0
gymnasium>=1.0
huggingface-hub
jax>=0.2.26
Expand Down

0 comments on commit 407ff79

Please sign in to comment.