Skip to content

Commit

Permalink
feat: return individual reward for each agent from connector
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Nov 13, 2024
1 parent 5ab7166 commit 69ba8b4
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
33 changes: 27 additions & 6 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.
- reward: jax array (float) of shape ():
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.
- reward: jax array (float) of shape (num_agents,):
- dense: for each agent the reward is 1 for each successful connection on that step.
Additionally, each pair of points that have not connected receives a
penalty reward of -0.03.
- episode termination:
- all agents either can't move (no available actions) or have connected to their target.
Expand Down Expand Up @@ -142,7 +143,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
step_count=state.step_count,
)
extras = self._get_extras(state)
timestep = restart(observation=observation, extras=extras)
timestep = restart(observation=observation, extras=extras, shape=(self.num_agents,))
return state, timestep

def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
Expand Down Expand Up @@ -171,19 +172,23 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observ
grid=grid, action_mask=action_mask, step_count=new_state.step_count
)

done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask))
done = jax.vmap(connected_or_blocked)(agents, action_mask)
discount = (1 - done).astype(float)
extras = self._get_extras(new_state)
timestep = jax.lax.cond(
done | (new_state.step_count >= self.time_limit),
jnp.all(done) | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
shape=(self.num_agents,),
),
lambda: transition(
reward=reward,
observation=observation,
extras=extras,
discount=discount,
shape=(self.num_agents,),
),
)

Expand Down Expand Up @@ -362,3 +367,19 @@ def action_spec(self) -> specs.MultiDiscreteArray:
dtype=jnp.int32,
name="action",
)

@cached_property
def reward_spec(self) -> specs.Array:
"""Returns: a reward per agent."""
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns: discount per agent."""
return specs.BoundedArray(
shape=(self.num_agents,),
dtype=float,
minimum=0.0,
maximum=1.0,
name="discount",
)
12 changes: 6 additions & 6 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_connector__reset(connector: Connector, key: jax.random.PRNGKey) -> None
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert jnp.allclose(timestep.discount, jnp.ones((connector.num_agents,)))
assert jnp.allclose(timestep.reward, jnp.zeros((connector.num_agents,)))
assert timestep.step_type == StepType.FIRST


Expand Down Expand Up @@ -94,7 +94,7 @@ def test_connector__step_connected(
chex.assert_trees_all_equal(real_state2, state2)

assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
reward = connector._reward_fn(real_state1, action2, real_state2)
assert jnp.array_equal(timestep.reward, reward)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_connector__step_blocked(

assert jnp.array_equal(state.grid, expected_grid)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))

assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))
Expand All @@ -165,12 +165,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None:
state, timestep = step_fn(state, actions)

assert timestep.step_type != StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(1))
assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))

# step 5
state, timestep = step_fn(state, actions)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))


def test_connector__step_agents_collision(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/connector/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def __call__(
~state.agents.connected & next_state.agents.connected, float
)
timestep_rewards = self.timestep_reward * jnp.asarray(~state.agents.connected, float)
return jnp.sum(connected_rewards + timestep_rewards)
return connected_rewards + timestep_rewards
23 changes: 12 additions & 11 deletions jumanji/environments/routing/connector/reward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,27 @@ def test_dense_reward(

# Reward of moving between the same states should be 0.
reward = dense_reward_fn(state, jnp.array([0, 0, 0]), state)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.asarray(timestep_reward * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([timestep_reward] * 3))

# Reward for no agents finished to 2 agents finished.
reward = dense_reward_fn(state, action1, state1)
chex.assert_rank(reward, 0)
expected_reward = connected_reward * 2 + timestep_reward * 3
assert jnp.isclose(reward, expected_reward)
chex.assert_rank(reward, 1)
expected_reward = jnp.array([connected_reward, 0, connected_reward]) + timestep_reward
assert jnp.allclose(reward, expected_reward)

# Reward for some agents finished to all agents finished.
reward = dense_reward_fn(state1, action2, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array(connected_reward + timestep_reward))
chex.assert_rank(reward, 1)
expected_reward = jnp.array([0, connected_reward + timestep_reward, 0])
assert jnp.allclose(reward, expected_reward)

# Reward for none finished to all finished
reward = dense_reward_fn(state, action1, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array((connected_reward + timestep_reward) * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([connected_reward + timestep_reward] * 3))

# Reward of all finished to all finished.
reward = dense_reward_fn(state2, jnp.zeros(3), state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.zeros(1))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.zeros(1))
2 changes: 1 addition & 1 deletion jumanji/training/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- env: snake # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, lbf, rubiks_cube, sliding_tile_puzzle, snake, sokoban, sudoku, tetris, tsp]
- env: connector # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, lbf, rubiks_cube, sliding_tile_puzzle, snake, sokoban, sudoku, tetris, tsp]

agent: random # [random, a2c]

Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/setup_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def setup_logger(cfg: DictConfig) -> Logger:

def _make_raw_env(cfg: DictConfig) -> Environment:
env = jumanji.make(cfg.env.registered_version)
if cfg.env.name in {"lbf"}:
if cfg.env.name in {"lbf", "connector"}:
# Convert a multi-agent environment to a single-agent environment
env = MultiToSingleWrapper(env)
return env
Expand Down

0 comments on commit 69ba8b4

Please sign in to comment.