From 47a50c1bcecc030665a1b03e706bee102c737467 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Fri, 19 Jan 2024 22:52:52 +0000 Subject: [PATCH 01/16] test: implement spec access smoke test --- .../environments/logic/game_2048/env_test.py | 10 ++++++++- .../logic/graph_coloring/env_test.py | 10 ++++++++- .../logic/minesweeper/env_test.py | 10 ++++++++- .../logic/rubiks_cube/env_test.py | 15 ++++++++++++- jumanji/environments/logic/sudoku/env_test.py | 10 ++++++++- .../environments/packing/bin_pack/env_test.py | 11 +++++++++- .../environments/packing/job_shop/env_test.py | 9 +++++++- .../environments/packing/knapsack/env_test.py | 11 +++++++++- .../environments/packing/tetris/env_test.py | 14 +++++++++++++ .../environments/routing/cleaner/env_test.py | 9 +++++++- .../routing/connector/env_test.py | 10 ++++++++- jumanji/environments/routing/cvrp/env_test.py | 9 +++++++- jumanji/environments/routing/maze/env_test.py | 9 +++++++- jumanji/environments/routing/mmst/env_test.py | 21 ++++++++++++------- .../routing/multi_cvrp/env_test.py | 11 +++++++++- .../routing/robot_warehouse/env_test.py | 20 +++++++++++++----- .../environments/routing/snake/env_test.py | 10 ++++++++- jumanji/environments/routing/tsp/env_test.py | 11 +++++++++- jumanji/testing/env_not_smoke.py | 10 ++++++++- jumanji/testing/env_not_smoke_test.py | 6 ++++++ 20 files changed, 197 insertions(+), 29 deletions(-) diff --git a/jumanji/environments/logic/game_2048/env_test.py b/jumanji/environments/logic/game_2048/env_test.py index 7985e0278..125fba401 100644 --- a/jumanji/environments/logic/game_2048/env_test.py +++ b/jumanji/environments/logic/game_2048/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.logic.game_2048.env import Game2048 from jumanji.environments.logic.game_2048.types import Board, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -154,3 +157,8 @@ def test_game_2048__get_action_mask(game_2048: Game2048, board: Board) -> None: def test_game_2048__does_not_smoke(game_2048: Game2048) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(game_2048) + + +def test_game_2048__specs_does_not_smoke(game_2048: Game2048) -> None: + """Test that we access specs without any errors.""" + check_env_specs_does_not_smoke(game_2048) diff --git a/jumanji/environments/logic/graph_coloring/env_test.py b/jumanji/environments/logic/graph_coloring/env_test.py index d0418da77..f7b618b1d 100644 --- a/jumanji/environments/logic/graph_coloring/env_test.py +++ b/jumanji/environments/logic/graph_coloring/env_test.py @@ -18,7 +18,10 @@ from jumanji.environments.logic.graph_coloring import GraphColoring from jumanji.environments.logic.graph_coloring.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -90,3 +93,8 @@ def test_graph_coloring_get_action_mask(graph_coloring: GraphColoring) -> None: def test_graph_coloring_does_not_smoke(graph_coloring: GraphColoring) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(graph_coloring) + + +def test_graph_coloring_specs_does_not_smoke(graph_coloring: GraphColoring) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(graph_coloring) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 197675f0e..3cf52b620 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -24,7 +24,10 @@ from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -154,6 +157,11 @@ def test_minesweeper__does_not_smoke(minesweeper_env: Minesweeper) -> None: check_env_does_not_smoke(env=minesweeper_env) +def test_minesweeper__specs_does_not_smoke(minesweeper_env: Minesweeper) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(minesweeper_env) + + def test_minesweeper__render( monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper ) -> None: diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 3cbf7ac55..a1ae0717b 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -23,7 +23,10 @@ from jumanji.environments.logic.rubiks_cube.env import RubiksCube from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator from jumanji.environments.logic.rubiks_cube.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -84,6 +87,16 @@ def test_rubiks_cube__does_not_smoke(cube_size: int) -> None: check_env_does_not_smoke(env) +@pytest.mark.parametrize("cube_size", [3, 4, 5]) +def test_rubiks_cube__specs_does_not_smoke(cube_size: int) -> None: + """Test that we can access specs without any errors.""" + env = RubiksCube( + time_limit=10, + generator=ScramblingGenerator(cube_size=cube_size, num_scrambles_on_reset=5), + ) + check_env_specs_does_not_smoke(env) + + def test_rubiks_cube__render( monkeypatch: pytest.MonkeyPatch, rubiks_cube: RubiksCube ) -> None: diff --git a/jumanji/environments/logic/sudoku/env_test.py b/jumanji/environments/logic/sudoku/env_test.py index 9e55cdc12..3152339c3 100644 --- a/jumanji/environments/logic/sudoku/env_test.py +++ b/jumanji/environments/logic/sudoku/env_test.py @@ -23,7 +23,10 @@ from jumanji.environments.logic.sudoku.env import Sudoku from jumanji.environments.logic.sudoku.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -75,6 +78,11 @@ def test_sudoku__does_not_smoke(sudoku_env: Sudoku) -> None: check_env_does_not_smoke(env=sudoku_env) +def test_sudoku__specs_does_not_smoke(sudoku_env: Sudoku) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(env=sudoku_env) + + def test_sudoku__render(monkeypatch: pytest.MonkeyPatch, sudoku_env: Sudoku) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index 921ce7025..e536738f2 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -33,7 +33,11 @@ item_from_space, location_from_space, ) -from jumanji.testing.env_not_smoke import SelectActionFn, check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + SelectActionFn, + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -168,6 +172,11 @@ def test_bin_pack__does_not_smoke( check_env_does_not_smoke(bin_pack, bin_pack_random_select_action) +def test_bin_pack__specs_does_not_smoke(bin_pack: BinPack) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(bin_pack) + + def test_bin_pack__pack_all_items_dummy_instance( bin_pack: BinPack, bin_pack_random_select_action: SelectActionFn ) -> None: diff --git a/jumanji/environments/packing/job_shop/env_test.py b/jumanji/environments/packing/job_shop/env_test.py index 737f42bea..964042dac 100644 --- a/jumanji/environments/packing/job_shop/env_test.py +++ b/jumanji/environments/packing/job_shop/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.job_shop.generator import ToyGenerator from jumanji.environments.packing.job_shop.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.types import TimeStep @@ -816,3 +819,7 @@ def test_job_shop__toy_generator_reward(self) -> None: def test_job_shop_env__does_not_smoke(self, job_shop_env: JobShop) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(job_shop_env) + + def test_job_shop_env__specs_does_not_smoke(self, job_shop_env: JobShop) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(job_shop_env) diff --git a/jumanji/environments/packing/knapsack/env_test.py b/jumanji/environments/packing/knapsack/env_test.py index 139c9b3fb..32ea10cf7 100644 --- a/jumanji/environments/packing/knapsack/env_test.py +++ b/jumanji/environments/packing/knapsack/env_test.py @@ -17,7 +17,10 @@ from jax import numpy as jnp from jumanji.environments.packing.knapsack import Knapsack, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -75,6 +78,12 @@ def test_knapsack_sparse__does_not_smoke( """Test that we can run an episode without any errors.""" check_env_does_not_smoke(knapsack_sparse_reward) + def test_knapsack_sparse__specs_does_not_smoke( + self, knapsack_sparse_reward: Knapsack + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(knapsack_sparse_reward) + def test_knapsack_sparse__trajectory_action( self, knapsack_sparse_reward: Knapsack ) -> None: diff --git a/jumanji/environments/packing/tetris/env_test.py b/jumanji/environments/packing/tetris/env_test.py index a46017d22..d68a7026f 100644 --- a/jumanji/environments/packing/tetris/env_test.py +++ b/jumanji/environments/packing/tetris/env_test.py @@ -19,6 +19,10 @@ from jumanji.environments.packing.tetris.env import Tetris from jumanji.environments.packing.tetris.types import State +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -115,3 +119,13 @@ def test_calculate_action_mask(tetris_env: Tetris, grid: chex.Array) -> None: ] ) assert (action_mask == expected_action_mask).all() + + +def test_tetris__does_not_smoke(tetris_env: Tetris) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(tetris_env) + + +def test_tetris__specs_does_not_smoke(tetris_env: Tetris) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(tetris_env) diff --git a/jumanji/environments/routing/cleaner/env_test.py b/jumanji/environments/routing/cleaner/env_test.py index 7a8eb3192..f386a5dad 100644 --- a/jumanji/environments/routing/cleaner/env_test.py +++ b/jumanji/environments/routing/cleaner/env_test.py @@ -21,7 +21,10 @@ from jumanji.environments.routing.cleaner.env import Cleaner from jumanji.environments.routing.cleaner.generator import Generator from jumanji.environments.routing.cleaner.types import Observation, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -191,6 +194,10 @@ def select_action( check_env_does_not_smoke(cleaner, select_actions) + def test_cleaner__specs_does_not_smoke(self, cleaner: Cleaner) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(cleaner) + def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: state, _ = cleaner.reset(key) diff --git a/jumanji/environments/routing/connector/env_test.py b/jumanji/environments/routing/connector/env_test.py index 2468f1db2..19478400a 100644 --- a/jumanji/environments/routing/connector/env_test.py +++ b/jumanji/environments/routing/connector/env_test.py @@ -24,7 +24,10 @@ from jumanji.environments.routing.connector.env import Connector from jumanji.environments.routing.connector.types import Agent, State from jumanji.environments.routing.connector.utils import get_position, get_target -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.tree_utils import tree_slice from jumanji.types import StepType, TimeStep @@ -230,6 +233,11 @@ def test_connector__does_not_smoke(connector: Connector) -> None: check_env_does_not_smoke(connector) +def test_connector__specs_does_not_smoke(connector: Connector) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(connector) + + def test_connector__get_action_mask(state: State, connector: Connector) -> None: """Validates the action masking.""" action_masks = jax.vmap(connector._get_action_mask, (0, None))( diff --git a/jumanji/environments/routing/cvrp/env_test.py b/jumanji/environments/routing/cvrp/env_test.py index 6ad2d2184..c0f828db2 100644 --- a/jumanji/environments/routing/cvrp/env_test.py +++ b/jumanji/environments/routing/cvrp/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.cvrp.constants import DEPOT_IDX from jumanji.environments.routing.cvrp.env import CVRP from jumanji.environments.routing.cvrp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -95,6 +98,10 @@ def test_cvrp_sparse__does_not_smoke(self, cvrp_sparse_reward: CVRP) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(cvrp_sparse_reward) + def test_cvrp_sparse__specs_does_not_smoke(self, cvrp_sparse_reward: CVRP) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(cvrp_sparse_reward) + def test_cvrp_sparse__trajectory_action(self, cvrp_sparse_reward: CVRP) -> None: """Tests a trajectory by visiting nodes in increasing and cyclic order, visiting the depot when the next node in the list surpasses the current capacity of the agent. diff --git a/jumanji/environments/routing/maze/env_test.py b/jumanji/environments/routing/maze/env_test.py index e84c9b942..5cf2a69a4 100644 --- a/jumanji/environments/routing/maze/env_test.py +++ b/jumanji/environments/routing/maze/env_test.py @@ -20,7 +20,10 @@ from jumanji.environments.routing.maze.env import Maze from jumanji.environments.routing.maze.generator import RandomGenerator, ToyGenerator from jumanji.environments.routing.maze.types import Position, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -227,3 +230,7 @@ def test_maze__toy_generator(self) -> None: def test_maze__does_not_smoke(self, maze: Maze) -> None: check_env_does_not_smoke(maze) + + def test_maze__specs_does_not_smoke(self, maze: Maze) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(maze) diff --git a/jumanji/environments/routing/mmst/env_test.py b/jumanji/environments/routing/mmst/env_test.py index ccd494f86..06f8e57d7 100644 --- a/jumanji/environments/routing/mmst/env_test.py +++ b/jumanji/environments/routing/mmst/env_test.py @@ -24,13 +24,16 @@ ) from jumanji.environments.routing.mmst.env import MMST from jumanji.environments.routing.mmst.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep def test__mmst_agent_observation( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: """Test that agent observation view of the node types is correct""" @@ -49,7 +52,7 @@ def test__mmst_agent_observation( def test__mmst_action_tie_break( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: """Test if the actions are mask correctly if multiple agents select the same node as next nodes. @@ -131,10 +134,14 @@ def test__mmst_does_not_smoke( check_env_does_not_smoke(mmst_split_gn_env) +def test__mmst_specs_does_not_smoke(mmst_split_gn_env: MMST) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(mmst_split_gn_env) + + def test__mmst_termination( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: - env, state, timestep = deterministic_mmst_env step_fn = jax.jit(env.step) @@ -170,7 +177,6 @@ def test__mmst_termination( def test__mmst_truncation(deterministic_mmst_env: Tuple[MMST, State, TimeStep]) -> None: - env, state, timestep = deterministic_mmst_env step_fn = jax.jit(env.step) @@ -182,9 +188,8 @@ def test__mmst_truncation(deterministic_mmst_env: Tuple[MMST, State, TimeStep]) def test__mmst_action_masking( - deterministic_mmst_env: Tuple[MMST, State, TimeStep] + deterministic_mmst_env: Tuple[MMST, State, TimeStep], ) -> None: - env, state, _ = deterministic_mmst_env step_fn = jax.jit(env.step) diff --git a/jumanji/environments/routing/multi_cvrp/env_test.py b/jumanji/environments/routing/multi_cvrp/env_test.py index a93aaf17e..f3292cf40 100644 --- a/jumanji/environments/routing/multi_cvrp/env_test.py +++ b/jumanji/environments/routing/multi_cvrp/env_test.py @@ -23,7 +23,10 @@ test_node_demand, ) from jumanji.environments.routing.multi_cvrp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -278,3 +281,9 @@ def select_action( return select_action(subkeys, observation.action_mask) check_env_does_not_smoke(multicvrp_env, select_actions) + + def test_env_multicvrp__specs_does_not_smoke( + self, multicvrp_env: MultiCVRP + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(multicvrp_env) diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py index e5d60b94d..8b5840756 100644 --- a/jumanji/environments/routing/robot_warehouse/env_test.py +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -21,7 +21,10 @@ from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.robot_warehouse.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.tree_utils import tree_slice from jumanji.types import TimeStep @@ -60,7 +63,7 @@ def test_robot_warehouse__reset(robot_warehouse_env: RobotWarehouse) -> None: def test_robot_warehouse__agent_observation( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate the agent observation function.""" env, state, timestep = deterministic_robot_warehouse_env @@ -163,6 +166,13 @@ def test_robot_warehouse__does_not_smoke(robot_warehouse_env: RobotWarehouse) -> check_env_does_not_smoke(robot_warehouse_env) +def test_robot_warehouse__specs_does_not_smoke( + robot_warehouse_env: RobotWarehouse, +) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(robot_warehouse_env) + + def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> None: """Validate the terminal reward.""" step_fn = jax.jit(robot_warehouse_env.step) @@ -179,7 +189,7 @@ def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> Non def test_robot_warehouse__truncation( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate episode truncation based on set time limit.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env @@ -197,7 +207,7 @@ def test_robot_warehouse__truncation( def test_robot_warehouse__truncate_upon_collision( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate episode terminates upon collision of agents.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env @@ -217,7 +227,7 @@ def test_robot_warehouse__truncate_upon_collision( def test_robot_warehouse__reward_in_goal( - deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep], ) -> None: """Validate goal reward behavior.""" robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env diff --git a/jumanji/environments/routing/snake/env_test.py b/jumanji/environments/routing/snake/env_test.py index df37aff3a..e684011f7 100644 --- a/jumanji/environments/routing/snake/env_test.py +++ b/jumanji/environments/routing/snake/env_test.py @@ -22,7 +22,10 @@ from jumanji.environments.routing.snake.env import Snake, State from jumanji.environments.routing.snake.types import Position -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -94,6 +97,11 @@ def test_snake__does_not_smoke(snake: Snake) -> None: check_env_does_not_smoke(snake) +def test_snake__specs_does_not_smoke(snake: Snake) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(snake) + + def test_update_head_position(snake: Snake) -> None: """Validates _update_head_position method. Checks that starting from a certain position, taking some actions diff --git a/jumanji/environments/routing/tsp/env_test.py b/jumanji/environments/routing/tsp/env_test.py index d43e51f13..7d320b13d 100644 --- a/jumanji/environments/routing/tsp/env_test.py +++ b/jumanji/environments/routing/tsp/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.tsp.env import TSP from jumanji.environments.routing.tsp.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -198,6 +201,12 @@ def test_tsp_sparse__does_not_smoke( """Test that we can run an episode without any errors.""" check_env_does_not_smoke(tsp_sparse_reward) + def test_tsp_sparse__specs_does_not_smoke( + self, tsp_sparse_reward: TSP, capsys: pytest.CaptureFixture + ) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(tsp_sparse_reward) + def test_tsp_sparse__trajectory_action(self, tsp_sparse_reward: TSP) -> None: """Checks that the agent stops when there are no more cities to be selected and that the appropriate reward is received. The testing loop ensures that no city is selected twice. diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 8a3cb34b4..411a0f80c 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -29,7 +29,7 @@ def make_random_select_action_fn( action_spec: Union[ specs.BoundedArray, specs.DiscreteArray, specs.MultiDiscreteArray - ] + ], ) -> SelectActionFn: """Create select action function that chooses random actions.""" @@ -97,3 +97,11 @@ def check_env_does_not_smoke( env.observation_spec().validate(timestep.observation) if assert_finite_check: chex.assert_tree_all_finite((state, timestep)) + + +def check_env_specs_does_not_smoke(env: Environment) -> None: + """Access specs of the environment in a jitted function to check no errors occur.""" + jax.jit(env.observation_spec()) + jax.jit(env.action_spec()) + jax.jit(env.reward_spec()) + jax.jit(env.discount_spec()) diff --git a/jumanji/testing/env_not_smoke_test.py b/jumanji/testing/env_not_smoke_test.py index d11f357d1..ca8f0a28a 100644 --- a/jumanji/testing/env_not_smoke_test.py +++ b/jumanji/testing/env_not_smoke_test.py @@ -20,6 +20,7 @@ from jumanji.testing.env_not_smoke import ( SelectActionFn, check_env_does_not_smoke, + check_env_specs_does_not_smoke, make_random_select_action_fn, ) from jumanji.testing.fakes import FakeEnvironment @@ -61,3 +62,8 @@ def test_random_select_action(fake_env: FakeEnvironment) -> None: action_2 = select_action(key3, timestep.observation) fake_env.action_spec().validate(action_1) assert not jnp.all(action_1 == action_2) + + +def test_env_specs_not_smoke(fake_env: FakeEnvironment) -> None: + """Test that the""" + check_env_specs_does_not_smoke(fake_env) From 84d170b603aa78f9b7be5c2ae6abfe290ca726f4 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Sat, 20 Jan 2024 04:44:44 +0000 Subject: [PATCH 02/16] feat: implement specs as properties --- README.md | 2 +- docs/guides/advanced_usage.md | 2 +- docs/guides/wrappers.md | 4 +- jumanji/env.py | 51 ++++++++++++++-- jumanji/environments/logic/game_2048/env.py | 9 +-- .../environments/logic/graph_coloring/env.py | 9 +-- jumanji/environments/logic/minesweeper/env.py | 9 +-- .../logic/minesweeper/env_test.py | 6 +- jumanji/environments/logic/rubiks_cube/env.py | 9 +-- .../logic/rubiks_cube/env_test.py | 6 +- jumanji/environments/logic/sudoku/env.py | 11 ++-- jumanji/environments/logic/sudoku/env_test.py | 4 +- .../environments/packing/bin_pack/conftest.py | 2 +- jumanji/environments/packing/bin_pack/env.py | 9 ++- .../environments/packing/bin_pack/env_test.py | 4 +- jumanji/environments/packing/job_shop/env.py | 7 ++- jumanji/environments/packing/knapsack/env.py | 11 ++-- jumanji/environments/packing/tetris/env.py | 9 +-- jumanji/environments/routing/cleaner/env.py | 7 ++- jumanji/environments/routing/connector/env.py | 9 +-- jumanji/environments/routing/cvrp/env.py | 11 ++-- jumanji/environments/routing/maze/env.py | 9 +-- jumanji/environments/routing/mmst/env.py | 10 ++-- .../environments/routing/multi_cvrp/env.py | 9 +-- .../routing/robot_warehouse/env.py | 9 +-- .../routing/robot_warehouse/env_test.py | 4 +- jumanji/environments/routing/snake/env.py | 12 ++-- .../environments/routing/snake/env_test.py | 6 +- jumanji/environments/routing/tsp/env.py | 11 ++-- jumanji/registration_test.py | 2 +- jumanji/testing/env_not_smoke.py | 26 ++++---- jumanji/testing/env_not_smoke_test.py | 4 +- jumanji/testing/fakes.py | 28 +++++---- jumanji/testing/fakes_test.py | 8 +-- jumanji/training/agents/a2c/a2c_agent.py | 2 +- .../training/agents/random/random_agent.py | 2 +- .../networks/bin_pack/actor_critic.py | 2 +- jumanji/training/networks/bin_pack/random.py | 2 +- .../training/networks/cleaner/actor_critic.py | 2 +- .../networks/connector/actor_critic.py | 2 +- .../training/networks/cvrp/actor_critic.py | 2 +- .../networks/game_2048/actor_critic.py | 2 +- .../networks/graph_coloring/actor_critic.py | 2 +- .../networks/job_shop/actor_critic.py | 2 +- .../networks/knapsack/actor_critic.py | 2 +- .../training/networks/maze/actor_critic.py | 2 +- .../networks/minesweeper/actor_critic.py | 2 +- .../training/networks/minesweeper/random.py | 2 +- .../training/networks/mmst/actor_critic.py | 4 +- .../networks/multi_cvrp/actor_critic.py | 4 +- .../networks/robot_warehouse/actor_critic.py | 2 +- .../networks/rubiks_cube/actor_critic.py | 2 +- .../training/networks/rubiks_cube/random.py | 4 +- .../training/networks/snake/actor_critic.py | 2 +- .../training/networks/sudoku/actor_critic.py | 4 +- jumanji/training/networks/sudoku/random.py | 2 +- .../training/networks/tetris/actor_critic.py | 2 +- jumanji/training/networks/tetris/random.py | 2 +- jumanji/training/networks/tsp/actor_critic.py | 2 +- jumanji/wrappers.py | 36 +++++------ jumanji/wrappers_test.py | 59 +++++++++---------- 61 files changed, 269 insertions(+), 213 deletions(-) diff --git a/README.md b/README.md index 38eb32850..59e9b394e 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ state, timestep = jax.jit(env.reset)(key) env.render(state) # Interact with the (jit-able) environment -action = env.action_spec().generate_value() # Action selection (dummy value here) +action = env.action_spec.generate_value() # Action selection (dummy value here) state, timestep = jax.jit(env.step)(state, action) # Take a step and observe the next state and time step ``` diff --git a/docs/guides/advanced_usage.md b/docs/guides/advanced_usage.md index 19c29c358..0b1eed97a 100644 --- a/docs/guides/advanced_usage.md +++ b/docs/guides/advanced_usage.md @@ -16,7 +16,7 @@ env = AutoResetWrapper(env) # Automatically reset the environment when an ep batch_size = 7 rollout_length = 5 -num_actions = env.action_spec().num_values +num_actions = env.action_spec.num_values random_key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(random_key) diff --git a/docs/guides/wrappers.md b/docs/guides/wrappers.md index 838131480..4623a1588 100644 --- a/docs/guides/wrappers.md +++ b/docs/guides/wrappers.md @@ -13,7 +13,7 @@ env = jumanji.make("Snake-6x6-v0") dm_env = jumanji.wrappers.JumanjiToDMEnvWrapper(env) timestep = dm_env.reset() -action = dm_env.action_spec().generate_value() +action = dm_env.action_spec.generate_value() next_timestep = dm_env.step(action) ... ``` @@ -52,7 +52,7 @@ key = jax.random.PRNGKey(0) state, timestep = env.reset(key) print("New episode") for i in range(100): - action = env.action_spec().generate_value() # Returns jnp.array(0) when using Snake. + action = env.action_spec.generate_value() # Returns jnp.array(0) when using Snake. state, timestep = env.step(state, action) if timestep.first(): print("New episode") diff --git a/jumanji/env.py b/jumanji/env.py index d3ddac6bd..aaabe5471 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -45,6 +45,13 @@ class Environment(abc.ABC, Generic[State]): def __repr__(self) -> str: return "Environment." + def __init__(self) -> None: + """Initialize environment.""" + self._observation_spec = self._make_observation_spec() + self._action_spec = self._make_action_spec() + self._reward_spec = self._make_reward_spec() + self._discount_spec = self._make_discount_spec() + @abc.abstractmethod def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Resets the environment to an initial state. @@ -70,34 +77,68 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: timestep: TimeStep object corresponding the timestep returned by the environment, """ - @abc.abstractmethod + @property def observation_spec(self) -> specs.Spec: """Returns the observation spec. Returns: observation_spec: a NestedSpec tree of spec. """ + return self._observation_spec @abc.abstractmethod + def _make_observation_spec(self) -> specs.Spec: + """Returns new observation spec. + + Returns: + observation_spec: a NestedSpec tree of spec. + """ + + @property def action_spec(self) -> specs.Spec: """Returns the action spec. Returns: action_spec: a NestedSpec tree of spec. """ + return self._action_spec + + @abc.abstractmethod + def _make_action_spec(self) -> specs.Spec: + """Returns new action spec. + Returns: + action_spec: a NestedSpec tree of spec. + """ + + @property def reward_spec(self) -> specs.Array: - """Describes the reward returned by the environment. By default, this is assumed to be a - single float. + """Returns the reward spec. By default, this is assumed to be a single float. + + Returns: + reward_spec: a `specs.Array` spec. + """ + return self._reward_spec + + def _make_reward_spec(self) -> specs.Array: + """Returns new reward spec. By default, this is assumed to be a single float. Returns: reward_spec: a `specs.Array` spec. """ return specs.Array(shape=(), dtype=float, name="reward") + @property def discount_spec(self) -> specs.BoundedArray: - """Describes the discount returned by the environment. By default, this is assumed to be a - single float between 0 and 1. + """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. + + Returns: + discount_spec: a `specs.BoundedArray` spec. + """ + return self._discount_spec + + def _make_discount_spec(self) -> specs.BoundedArray: + """Returns new discount spec. By default, this is assumed to be a single float between 0 and 1. Returns: discount_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index ba8115c16..36e3cdd37 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -69,7 +69,7 @@ class Game2048(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -85,6 +85,7 @@ def __init__( viewer: `Viewer` used for rendering. Defaults to `Game2048Viewer`. """ self.board_size = board_size + super().__init__() # Create viewer used for rendering self._viewer = viewer or Game2048Viewer("2048", board_size) @@ -97,7 +98,7 @@ def __repr__(self) -> str: """ return f"2048 Game(board_size={self.board_size})" - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Game2048` environment. Returns: @@ -122,8 +123,8 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index 36970d7da..c2bf663bf 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -76,7 +76,7 @@ class GraphColoring(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -100,6 +100,7 @@ def __init__( num_nodes=20, edge_probability=0.8 ) self.num_nodes = self.generator.num_nodes + super().__init__() # Create viewer used for rendering self._env_viewer = viewer or GraphColoringViewer(name="GraphColoring") @@ -206,8 +207,8 @@ def step( ) return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for the `Observation` whose fields are: @@ -253,7 +254,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) - def action_spec(self) -> specs.DiscreteArray: + def _make_action_spec(self) -> specs.DiscreteArray: """Specification of the action for the `GraphColoring` environment. Returns: diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index a5d9e5f01..eddfe6202 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -81,7 +81,7 @@ class Minesweeper(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -127,6 +127,7 @@ def __init__( self.num_rows = self.generator.num_rows self.num_cols = self.generator.num_cols self.num_mines = self.generator.num_mines + super().__init__() self._viewer = viewer or MinesweeperViewer( num_rows=self.num_rows, num_cols=self.num_cols ) @@ -182,7 +183,7 @@ def step( ) return next_state, next_timestep - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Minesweeper` environment. Returns: @@ -229,8 +230,8 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. An action consists of the height and width of the square to be explored. Returns: diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 3cf52b620..7ae532deb 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -126,7 +126,7 @@ def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: key = jax.random.PRNGKey(0) state, timestep = jax.jit(minesweeper_env.reset)(key) # For this board, this action will be a non-mined square - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -170,7 +170,7 @@ def test_minesweeper__render( state, timestep = jax.jit(minesweeper_env.reset)(jax.random.PRNGKey(0)) minesweeper_env.render(state) minesweeper_env.close() - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() state, timestep = jax.jit(minesweeper_env.step)(state, action) minesweeper_env.render(state) minesweeper_env.close() @@ -179,7 +179,7 @@ def test_minesweeper__render( def test_minesweeper__done_invalid_action(minesweeper_env: Minesweeper) -> None: """Test that the strict done signal is sent correctly""" # Note that this action corresponds to not stepping on a mine - action = minesweeper_env.action_spec().generate_value() + action = minesweeper_env.action_spec.generate_value() *_, episode_length = play_and_get_episode_stats( env=minesweeper_env, actions=[action for _ in range(10)], time_limit=10 ) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index bd01f9809..5dc5be8d8 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -75,7 +75,7 @@ class RubiksCube(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -113,6 +113,7 @@ def __init__( cube_size=3, num_scrambles_on_reset=100, ) + super().__init__() self._viewer = viewer or RubiksCubeViewer( sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self.generator.cube_size ) @@ -173,7 +174,7 @@ def step( ) return next_state, next_timestep - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `RubiksCube` environment. Returns: @@ -202,8 +203,8 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index a1ae0717b..59d56cfa3 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -54,7 +54,7 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: step_fn = jax.jit(chex.assert_max_traces(rubiks_cube.step, n=1)) key = jax.random.PRNGKey(0) state, timestep = rubiks_cube.reset(key) - action = rubiks_cube.action_spec().generate_value() + action = rubiks_cube.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -105,7 +105,7 @@ def test_rubiks_cube__render( state, timestep = rubiks_cube.reset(jax.random.PRNGKey(0)) rubiks_cube.render(state) rubiks_cube.close() - action = rubiks_cube.action_spec().generate_value() + action = rubiks_cube.action_spec.generate_value() state, timestep = rubiks_cube.step(state, action) rubiks_cube.render(state) rubiks_cube.close() @@ -116,7 +116,7 @@ def test_rubiks_cube__done(time_limit: int) -> None: """Test that the done signal is sent correctly.""" env = RubiksCube(time_limit=time_limit) state, timestep = env.reset(jax.random.PRNGKey(0)) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() episode_length = 0 step_fn = jax.jit(env.step) while not timestep.last(): diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 629da1d1a..03c222af1 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -66,7 +66,7 @@ class Sudoku(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -78,6 +78,7 @@ def __init__( reward_fn: Optional[RewardFn] = None, viewer: Optional[Viewer[State]] = None, ): + super().__init__() if generator is None: file_path = os.path.dirname(os.path.abspath(__file__)) database_file = DATABASES["mixed"] @@ -129,8 +130,8 @@ def step( return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec containing the board and action_mask arrays. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec containing the board and action_mask arrays. Returns: Spec containing all the specifications for all the `Observation` fields: @@ -158,8 +159,8 @@ def observation_spec(self) -> specs.Spec[Observation]: Observation, "ObservationSpec", board=board, action_mask=action_mask ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. An action is composed of 3 integers: the row index, + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: diff --git a/jumanji/environments/logic/sudoku/env_test.py b/jumanji/environments/logic/sudoku/env_test.py index 3152339c3..85bd9d672 100644 --- a/jumanji/environments/logic/sudoku/env_test.py +++ b/jumanji/environments/logic/sudoku/env_test.py @@ -56,7 +56,7 @@ def test_sudoku__step(sudoku_env: Sudoku) -> None: key = jax.random.PRNGKey(0) state, timestep = jax.jit(sudoku_env.reset)(key) - action = sudoku_env.action_spec().generate_value() + action = sudoku_env.action_spec.generate_value() next_state, next_timestep = step_fn(state, action) # Check that the state has changed @@ -89,7 +89,7 @@ def test_sudoku__render(monkeypatch: pytest.MonkeyPatch, sudoku_env: Sudoku) -> state, timestep = jax.jit(sudoku_env.reset)(jax.random.PRNGKey(0)) sudoku_env.render(state) sudoku_env.close() - action = sudoku_env.action_spec().generate_value() + action = sudoku_env.action_spec.generate_value() state, timestep = jax.jit(sudoku_env.step)(state, action) sudoku_env.render(state) sudoku_env.close() diff --git a/jumanji/environments/packing/bin_pack/conftest.py b/jumanji/environments/packing/bin_pack/conftest.py index 9325960ba..33ca6c437 100644 --- a/jumanji/environments/packing/bin_pack/conftest.py +++ b/jumanji/environments/packing/bin_pack/conftest.py @@ -111,7 +111,7 @@ def bin_pack(dummy_generator: DummyGenerator) -> BinPack: @pytest.fixture def obs_spec(bin_pack: BinPack) -> specs.Spec: - return bin_pack.observation_spec() + return bin_pack.observation_spec @pytest.fixture diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 8cd6aa259..3e83ce8f3 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -106,7 +106,7 @@ class BinPack(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -154,6 +154,7 @@ def __init__( self.obs_num_ems = obs_num_ems self.reward_fn = reward_fn or DenseReward() self.normalize_dimensions = normalize_dimensions + super().__init__() self._viewer = viewer or BinPackViewer("BinPack", render_mode="human") self.debug = debug @@ -171,7 +172,7 @@ def __repr__(self) -> str: ] ) - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `BinPack` environment. Returns: @@ -248,7 +249,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.MultiDiscreteArray: + def _make_action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `BinPack` environment. Returns: @@ -610,13 +611,11 @@ def _get_intersections_dict( _, direction_intersections_mask, ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): - # Inner loop iterates through alternative directions. for (alt_direction, alt_direction_intersections_ems), ( _, alt_direction_intersections_mask, ) in zip(intersections_ems_dict.items(), intersections_mask_dict.items()): - # The current direction EMS is included in the alternative EMS. directions_included_in_alt_directions = jax.vmap( jax.vmap(Space.is_included, in_axes=(None, 0)), in_axes=(0, None) diff --git a/jumanji/environments/packing/bin_pack/env_test.py b/jumanji/environments/packing/bin_pack/env_test.py index e536738f2..967a56538 100644 --- a/jumanji/environments/packing/bin_pack/env_test.py +++ b/jumanji/environments/packing/bin_pack/env_test.py @@ -44,7 +44,7 @@ @pytest.fixture def bin_pack_random_select_action(bin_pack: BinPack) -> SelectActionFn: - num_ems, num_items = np.asarray(bin_pack.action_spec().num_values) + num_ems, num_items = np.asarray(bin_pack.action_spec.num_values) def select_action(key: chex.PRNGKey, observation: Observation) -> chex.Array: """Randomly sample valid actions, as determined by `observation.action_mask`.""" @@ -152,7 +152,7 @@ def test_bin_pack_step__jit(bin_pack: BinPack) -> None: key = jax.random.PRNGKey(0) state, timestep = bin_pack.reset(key) - action = bin_pack.action_spec().generate_value() + action = bin_pack.action_spec.generate_value() _ = step_fn(state, action) # Call again to check it does not compile twice. state, timestep = step_fn(state, action) diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 1af25af87..98c4686b3 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -83,7 +83,7 @@ class JobShop(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -113,6 +113,7 @@ def __init__( self.num_machines = self.generator.num_machines self.max_num_ops = self.generator.max_num_ops self.max_op_duration = self.generator.max_op_duration + super().__init__() # Define the "job id" of a no-op action as the number of jobs self.no_op_idx = self.num_jobs @@ -356,7 +357,7 @@ def _update_machines( return updated_machines_job_ids, updated_machines_remaining_times - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `JobShop` environment. Returns: @@ -421,7 +422,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.MultiDiscreteArray: + def _make_action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action in the `JobShop` environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index ce6c3838c..e063d29ab 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -76,7 +76,7 @@ class Knapsack(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -107,6 +107,7 @@ def __init__( total_budget=12.5, ) self.num_items = self.generator.num_items + super().__init__() self.total_budget = self.generator.total_budget self.reward_fn = reward_fn or DenseReward() self._viewer = viewer or KnapsackViewer( @@ -176,8 +177,8 @@ def step( return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for each field in the Observation: @@ -223,8 +224,8 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 37e608663..693035d6d 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -69,7 +69,7 @@ class Tetris(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -106,6 +106,7 @@ def __init__( self.TETROMINOES_LIST = jnp.array(TETROMINOES_LIST, jnp.int32) self.reward_list = jnp.array(REWARD_LIST, float) self.time_limit = time_limit + super().__init__() self._viewer = viewer or TetrisViewer( num_rows=self.num_rows, num_cols=self.num_cols, @@ -246,7 +247,7 @@ def render(self, state: State) -> Optional[NDArray]: """ return self._viewer.render(state) - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Tetris` environment. Returns: @@ -285,8 +286,8 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. An action consists of two pieces of information: + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index f1e1410c1..4dfb12a18 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -74,7 +74,7 @@ class Cleaner(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -107,6 +107,7 @@ def __init__( self.num_cols = self.generator.num_cols self.grid_shape = (self.num_rows, self.num_cols) self.time_limit = time_limit or (self.num_rows * self.num_cols) + super().__init__() self.penalty_per_timestep = penalty_per_timestep # Create viewer used for rendering @@ -122,7 +123,7 @@ def __repr__(self) -> str: ")" ) - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `Cleaner` environment. Returns: @@ -152,7 +153,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.MultiDiscreteArray: + def _make_action_spec(self) -> specs.MultiDiscreteArray: """Specification of the action for the `Cleaner` environment. Returns: diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index 7a139f31b..653a63c26 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -88,7 +88,7 @@ class Connector(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_specc.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -118,6 +118,7 @@ def __init__( self.time_limit = time_limit self.num_agents = self._generator.num_agents self.grid_size = self._generator.grid_size + super().__init__() self._agent_ids = jnp.arange(self.num_agents) self._viewer = viewer or ConnectorViewer( "Connector", self.num_agents, render_mode="human" @@ -318,7 +319,7 @@ def close(self) -> None: """ self._viewer.close() - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Connector` environment. Returns: @@ -356,8 +357,8 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec for the Connector environment. + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index 6af507519..1f10a0467 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -89,7 +89,7 @@ class CVRP(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -121,6 +121,7 @@ def __init__( max_demand=10, ) self.num_nodes = self.generator.num_nodes + super().__init__() self.max_capacity = self.generator.max_capacity self.max_demand = self.generator.max_demand if self.max_capacity < self.max_demand: @@ -195,8 +196,8 @@ def step( ) return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for the `Observation` whose fields are: @@ -261,8 +262,8 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 4c56eaedf..25ba0b287 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -71,7 +71,7 @@ class Maze(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -100,6 +100,7 @@ def __init__( self.generator = generator or RandomGenerator(num_rows=10, num_cols=10) self.num_rows = self.generator.num_rows self.num_cols = self.generator.num_cols + super().__init__() self.shape = (self.num_rows, self.num_cols) self.time_limit = time_limit or self.num_rows * self.num_cols @@ -117,7 +118,7 @@ def __repr__(self) -> str: ] ) - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Maze` environment. Returns: @@ -159,8 +160,8 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: action_spec: discrete action space with 4 values. diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index cac71ffe7..2e6f876f5 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -157,6 +157,7 @@ def __init__( self._env_viewer = viewer or MMSTViewer(num_agents=self.num_agents) self.time_limit = time_limit + super().__init__() def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -198,7 +199,6 @@ def step_agent_fn( indices: chex.Array, agent_id: int, ) -> Tuple[chex.Array, ...]: - is_invalid_choice = jnp.any(action == INVALID_CHOICE) | jnp.any( action == INVALID_TIE_BREAK ) @@ -273,8 +273,8 @@ def step_agent_fn( state, timestep = self._state_to_timestep(state, action) return state, timestep - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. Returns: action_spec: a `specs.MultiDiscreteArray` spec. @@ -284,8 +284,8 @@ def action_spec(self) -> specs.MultiDiscreteArray: name="action", ) - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for the `Observation` whose fields are: diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 23688829e..068222dda 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -127,6 +127,7 @@ def __init__( max_single_vehicle_distance(self._map_max, self._num_customers) / self._speed ) + super().__init__() def __repr__(self) -> str: return f"MultiCVRP(num_customers={self._num_customers}, num_vehicles={self._num_vehicles})" @@ -177,9 +178,9 @@ def step( return new_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """ - Returns the observation spec. + Returns new observation spec. Returns: observation_spec: a Tuple containing the spec for each of the constituent fields of an @@ -306,9 +307,9 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.BoundedArray: + def _make_action_spec(self) -> specs.BoundedArray: """ - Returns the action spec. + Returns new action spec. Returns: action_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index 908fe8053..1656139b4 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -127,7 +127,7 @@ class RobotWarehouse(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -182,6 +182,7 @@ def __init__( ) self.goals = self._generator.goals self.time_limit = time_limit + super().__init__() # create viewer for rendering environment self._viewer = viewer or RobotWarehouseViewer( @@ -334,7 +335,7 @@ def update_reward_and_request_queue_scan( ) return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `RobotWarehouse` environment. Returns: Spec for the `Observation`, consisting of the fields: @@ -357,8 +358,8 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.MultiDiscreteArray: - """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + def _make_action_spec(self) -> specs.MultiDiscreteArray: + """Returns new action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,). """ diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py index 8b5840756..cf37e3b2e 100644 --- a/jumanji/environments/routing/robot_warehouse/env_test.py +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -32,8 +32,8 @@ def test_robot_warehouse__specs(robot_warehouse_env: RobotWarehouse) -> None: """Validate environment specs conform to the expected shapes and values""" - action_spec = robot_warehouse_env.action_spec() - observation_spec = robot_warehouse_env.observation_spec() + action_spec = robot_warehouse_env.action_spec + observation_spec = robot_warehouse_env.observation_spec assert observation_spec.agents_view.shape == (2, 66) # type: ignore assert action_spec.num_values.shape[0] == robot_warehouse_env.num_agents diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index 0231c3a94..c1ab16883 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -84,7 +84,7 @@ class Snake(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -108,11 +108,11 @@ def __init__( the episode ends. Defaults to 4000. viewer: `Viewer` used for rendering. Defaults to `SnakeViewer`. """ - super().__init__() self.num_rows = num_rows self.num_cols = num_cols self.board_shape = (num_rows, num_cols) self.time_limit = time_limit + super().__init__() self._viewer = viewer or SnakeViewer() def __repr__(self) -> str: @@ -235,8 +235,8 @@ def step( ) return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for the `Observation` whose fields are: @@ -269,8 +269,8 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/routing/snake/env_test.py b/jumanji/environments/routing/snake/env_test.py index e684011f7..c7b0e9711 100644 --- a/jumanji/environments/routing/snake/env_test.py +++ b/jumanji/environments/routing/snake/env_test.py @@ -63,7 +63,7 @@ def test_snake__step(snake: Snake) -> None: # Sample two different actions action1, action2 = jax.random.choice( action_key, - jnp.arange(snake.action_spec()._num_values), + jnp.arange(snake.action_spec._num_values), shape=(2,), replace=False, ) @@ -148,7 +148,7 @@ def test_snake__render(monkeypatch: pytest.MonkeyPatch, snake: Snake) -> None: monkeypatch.setattr(plt, "show", lambda fig: None) step_fn = jax.jit(snake.step) state, timestep = snake.reset(jax.random.PRNGKey(0)) - action = snake.action_spec().generate_value() + action = snake.action_spec.generate_value() state, timestep = step_fn(state, action) snake.render(state) snake.close() @@ -159,7 +159,7 @@ def test_snake__animation(snake: Snake, tmpdir: py.path.local) -> None: step_fn = jax.jit(snake.step) state, _ = snake.reset(jax.random.PRNGKey(0)) states = [state] - action = snake.action_spec().generate_value() + action = snake.action_spec.generate_value() state, _ = step_fn(state, action) states.append(state) animation = snake.animate(states) diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 83236f872..82531a093 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -83,7 +83,7 @@ class TSP(Environment[State]): key = jax.random.key(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -112,6 +112,7 @@ def __init__( num_cities=20, ) self.num_cities = self.generator.num_cities + super().__init__() self.reward_fn = reward_fn or DenseReward() self._viewer = viewer or TSPViewer(name="TSP", render_mode="human") @@ -169,8 +170,8 @@ def step( ) return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Spec[Observation]: + """Returns new observation spec. Returns: Spec for the `Observation` whose fields are: @@ -212,8 +213,8 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def action_spec(self) -> specs.DiscreteArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.DiscreteArray: + """Returns new action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/registration_test.py b/jumanji/registration_test.py index 5ec06bd0a..218be5903 100644 --- a/jumanji/registration_test.py +++ b/jumanji/registration_test.py @@ -94,7 +94,7 @@ def test_register__override_kwargs(mocker: pytest_mock.MockerFixture) -> None: env: FakeEnvironment = registration.make( # type: ignore env_id, observation_shape=obs_shape ) - assert env.observation_spec().shape == obs_shape + assert env.observation_spec.shape == obs_shape def test_registration__make() -> None: diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 411a0f80c..6c0617ee5 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -72,17 +72,16 @@ def check_env_does_not_smoke( assert_finite_check: bool = True, ) -> None: """Run an episode of the environment, with a jitted step function to check no errors occur.""" - action_spec = env.action_spec() if select_action is None: - if isinstance(action_spec, specs.BoundedArray) or isinstance( - action_spec, specs.DiscreteArray + if isinstance(env.action_spec, specs.BoundedArray) or isinstance( + env.action_spec, specs.DiscreteArray ): - select_action = make_random_select_action_fn(action_spec) + select_action = make_random_select_action_fn(env.action_spec) else: raise NotImplementedError( f"Currently the `make_random_select_action_fn` only works for environments with " f"either discrete actions or bounded continuous actions. The input environment to " - f"this test has an action spec of type {action_spec}, and therefore requires " + f"this test has an action spec of type {env.action_spec}, and therefore requires " f"a custom `SelectActionFn` to be provided to this test." ) key = jax.random.PRNGKey(0) @@ -92,16 +91,21 @@ def check_env_does_not_smoke( while not timestep.last(): key, action_key = jax.random.split(key) action = select_action(action_key, timestep.observation) - env.action_spec().validate(action) + env.action_spec.validate(action) state, timestep = step_fn(state, action) - env.observation_spec().validate(timestep.observation) + env.observation_spec.validate(timestep.observation) if assert_finite_check: chex.assert_tree_all_finite((state, timestep)) +def access_specs(env: Environment) -> None: + """Access specs of the environment.""" + env.observation_spec + env.action_spec + env.reward_spec + env.discount_spec + + def check_env_specs_does_not_smoke(env: Environment) -> None: """Access specs of the environment in a jitted function to check no errors occur.""" - jax.jit(env.observation_spec()) - jax.jit(env.action_spec()) - jax.jit(env.reward_spec()) - jax.jit(env.discount_spec()) + jax.jit(access_specs, static_argnums=0)(env) diff --git a/jumanji/testing/env_not_smoke_test.py b/jumanji/testing/env_not_smoke_test.py index ca8f0a28a..8900aaaeb 100644 --- a/jumanji/testing/env_not_smoke_test.py +++ b/jumanji/testing/env_not_smoke_test.py @@ -55,12 +55,12 @@ def test_random_select_action(fake_env: FakeEnvironment) -> None: """Validate that the `select_action` method returns random actions meeting the environment spec.""" key = jax.random.PRNGKey(0) - select_action = make_random_select_action_fn(fake_env.action_spec()) + select_action = make_random_select_action_fn(fake_env.action_spec) key1, key2, key3 = jax.random.split(key, 3) env_state, timestep = fake_env.reset(key1) action_1 = select_action(key2, timestep.observation) action_2 = select_action(key3, timestep.observation) - fake_env.action_spec().validate(action_1) + fake_env.action_spec.validate(action_1) assert not jnp.all(action_1 == action_2) diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index 0835a47ac..01cb689c1 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -56,10 +56,11 @@ def __init__( self.time_limit = time_limit self.observation_shape = observation_shape self.action_shape = action_shape - self._example_action = self.action_spec().generate_value() + super().__init__() + self._example_action = self.action_spec.generate_value() - def observation_spec(self) -> specs.Array: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Array: + """Returns new observation spec. Returns: observation_spec: a `specs.Array` spec. @@ -69,8 +70,8 @@ def observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) - def action_spec(self) -> specs.BoundedArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.BoundedArray: + """Returns new action spec. Returns: action_spec: a `specs.DiscreteArray` spec. @@ -169,14 +170,15 @@ def __init__( self.observation_shape = observation_shape self.num_action_values = num_action_values self.num_agents = num_agents + super().__init__() self.reward_per_step = reward_per_step assert ( observation_shape[0] == num_agents ), f"""a leading dimension of size 'num_agents': {num_agents} is expected for the observation, got shape: {observation_shape}.""" - def observation_spec(self) -> specs.Array: - """Returns the observation spec. + def _make_observation_spec(self) -> specs.Array: + """Returns new observation spec. Returns: observation_spec: a `specs.Array` spec. @@ -186,8 +188,8 @@ def observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) - def action_spec(self) -> specs.BoundedArray: - """Returns the action spec. + def _make_action_spec(self) -> specs.BoundedArray: + """Returns new action spec. Returns: action_spec: a `specs.Array` spec. @@ -197,15 +199,15 @@ def action_spec(self) -> specs.BoundedArray: (self.num_agents,), int, 0, self.num_action_values - 1 ) - def reward_spec(self) -> specs.Array: - """Returns the reward spec. + def _make_reward_spec(self) -> specs.Array: + """Returns new reward spec. Returns: reward_spec: a `specs.Array` spec. """ return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") - def discount_spec(self) -> specs.BoundedArray: + def _make_discount_spec(self) -> specs.BoundedArray: """Describes the discount returned by the environment. Returns: @@ -231,7 +233,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[FakeState, TimeStep]: """ state = FakeState(key=key, step=0) - observation = self.observation_spec().generate_value() + observation = self.observation_spec.generate_value() timestep = restart(observation=observation, shape=(self.num_agents,)) return state, timestep diff --git a/jumanji/testing/fakes_test.py b/jumanji/testing/fakes_test.py index 42cb98415..81d75dd16 100644 --- a/jumanji/testing/fakes_test.py +++ b/jumanji/testing/fakes_test.py @@ -31,7 +31,7 @@ def test_fake_environment__reset(fake_environment: fakes.FakeEnvironment) -> Non def test_fake_environment__step(fake_environment: fakes.FakeEnvironment) -> None: """Validates the step function of the fake environment.""" state, timestep = fake_environment.reset(random.PRNGKey(0)) - action = fake_environment.action_spec().generate_value() + action = fake_environment.action_spec.generate_value() next_state, timestep = fake_environment.step(state, action) # Check that the step value is now different assert state.step != next_state.step @@ -43,7 +43,7 @@ def test_fake_environment__does_not_smoke( ) -> None: """Validates the run of an episode in the fake environment. Check that it does not smoke.""" state, timestep = fake_environment.reset(random.PRNGKey(0)) - action = fake_environment.action_spec().generate_value() + action = fake_environment.action_spec.generate_value() while not timestep.last(): state, timestep = fake_environment.step(state, action) @@ -67,7 +67,7 @@ def test_fake_multi_environment__step( ) -> None: """Validates the step function of the fake multi agent environment.""" state, timestep = fake_multi_environment.reset(random.PRNGKey(0)) - action = fake_multi_environment.action_spec().generate_value() + action = fake_multi_environment.action_spec.generate_value() assert action.shape[0] == fake_multi_environment.num_agents next_state, timestep = fake_multi_environment.step(state, action) @@ -85,7 +85,7 @@ def test_fake_multi_environment__does_not_smoke( """Validates the run of an episode in the fake multi agent environment. Check that it does not smoke.""" state, timestep = fake_multi_environment.reset(random.PRNGKey(0)) - action = fake_multi_environment.action_spec().generate_value() + action = fake_multi_environment.action_spec.generate_value() assert action.shape[0] == fake_multi_environment.num_agents while not timestep.last(): state, timestep = fake_multi_environment.step(state, action) diff --git a/jumanji/training/agents/a2c/a2c_agent.py b/jumanji/training/agents/a2c/a2c_agent.py index 09eca211b..2392ab05c 100644 --- a/jumanji/training/agents/a2c/a2c_agent.py +++ b/jumanji/training/agents/a2c/a2c_agent.py @@ -51,7 +51,7 @@ def __init__( ) -> None: super().__init__(total_batch_size=total_batch_size) self.env = env - self.observation_spec = env.observation_spec() + self.observation_spec = env.observation_spec self.n_steps = n_steps self.actor_critic_networks = actor_critic_networks self.optimizer = optimizer diff --git a/jumanji/training/agents/random/random_agent.py b/jumanji/training/agents/random/random_agent.py index 4c17edd48..904fbe98f 100644 --- a/jumanji/training/agents/random/random_agent.py +++ b/jumanji/training/agents/random/random_agent.py @@ -33,7 +33,7 @@ def __init__( ) -> None: super().__init__(total_batch_size=total_batch_size) self.env = env - self.observation_spec = env.observation_spec() + self.observation_spec = env.observation_spec self.n_steps = n_steps self.random_policy = random_policy diff --git a/jumanji/training/networks/bin_pack/actor_critic.py b/jumanji/training/networks/bin_pack/actor_critic.py index 1643a5dc8..c3de93f8f 100644 --- a/jumanji/training/networks/bin_pack/actor_critic.py +++ b/jumanji/training/networks/bin_pack/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_bin_pack( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `BinPack` environment.""" - num_values = np.asarray(bin_pack.action_spec().num_values) + num_values = np.asarray(bin_pack.action_spec.num_values) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=num_values ) diff --git a/jumanji/training/networks/bin_pack/random.py b/jumanji/training/networks/bin_pack/random.py index 537772239..add32d75a 100644 --- a/jumanji/training/networks/bin_pack/random.py +++ b/jumanji/training/networks/bin_pack/random.py @@ -21,7 +21,7 @@ def make_random_policy_bin_pack(bin_pack: BinPack) -> RandomPolicy: """Make random policy for BinPack.""" - action_spec_num_values = bin_pack.action_spec().num_values + action_spec_num_values = bin_pack.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values ) diff --git a/jumanji/training/networks/cleaner/actor_critic.py b/jumanji/training/networks/cleaner/actor_critic.py index b8002df3b..2fedfe289 100644 --- a/jumanji/training/networks/cleaner/actor_critic.py +++ b/jumanji/training/networks/cleaner/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_cleaner( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Cleaner` environment.""" - num_values = np.asarray(cleaner.action_spec().num_values) + num_values = np.asarray(cleaner.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/connector/actor_critic.py b/jumanji/training/networks/connector/actor_critic.py index 1473de23b..edcc4946f 100644 --- a/jumanji/training/networks/connector/actor_critic.py +++ b/jumanji/training/networks/connector/actor_critic.py @@ -46,7 +46,7 @@ def make_actor_critic_networks_connector( conv_n_channels: int, ) -> ActorCriticNetworks: """Make actor-critic networks for the `Connector` environment.""" - num_values = np.asarray(connector.action_spec().num_values) + num_values = np.asarray(connector.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/cvrp/actor_critic.py b/jumanji/training/networks/cvrp/actor_critic.py index 5bae24f7f..e5b498a2a 100644 --- a/jumanji/training/networks/cvrp/actor_critic.py +++ b/jumanji/training/networks/cvrp/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_cvrp( mean_nodes_in_query: bool, ) -> ActorCriticNetworks: """Make actor-critic networks for the `CVRP` environment.""" - num_actions = cvrp.action_spec().num_values + num_actions = cvrp.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/game_2048/actor_critic.py b/jumanji/training/networks/game_2048/actor_critic.py index 6a7d66fa5..caa4351dd 100644 --- a/jumanji/training/networks/game_2048/actor_critic.py +++ b/jumanji/training/networks/game_2048/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_game_2048( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Game2048` environment.""" - num_actions = game_2048.action_spec().num_values + num_actions = game_2048.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py index 2833061c0..6e2e336f6 100644 --- a/jumanji/training/networks/graph_coloring/actor_critic.py +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_graph_coloring( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `GraphColoring` environment.""" - num_actions = graph_coloring.action_spec().num_values + num_actions = graph_coloring.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/job_shop/actor_critic.py b/jumanji/training/networks/job_shop/actor_critic.py index 0c09c8bf6..77d070a17 100644 --- a/jumanji/training/networks/job_shop/actor_critic.py +++ b/jumanji/training/networks/job_shop/actor_critic.py @@ -42,7 +42,7 @@ def make_actor_critic_networks_job_shop( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Create an actor-critic network for the `JobShop` environment.""" - num_values = np.asarray(job_shop.action_spec().num_values) + num_values = np.asarray(job_shop.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/knapsack/actor_critic.py b/jumanji/training/networks/knapsack/actor_critic.py index 799a08e4d..b8a676e23 100644 --- a/jumanji/training/networks/knapsack/actor_critic.py +++ b/jumanji/training/networks/knapsack/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_knapsack( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Knapsack` environment.""" - num_actions = knapsack.action_spec().num_values + num_actions = knapsack.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/maze/actor_critic.py b/jumanji/training/networks/maze/actor_critic.py index ae93b0286..8d236ef5a 100644 --- a/jumanji/training/networks/maze/actor_critic.py +++ b/jumanji/training/networks/maze/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_maze( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Maze` environment.""" - num_actions = np.asarray(maze.action_spec().num_values) + num_actions = np.asarray(maze.action_spec.num_values) parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index 3789be5f3..593673f9f 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -45,7 +45,7 @@ def make_actor_critic_networks_minesweeper( vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( - action_spec_num_values=np.asarray(minesweeper.action_spec().num_values) + action_spec_num_values=np.asarray(minesweeper.action_spec.num_values) ) policy_network = make_network_cnn( vocab_size=vocab_size, diff --git a/jumanji/training/networks/minesweeper/random.py b/jumanji/training/networks/minesweeper/random.py index b7e80a3de..c7194091f 100644 --- a/jumanji/training/networks/minesweeper/random.py +++ b/jumanji/training/networks/minesweeper/random.py @@ -22,7 +22,7 @@ def make_random_policy_minesweeper(minesweeper: Minesweeper) -> RandomPolicy: """Make random policy for Minesweeper.""" - action_spec_num_values = minesweeper.action_spec().num_values + action_spec_num_values = minesweeper.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/mmst/actor_critic.py b/jumanji/training/networks/mmst/actor_critic.py index ccc5d5674..45e776b4c 100644 --- a/jumanji/training/networks/mmst/actor_critic.py +++ b/jumanji/training/networks/mmst/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_mmst( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `MMST` environment.""" - num_values = mmst.action_spec().num_values + num_values = mmst.action_spec.num_values parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) @@ -96,12 +96,10 @@ def get_node_feats(node: chex.Array) -> chex.Array: return embeddings def embed_agents(self, agents: chex.Array) -> chex.Array: - embeddings = hk.Linear(self.model_size, name="agent_projection")(agents) return embeddings def __call__(self, observation: Observation) -> chex.Array: - batch_size, num_nodes = observation.node_types.shape num_agents = observation.positions.shape[1] agents_used = jnp.arange(num_agents).reshape(-1, 1) diff --git a/jumanji/training/networks/multi_cvrp/actor_critic.py b/jumanji/training/networks/multi_cvrp/actor_critic.py index 268426149..3300b7835 100644 --- a/jumanji/training/networks/multi_cvrp/actor_critic.py +++ b/jumanji/training/networks/multi_cvrp/actor_critic.py @@ -46,7 +46,7 @@ def make_actor_critic_networks_multicvrp( # Add depot to the number of customers num_customers += 1 - num_actions = MultiCVRP.action_spec().maximum + num_actions = MultiCVRP.action_spec.maximum parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=np.asarray(num_actions).reshape(1) ) @@ -161,7 +161,6 @@ def customer_encoder( o_customers: chex.Array, v_embedding: chex.Array, ) -> chex.Array: - # Embed the depot differently # (B, C, D) depot_projection = hk.Linear(self.model_size, name="depot_projection") @@ -211,7 +210,6 @@ def vehicle_encoder( v_embedding: chex.Array, c_embedding: chex.Array, ) -> chex.Array: - # Projection of the operations embeddings = hk.Linear(self.model_size, name="o_vehicle_projections")( v_embedding diff --git a/jumanji/training/networks/robot_warehouse/actor_critic.py b/jumanji/training/networks/robot_warehouse/actor_critic.py index a1aca10cd..965caf397 100644 --- a/jumanji/training/networks/robot_warehouse/actor_critic.py +++ b/jumanji/training/networks/robot_warehouse/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_robot_warehouse( transformer_mlp_units: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `RobotWarehouse` environment.""" - num_values = np.asarray(robot_warehouse.action_spec().num_values) + num_values = np.asarray(robot_warehouse.action_spec.num_values) parametric_action_distribution = MultiCategoricalParametricDistribution( num_values=num_values ) diff --git a/jumanji/training/networks/rubiks_cube/actor_critic.py b/jumanji/training/networks/rubiks_cube/actor_critic.py index 5e79d9e38..53a2643ac 100644 --- a/jumanji/training/networks/rubiks_cube/actor_critic.py +++ b/jumanji/training/networks/rubiks_cube/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_rubiks_cube( dense_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `RubiksCube` environment.""" - action_spec_num_values = np.asarray(rubiks_cube.action_spec().num_values) + action_spec_num_values = np.asarray(rubiks_cube.action_spec.num_values) num_actions = int(np.prod(action_spec_num_values)) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/rubiks_cube/random.py b/jumanji/training/networks/rubiks_cube/random.py index b8d18dd0d..3040db114 100644 --- a/jumanji/training/networks/rubiks_cube/random.py +++ b/jumanji/training/networks/rubiks_cube/random.py @@ -21,8 +21,8 @@ def make_random_policy_rubiks_cube(rubiks_cube: RubiksCube) -> RandomPolicy: """Make random policy for RubiksCube.""" - action_minimum = rubiks_cube.action_spec().minimum - action_maximum = rubiks_cube.action_spec().maximum + action_minimum = rubiks_cube.action_spec.minimum + action_maximum = rubiks_cube.action_spec.maximum def random_policy(observation: Observation, key: chex.PRNGKey) -> chex.Array: batch_size = observation.cube.shape[0] diff --git a/jumanji/training/networks/snake/actor_critic.py b/jumanji/training/networks/snake/actor_critic.py index f23906154..0be42e223 100644 --- a/jumanji/training/networks/snake/actor_critic.py +++ b/jumanji/training/networks/snake/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_snake( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Snake` environment.""" - num_actions = snake.action_spec().num_values + num_actions = snake.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/training/networks/sudoku/actor_critic.py b/jumanji/training/networks/sudoku/actor_critic.py index f3bd1092b..8ba664e65 100644 --- a/jumanji/training/networks/sudoku/actor_critic.py +++ b/jumanji/training/networks/sudoku/actor_critic.py @@ -40,7 +40,7 @@ def make_cnn_actor_critic_networks_sudoku( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sudoku` environment. Uses the CNN network architecture.""" - num_actions = sudoku.action_spec().num_values + num_actions = sudoku.action_spec.num_values parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=np.asarray(num_actions) ) @@ -71,7 +71,7 @@ def make_equivariant_actor_critic_networks_sudoku( ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sudoku` environment. Uses the digits-permutation equivariant network architecture.""" - num_actions = sudoku.action_spec().num_values + num_actions = sudoku.action_spec.num_values parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=np.asarray(num_actions) ) diff --git a/jumanji/training/networks/sudoku/random.py b/jumanji/training/networks/sudoku/random.py index 3b394d7b2..e2cce1fde 100644 --- a/jumanji/training/networks/sudoku/random.py +++ b/jumanji/training/networks/sudoku/random.py @@ -23,7 +23,7 @@ def make_random_policy_sudoku(sudoku: Sudoku) -> RandomPolicy: """Make random policy for the `Sudoku` environment.""" - action_spec_num_values = sudoku.action_spec().num_values + action_spec_num_values = sudoku.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values diff --git a/jumanji/training/networks/tetris/actor_critic.py b/jumanji/training/networks/tetris/actor_critic.py index 5ac0488de..4e37052fd 100644 --- a/jumanji/training/networks/tetris/actor_critic.py +++ b/jumanji/training/networks/tetris/actor_critic.py @@ -39,7 +39,7 @@ def make_actor_critic_networks_tetris( """Make actor-critic networks for the `Tetris` environment.""" parametric_action_distribution = FactorisedActionSpaceParametricDistribution( - action_spec_num_values=np.asarray(tetris.action_spec().num_values) + action_spec_num_values=np.asarray(tetris.action_spec.num_values) ) policy_network = make_network_cnn( conv_num_channels=conv_num_channels, diff --git a/jumanji/training/networks/tetris/random.py b/jumanji/training/networks/tetris/random.py index e7410f35c..309687d4a 100644 --- a/jumanji/training/networks/tetris/random.py +++ b/jumanji/training/networks/tetris/random.py @@ -21,7 +21,7 @@ def make_random_policy_tetris(tetris: Tetris) -> RandomPolicy: """Make random policy for `Tetris`.""" - action_spec_num_values = tetris.action_spec().num_values + action_spec_num_values = tetris.action_spec.m_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values ) diff --git a/jumanji/training/networks/tsp/actor_critic.py b/jumanji/training/networks/tsp/actor_critic.py index 6b0761411..cff891c5e 100644 --- a/jumanji/training/networks/tsp/actor_critic.py +++ b/jumanji/training/networks/tsp/actor_critic.py @@ -38,7 +38,7 @@ def make_actor_critic_networks_tsp( mean_cities_in_query: bool, ) -> ActorCriticNetworks: """Make actor-critic networks for the `TSP` environment.""" - num_actions = tsp.action_spec().num_values + num_actions = tsp.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 98d1ec91d..a37b32237 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -48,8 +48,8 @@ class Wrapper(Environment[State], Generic[State]): """ def __init__(self, env: Environment): - super().__init__() self._env = env + super().__init__() def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self._env)})" @@ -89,21 +89,21 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ return self._env.step(state, action) - def observation_spec(self) -> specs.Spec: + def _make_observation_spec(self) -> specs.Spec: """Returns the observation spec.""" - return self._env.observation_spec() + return self._env._make_observation_spec() - def action_spec(self) -> specs.Spec: + def _make_action_spec(self) -> specs.Spec: """Returns the action spec.""" - return self._env.action_spec() + return self._env._make_action_spec() - def reward_spec(self) -> specs.Array: + def _make_reward_spec(self) -> specs.Array: """Returns the reward spec.""" - return self._env.reward_spec() + return self._env._make_reward_spec() - def discount_spec(self) -> specs.BoundedArray: + def _make_discount_spec(self) -> specs.BoundedArray: """Returns the discount spec.""" - return self._env.discount_spec() + return self._env._make_discount_spec() def render(self, state: State) -> Any: """Compute render frames during initialisation of the environment. @@ -165,7 +165,7 @@ def reset(self) -> dm_env.TimeStep: - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the - specification returned by `observation_spec()`. + specification returned by `observation_spec`. """ reset_key, self._key = jax.random.split(self._key) self._state, timestep = self._jitted_reset(reset_key) @@ -184,21 +184,21 @@ def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep: Args: action: A NumPy array, or a nested dict, list or tuple of arrays - corresponding to `action_spec()`. + corresponding to `action_spec`. Returns: A `TimeStep` namedtuple containing: - step_type: A `StepType` value. - reward: Reward at this timestep, or None if step_type is `StepType.FIRST`. Must conform to the specification returned by - `reward_spec()`. + `reward_spec`. - discount: A discount in the range [0, 1], or None if step_type is `StepType.FIRST`. Must conform to the specification returned by - `discount_spec()`. + `discount_spec`. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the - specification returned by `observation_spec()`. + specification returned by `observation_spec`. """ self._state, timestep = self._jitted_step(self._state, action) return dm_env.TimeStep( @@ -210,11 +210,11 @@ def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep: def observation_spec(self) -> dm_env.specs.Array: """Returns the dm_env observation spec.""" - return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec()) + return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec) def action_spec(self) -> dm_env.specs.Array: """Returns the dm_env action spec.""" - return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec()) + return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec) @property def unwrapped(self) -> Environment: @@ -540,9 +540,9 @@ def __init__(self, env: Environment, seed: int = 0, backend: Optional[str] = Non self.backend = backend self._state = None self.observation_space = specs.jumanji_specs_to_gym_spaces( - self._env.observation_spec() + self._env.observation_spec ) - self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec()) + self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec) def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]: """Reset function of a Jumanji environment to be jitted.""" diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 4dc339694..e68964325 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -109,37 +109,39 @@ def test_wrapper__reset( mock_reset.assert_called_once_with(mock_key) - def test_wrapper__observation_spec( + def test_wrapper__make_observation_spec( self, mocker: pytest_mock.MockerFixture, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper.observation_spec` calls the observation_spec function of + """Checks `Wrapper._make_observation_spec` calls the _make_observation_spec function of the underlying env. """ - mock_obs_spec = mocker.patch.object( - fake_environment, "observation_spec", autospec=True + mock_make_obs_spec = mocker.patch.object( + fake_environment, "_make_observation_spec", autospec=True ) - wrapped_fake_environment.observation_spec() + wrapped_fake_environment._make_observation_spec() - mock_obs_spec.assert_called_once() + mock_make_obs_spec.assert_called_once() - def test_wrapper__action_spec( + def test_wrapper__make_action_spec( self, mocker: pytest_mock.MockerFixture, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper.action_spec` calls the action_spec function of the underlying env.""" - mock_action_spec = mocker.patch.object( - fake_environment, "action_spec", autospec=True + """Checks `Wrapper._make_action_spec` calls the _make_action_spec function of the underlying + env. + """ + mock_make_action_spec = mocker.patch.object( + fake_environment, "_make_action_spec", autospec=True ) - wrapped_fake_environment.action_spec() + wrapped_fake_environment._make_action_spec() - mock_action_spec.assert_called_once() + mock_make_action_spec.assert_called_once() def test_wrapper__repr(self, wrapped_fake_environment: Wrapper) -> None: """Checks `Wrapper.__repr__` returns the expected representation string.""" @@ -302,7 +304,6 @@ def test_jumanji_environment_to_gym_env__render( mocker: pytest_mock.MockerFixture, fake_gym_env: JumanjiToGymWrapper, ) -> None: - mock_render = mocker.patch.object( fake_gym_env.unwrapped, "render", autospec=True ) @@ -317,7 +318,6 @@ def test_jumanji_environment_to_gym_env__close( mocker: pytest_mock.MockerFixture, fake_gym_env: JumanjiToGymWrapper, ) -> None: - mock_close = mocker.patch.object(fake_gym_env.unwrapped, "close", autospec=True) fake_gym_env.close() @@ -375,7 +375,7 @@ def test_multi_env__step( agent wrapped environment. """ state, timestep = fake_multi_to_single_env.reset(key) # type: ignore - action = fake_multi_to_single_env.action_spec().generate_value() + action = fake_multi_to_single_env.action_spec.generate_value() state, next_timestep = jax.jit(fake_multi_to_single_env.step)(state, action) assert next_timestep != timestep assert next_timestep.reward.shape == () @@ -398,7 +398,7 @@ def test_multi_env__different_reward_aggregator( fake_multi_environment, reward_aggregator=jnp.mean ) state, timestep = mean_fake_multi_to_single_env.reset(key) # type: ignore - action = mean_fake_multi_to_single_env.action_spec().generate_value() + action = mean_fake_multi_to_single_env.action_spec.generate_value() state, next_timestep = mean_fake_multi_to_single_env.step( state, action ) # type: Tuple[FakeState, TimeStep] @@ -416,9 +416,9 @@ def test_multi_env__observation_spec( """Validates observation_spec property of the multi agent to single agent wrapped environment. """ - obs_spec: specs.Array = fake_multi_to_single_env.observation_spec() # type: ignore + obs_spec: specs.Array = fake_multi_to_single_env.observation_spec # type: ignore assert isinstance(obs_spec, specs.Array) - assert obs_spec.shape == fake_multi_environment.observation_spec().shape + assert obs_spec.shape == fake_multi_environment.observation_spec.shape def test_multi_env__action_spec( self, @@ -428,9 +428,9 @@ def test_multi_env__action_spec( """Validates action_spec property of the multi agent to single agent wrapped environment. """ - action_spec: specs.Array = fake_multi_to_single_env.action_spec() # type: ignore - assert isinstance(fake_multi_to_single_env.action_spec(), specs.Array) - assert action_spec.shape == fake_multi_environment.action_spec().shape + action_spec: specs.Array = fake_multi_to_single_env.action_spec # type: ignore + assert isinstance(fake_multi_to_single_env.action_spec, specs.Array) + assert action_spec.shape == fake_multi_environment.action_spec.shape def test_multi_env__unwrapped( self, @@ -473,9 +473,9 @@ def test_vmap_env__step( state, timestep = fake_vmap_environment.reset( keys ) # type: Tuple[FakeState, TimeStep] - action = jax.vmap( - lambda _: fake_vmap_environment.action_spec().generate_value() - )(keys) + action = jax.vmap(lambda _: fake_vmap_environment.action_spec.generate_value())( + keys + ) state, next_timestep = jax.jit(fake_vmap_environment.step)( state, action @@ -547,7 +547,7 @@ def test_auto_reset_wrapper__step_no_reset( ) # type: Tuple[FakeState, TimeStep] # Generate an action - action = fake_auto_reset_environment.action_spec().generate_value() + action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)( state, action @@ -573,7 +573,7 @@ def test_auto_reset_wrapper__step_reset( # Loop across time_limit so auto-reset occurs timestep = first_timestep for _ in range(fake_environment.time_limit): - action = fake_auto_reset_environment.action_spec().generate_value() + action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) assert timestep.step_type == StepType.LAST @@ -592,7 +592,7 @@ def action( self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey ) -> chex.Array: generate_action_fn = ( - lambda _: fake_vmap_auto_reset_environment.action_spec().generate_value() + lambda _: fake_vmap_auto_reset_environment.action_spec.generate_value() ) return jax.vmap(generate_action_fn)(keys) @@ -765,9 +765,8 @@ def test_jumanji_to_gym_obs__wrong_observation(self) -> None: def test_jumanji_to_gym_obs__bin_pack(self) -> None: """Check that an example bin_pack observation is correctly converted.""" - env = BinPack(obs_num_ems=1) - env.generator = bin_pack_conftest.DummyGenerator() - obs = env.observation_spec().generate_value() + env = BinPack(generator=bin_pack_conftest.DummyGenerator(), obs_num_ems=1) + obs = env.observation_spec.generate_value() converted_obs = jumanji_to_gym_obs(obs) correct_obs = { From 69aefba868fb03bcd9a9afe047d7c560f7b5223a Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Sat, 20 Jan 2024 05:56:40 +0000 Subject: [PATCH 03/16] feat: implement generic typevar ActionSpec on Environment --- jumanji/env.py | 7 ++++--- jumanji/environments/logic/game_2048/env.py | 2 +- jumanji/environments/logic/graph_coloring/env.py | 2 +- jumanji/environments/logic/minesweeper/env.py | 2 +- jumanji/environments/logic/rubiks_cube/env.py | 2 +- jumanji/environments/logic/sudoku/env.py | 2 +- jumanji/environments/packing/bin_pack/env.py | 2 +- jumanji/environments/packing/job_shop/env.py | 2 +- jumanji/environments/packing/knapsack/env.py | 2 +- jumanji/environments/packing/tetris/env.py | 2 +- jumanji/environments/routing/cleaner/env.py | 2 +- jumanji/environments/routing/connector/env.py | 2 +- jumanji/environments/routing/cvrp/env.py | 2 +- jumanji/environments/routing/maze/env.py | 2 +- jumanji/environments/routing/mmst/env.py | 2 +- jumanji/environments/routing/multi_cvrp/env.py | 2 +- jumanji/environments/routing/robot_warehouse/env.py | 2 +- jumanji/environments/routing/snake/env.py | 2 +- jumanji/environments/routing/tsp/env.py | 2 +- jumanji/registration_test.py | 8 ++++---- jumanji/testing/fakes.py | 4 ++-- jumanji/training/networks/tetris/random.py | 2 +- jumanji/wrappers.py | 8 ++++---- jumanji/wrappers_test.py | 7 ++++--- 24 files changed, 37 insertions(+), 35 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index aaabe5471..64673ccf1 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -33,9 +33,10 @@ class StateProtocol(Protocol): State = TypeVar("State", bound="StateProtocol") +ActionSpec = TypeVar("ActionSpec", bound=specs.Array) -class Environment(abc.ABC, Generic[State]): +class Environment(abc.ABC, Generic[State, ActionSpec]): """Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. @@ -95,7 +96,7 @@ def _make_observation_spec(self) -> specs.Spec: """ @property - def action_spec(self) -> specs.Spec: + def action_spec(self) -> ActionSpec: """Returns the action spec. Returns: @@ -104,7 +105,7 @@ def action_spec(self) -> specs.Spec: return self._action_spec @abc.abstractmethod - def _make_action_spec(self) -> specs.Spec: + def _make_action_spec(self) -> ActionSpec: """Returns new action spec. Returns: diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 36e3cdd37..1e4c253ec 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class Game2048(Environment[State]): +class Game2048(Environment[State, specs.DiscreteArray]): """Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index c2bf663bf..0ad435dc8 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -33,7 +33,7 @@ from jumanji.viewer import Viewer -class GraphColoring(Environment[State]): +class GraphColoring(Environment[State, specs.DiscreteArray]): """Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index eddfe6202..d4b40797d 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -36,7 +36,7 @@ from jumanji.viewer import Viewer -class Minesweeper(Environment[State]): +class Minesweeper(Environment[State, specs.MultiDiscreteArray]): """A JAX implementation of the minesweeper game. - observation: `Observation` diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 5dc5be8d8..3f916f741 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -42,7 +42,7 @@ from jumanji.viewer import Viewer -class RubiksCube(Environment[State]): +class RubiksCube(Environment[State, specs.MultiDiscreteArray]): """A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 03c222af1..65eb43778 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -32,7 +32,7 @@ from jumanji.viewer import Viewer -class Sudoku(Environment[State]): +class Sudoku(Environment[State, specs.MultiDiscreteArray]): """A JAX implementation of the sudoku game. - observation: `Observation` diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 3e83ce8f3..bf8b679c1 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -43,7 +43,7 @@ from jumanji.viewer import Viewer -class BinPack(Environment[State]): +class BinPack(Environment[State, specs.MultiDiscreteArray]): """Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 98c4686b3..213b2d205 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class JobShop(Environment[State]): +class JobShop(Environment[State, specs.MultiDiscreteArray]): """The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given `num_jobs` jobs, each consisting of at most `max_num_ops` ops, which need to be processed on `num_machines` machines. diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index e063d29ab..cb9a275fc 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Knapsack(Environment[State]): +class Knapsack(Environment[State, specs.DiscreteArray]): """Knapsack environment as described in [1]. - observation: Observation diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 693035d6d..b0e0612cf 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -34,7 +34,7 @@ from jumanji.viewer import Viewer -class Tetris(Environment[State]): +class Tetris(Environment[State, specs.MultiDiscreteArray]): """RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 4dfb12a18..5122d6f91 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Cleaner(Environment[State]): +class Cleaner(Environment[State, specs.MultiDiscreteArray]): """A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index 653a63c26..0784cc963 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -46,7 +46,7 @@ from jumanji.viewer import Viewer -class Connector(Environment[State]): +class Connector(Environment[State, specs.MultiDiscreteArray]): """The `Connector` environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index 1f10a0467..0e7d612d6 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class CVRP(Environment[State]): +class CVRP(Environment[State, specs.DiscreteArray]): """Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. - observation: `Observation` diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 25ba0b287..3c654324b 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Maze(Environment[State]): +class Maze(Environment[State, specs.DiscreteArray]): """A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index 2e6f876f5..8997a1c0a 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -42,7 +42,7 @@ from jumanji.viewer import Viewer -class MMST(Environment[State]): +class MMST(Environment[State, specs.MultiDiscreteArray]): """The `MMST` (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 068222dda..afc781e21 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -48,7 +48,7 @@ from jumanji.viewer import Viewer -class MultiCVRP(Environment[State]): +class MultiCVRP(Environment[State, specs.BoundedArray]): """ Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index 1656139b4..f13c79170 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -48,7 +48,7 @@ from jumanji.viewer import Viewer -class RobotWarehouse(Environment[State]): +class RobotWarehouse(Environment[State, specs.MultiDiscreteArray]): """A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index c1ab16883..9fde458d5 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class Snake(Environment[State]): +class Snake(Environment[State, specs.DiscreteArray]): """A JAX implementation of the 'Snake' game. - observation: `Observation` diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 82531a093..009eee764 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -31,7 +31,7 @@ from jumanji.viewer import Viewer -class TSP(Environment[State]): +class TSP(Environment[State, specs.DiscreteArray]): """Traveling Salesman Problem (TSP) environment as described in [1]. - observation: Observation diff --git a/jumanji/registration_test.py b/jumanji/registration_test.py index 218be5903..a37e05935 100644 --- a/jumanji/registration_test.py +++ b/jumanji/registration_test.py @@ -18,7 +18,7 @@ import pytest_mock import jumanji -from jumanji import registration +from jumanji import registration, specs from jumanji.testing.fakes import FakeEnvironment @@ -91,10 +91,10 @@ def test_register__override_kwargs(mocker: pytest_mock.MockerFixture) -> None: id=env_id, entry_point="jumanji.testing.fakes:FakeEnvironment", ) - env: FakeEnvironment = registration.make( # type: ignore + obs_spec: specs.Array = registration.make( # type: ignore env_id, observation_shape=obs_shape - ) - assert env.observation_spec.shape == obs_shape + ).observation_spec + assert obs_spec.shape == obs_shape def test_registration__make() -> None: diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index 01cb689c1..8925233de 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -34,7 +34,7 @@ class FakeState: step: jnp.int32 -class FakeEnvironment(Environment[FakeState]): +class FakeEnvironment(Environment[FakeState, specs.BoundedArray]): """ A fake environment that inherits from Environment, for testing purposes. The observation is an array full of `state.step` of shape `(self.observation_shape,)` @@ -143,7 +143,7 @@ def _state_to_obs(self, state: FakeState) -> chex.Array: return state.step * jnp.ones(self.observation_shape, float) -class FakeMultiEnvironment(Environment[FakeState]): +class FakeMultiEnvironment(Environment[FakeState, specs.BoundedArray]): """ A fake multi agent environment that inherits from Environment, for testing purposes. """ diff --git a/jumanji/training/networks/tetris/random.py b/jumanji/training/networks/tetris/random.py index 309687d4a..eef995792 100644 --- a/jumanji/training/networks/tetris/random.py +++ b/jumanji/training/networks/tetris/random.py @@ -21,7 +21,7 @@ def make_random_policy_tetris(tetris: Tetris) -> RandomPolicy: """Make random policy for `Tetris`.""" - action_spec_num_values = tetris.action_spec.m_values + action_spec_num_values = tetris.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values ) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index a37b32237..8a7c3a1a9 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -33,7 +33,7 @@ import numpy as np from jumanji import specs, tree_utils -from jumanji.env import Environment, State +from jumanji.env import ActionSpec, Environment, State from jumanji.types import TimeStep Observation = TypeVar("Observation") @@ -42,12 +42,12 @@ GymObservation = Any -class Wrapper(Environment[State], Generic[State]): +class Wrapper(Environment[State, ActionSpec], Generic[State, ActionSpec]): """Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 """ - def __init__(self, env: Environment): + def __init__(self, env: Environment[State, ActionSpec]): self._env = env super().__init__() @@ -93,7 +93,7 @@ def _make_observation_spec(self) -> specs.Spec: """Returns the observation spec.""" return self._env._make_observation_spec() - def _make_action_spec(self) -> specs.Spec: + def _make_action_spec(self) -> ActionSpec: """Returns the action spec.""" return self._env._make_action_spec() diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index e68964325..ba339faee 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -25,7 +25,7 @@ import pytest_mock from jumanji import specs -from jumanji.env import Environment +from jumanji.env import ActionSpec, Environment from jumanji.environments.packing.bin_pack import conftest as bin_pack_conftest from jumanji.environments.packing.bin_pack.env import BinPack from jumanji.testing.fakes import FakeEnvironment, FakeMultiEnvironment, FakeState @@ -48,7 +48,7 @@ @pytest.fixture def mock_wrapper_class() -> Type[Wrapper]: - class MockWrapper(Wrapper[FakeState]): + class MockWrapper(Wrapper[FakeState, ActionSpec]): pass return MockWrapper @@ -418,7 +418,8 @@ def test_multi_env__observation_spec( """ obs_spec: specs.Array = fake_multi_to_single_env.observation_spec # type: ignore assert isinstance(obs_spec, specs.Array) - assert obs_spec.shape == fake_multi_environment.observation_spec.shape + multi_obs_spec: specs.Array = fake_multi_environment.observation_spec # type: ignore + assert obs_spec.shape == multi_obs_spec.shape def test_multi_env__action_spec( self, From 7c2c6283e67c366ea7f384fac2a8403499030731 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Sat, 27 Jan 2024 02:06:21 +0000 Subject: [PATCH 04/16] docs: Remove NestedSpec from docstrings --- jumanji/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index 64673ccf1..d488ac3c9 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -83,7 +83,7 @@ def observation_spec(self) -> specs.Spec: """Returns the observation spec. Returns: - observation_spec: a NestedSpec tree of spec. + observation_spec: a potentially nested `Spec` structure representing the observation. """ return self._observation_spec @@ -92,7 +92,7 @@ def _make_observation_spec(self) -> specs.Spec: """Returns new observation spec. Returns: - observation_spec: a NestedSpec tree of spec. + observation_spec: a potentially nested `Spec` structure representing the observation. """ @property @@ -100,7 +100,7 @@ def action_spec(self) -> ActionSpec: """Returns the action spec. Returns: - action_spec: a NestedSpec tree of spec. + action_spec: a potentially nested `Spec` structure representing the action. """ return self._action_spec @@ -109,7 +109,7 @@ def _make_action_spec(self) -> ActionSpec: """Returns new action spec. Returns: - action_spec: a NestedSpec tree of spec. + action_spec: a potentially nested `Spec` structure representing the action. """ @property From 44e6837905909e8e4a802bd2e2fc424f28d40546 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Sat, 27 Jan 2024 00:17:51 -0600 Subject: [PATCH 05/16] feat(env): implement generic Observation typevar on Environment --- jumanji/env.py | 15 +++++---- jumanji/environments/logic/game_2048/env.py | 2 +- .../environments/logic/graph_coloring/env.py | 2 +- jumanji/environments/logic/minesweeper/env.py | 2 +- jumanji/environments/logic/rubiks_cube/env.py | 2 +- jumanji/environments/logic/sudoku/env.py | 2 +- jumanji/environments/packing/bin_pack/env.py | 2 +- jumanji/environments/packing/job_shop/env.py | 2 +- jumanji/environments/packing/knapsack/env.py | 2 +- jumanji/environments/packing/tetris/env.py | 2 +- jumanji/environments/routing/cleaner/env.py | 2 +- jumanji/environments/routing/connector/env.py | 2 +- jumanji/environments/routing/cvrp/env.py | 2 +- jumanji/environments/routing/maze/env.py | 2 +- jumanji/environments/routing/mmst/env.py | 2 +- .../environments/routing/multi_cvrp/env.py | 2 +- .../routing/robot_warehouse/env.py | 2 +- jumanji/environments/routing/snake/env.py | 2 +- jumanji/environments/routing/tsp/env.py | 2 +- jumanji/testing/fakes.py | 4 +-- jumanji/wrappers.py | 32 +++++++------------ jumanji/wrappers_test.py | 4 +-- 22 files changed, 43 insertions(+), 48 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index d488ac3c9..cbd57268d 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -34,9 +34,10 @@ class StateProtocol(Protocol): State = TypeVar("State", bound="StateProtocol") ActionSpec = TypeVar("ActionSpec", bound=specs.Array) +Observation = TypeVar("Observation") -class Environment(abc.ABC, Generic[State, ActionSpec]): +class Environment(abc.ABC, Generic[State, ActionSpec, Observation]): """Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. @@ -54,7 +55,7 @@ def __init__(self) -> None: self._discount_spec = self._make_discount_spec() @abc.abstractmethod - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment to an initial state. Args: @@ -66,7 +67,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """ @abc.abstractmethod - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -79,7 +82,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ @property - def observation_spec(self) -> specs.Spec: + def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. Returns: @@ -88,7 +91,7 @@ def observation_spec(self) -> specs.Spec: return self._observation_spec @abc.abstractmethod - def _make_observation_spec(self) -> specs.Spec: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Returns new observation spec. Returns: @@ -149,7 +152,7 @@ def _make_discount_spec(self) -> specs.BoundedArray: ) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self def render(self, state: State) -> Any: diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 1e4c253ec..f95ca173e 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class Game2048(Environment[State, specs.DiscreteArray]): +class Game2048(Environment[State, specs.DiscreteArray, Observation]): """Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index 0ad435dc8..a6ed516eb 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -33,7 +33,7 @@ from jumanji.viewer import Viewer -class GraphColoring(Environment[State, specs.DiscreteArray]): +class GraphColoring(Environment[State, specs.DiscreteArray, Observation]): """Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index d4b40797d..6921a671f 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -36,7 +36,7 @@ from jumanji.viewer import Viewer -class Minesweeper(Environment[State, specs.MultiDiscreteArray]): +class Minesweeper(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the minesweeper game. - observation: `Observation` diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 3f916f741..cf13ce388 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -42,7 +42,7 @@ from jumanji.viewer import Viewer -class RubiksCube(Environment[State, specs.MultiDiscreteArray]): +class RubiksCube(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 65eb43778..8857c7931 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -32,7 +32,7 @@ from jumanji.viewer import Viewer -class Sudoku(Environment[State, specs.MultiDiscreteArray]): +class Sudoku(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the sudoku game. - observation: `Observation` diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index bf8b679c1..2716bb6e2 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -43,7 +43,7 @@ from jumanji.viewer import Viewer -class BinPack(Environment[State, specs.MultiDiscreteArray]): +class BinPack(Environment[State, specs.MultiDiscreteArray, Observation]): """Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 213b2d205..b25e2eedc 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class JobShop(Environment[State, specs.MultiDiscreteArray]): +class JobShop(Environment[State, specs.MultiDiscreteArray, Observation]): """The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given `num_jobs` jobs, each consisting of at most `max_num_ops` ops, which need to be processed on `num_machines` machines. diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index cb9a275fc..6ee9ec919 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Knapsack(Environment[State, specs.DiscreteArray]): +class Knapsack(Environment[State, specs.DiscreteArray, Observation]): """Knapsack environment as described in [1]. - observation: Observation diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index b0e0612cf..7e4bbc216 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -34,7 +34,7 @@ from jumanji.viewer import Viewer -class Tetris(Environment[State, specs.MultiDiscreteArray]): +class Tetris(Environment[State, specs.MultiDiscreteArray, Observation]): """RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 5122d6f91..53dcdf5d6 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Cleaner(Environment[State, specs.MultiDiscreteArray]): +class Cleaner(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index 0784cc963..e2d621729 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -46,7 +46,7 @@ from jumanji.viewer import Viewer -class Connector(Environment[State, specs.MultiDiscreteArray]): +class Connector(Environment[State, specs.MultiDiscreteArray, Observation]): """The `Connector` environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index 0e7d612d6..f5d9f47d5 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class CVRP(Environment[State, specs.DiscreteArray]): +class CVRP(Environment[State, specs.DiscreteArray, Observation]): """Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. - observation: `Observation` diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 3c654324b..647bc79e3 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -30,7 +30,7 @@ from jumanji.viewer import Viewer -class Maze(Environment[State, specs.DiscreteArray]): +class Maze(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index 8997a1c0a..d81146c74 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -42,7 +42,7 @@ from jumanji.viewer import Viewer -class MMST(Environment[State, specs.MultiDiscreteArray]): +class MMST(Environment[State, specs.MultiDiscreteArray, Observation]): """The `MMST` (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index afc781e21..666f2ba2c 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -48,7 +48,7 @@ from jumanji.viewer import Viewer -class MultiCVRP(Environment[State, specs.BoundedArray]): +class MultiCVRP(Environment[State, specs.BoundedArray, Observation]): """ Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index f13c79170..290b75020 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -48,7 +48,7 @@ from jumanji.viewer import Viewer -class RobotWarehouse(Environment[State, specs.MultiDiscreteArray]): +class RobotWarehouse(Environment[State, specs.MultiDiscreteArray, Observation]): """A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index 9fde458d5..90a179d6e 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class Snake(Environment[State, specs.DiscreteArray]): +class Snake(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'Snake' game. - observation: `Observation` diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 009eee764..7f4c084a5 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -31,7 +31,7 @@ from jumanji.viewer import Viewer -class TSP(Environment[State, specs.DiscreteArray]): +class TSP(Environment[State, specs.DiscreteArray, Observation]): """Traveling Salesman Problem (TSP) environment as described in [1]. - observation: Observation diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index 8925233de..002778a31 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -34,7 +34,7 @@ class FakeState: step: jnp.int32 -class FakeEnvironment(Environment[FakeState, specs.BoundedArray]): +class FakeEnvironment(Environment[FakeState, specs.BoundedArray, chex.Array]): """ A fake environment that inherits from Environment, for testing purposes. The observation is an array full of `state.step` of shape `(self.observation_shape,)` @@ -143,7 +143,7 @@ def _state_to_obs(self, state: FakeState) -> chex.Array: return state.step * jnp.ones(self.observation_shape, float) -class FakeMultiEnvironment(Environment[FakeState, specs.BoundedArray]): +class FakeMultiEnvironment(Environment[FakeState, specs.BoundedArray, chex.Array]): """ A fake multi agent environment that inherits from Environment, for testing purposes. """ diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 8a7c3a1a9..bd49371b7 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -13,17 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import ( - Any, - Callable, - ClassVar, - Dict, - Generic, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union import chex import dm_env.specs @@ -33,21 +23,21 @@ import numpy as np from jumanji import specs, tree_utils -from jumanji.env import ActionSpec, Environment, State +from jumanji.env import ActionSpec, Environment, Observation, State from jumanji.types import TimeStep -Observation = TypeVar("Observation") - # Type alias that corresponds to ObsType in the Gym API GymObservation = Any -class Wrapper(Environment[State, ActionSpec], Generic[State, ActionSpec]): +class Wrapper( + Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 """ - def __init__(self, env: Environment[State, ActionSpec]): + def __init__(self, env: Environment[State, ActionSpec, Observation]): self._env = env super().__init__() @@ -60,11 +50,11 @@ def __getattr__(self, name: str) -> Any: return getattr(self._env, name) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: """Returns the wrapped env.""" return self._env.unwrapped - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment to an initial state. Args: @@ -76,7 +66,9 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """ return self._env.reset(key) - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: """Run one timestep of the environment's dynamics. Args: @@ -89,7 +81,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """ return self._env.step(state, action) - def _make_observation_spec(self) -> specs.Spec: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec.""" return self._env._make_observation_spec() diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index ba339faee..6f93b9771 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -25,7 +25,7 @@ import pytest_mock from jumanji import specs -from jumanji.env import ActionSpec, Environment +from jumanji.env import Environment from jumanji.environments.packing.bin_pack import conftest as bin_pack_conftest from jumanji.environments.packing.bin_pack.env import BinPack from jumanji.testing.fakes import FakeEnvironment, FakeMultiEnvironment, FakeState @@ -48,7 +48,7 @@ @pytest.fixture def mock_wrapper_class() -> Type[Wrapper]: - class MockWrapper(Wrapper[FakeState, ActionSpec]): + class MockWrapper(Wrapper[FakeState, specs.BoundedArray, chex.Array]): pass return MockWrapper From 094831eaeaea9f0020e658ae396942d161de2569 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Tue, 30 Jan 2024 03:19:30 +0000 Subject: [PATCH 06/16] feat(wrappers): Update wrappers to inherit type hints from Environment --- jumanji/wrappers.py | 51 +++++++++++++++++++++++++++++----------- jumanji/wrappers_test.py | 3 +++ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index bd49371b7..92f47c480 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -120,10 +120,16 @@ def __exit__(self, *args: Any) -> None: self.close() -class JumanjiToDMEnvWrapper(dm_env.Environment): +class JumanjiToDMEnvWrapper( + dm_env.Environment, Generic[State, ActionSpec, Observation] +): """A wrapper that converts Environment to dm_env.Environment.""" - def __init__(self, env: Environment, key: Optional[chex.PRNGKey] = None): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + key: Optional[chex.PRNGKey] = None, + ): """Create the wrapped environment. Args: @@ -209,16 +215,18 @@ def action_spec(self) -> dm_env.specs.Array: return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec) @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self._env -class MultiToSingleWrapper(Wrapper): +class MultiToSingleWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """A wrapper that converts a multi-agent Environment to a single-agent Environment.""" def __init__( self, - env: Environment, + env: Environment[State, ActionSpec, Observation], reward_aggregator: Callable = jnp.sum, discount_aggregator: Callable = jnp.max, ): @@ -235,7 +243,9 @@ def __init__( self._reward_aggregator = reward_aggregator self._discount_aggregator = discount_aggregator - def _aggregate_timestep(self, timestep: TimeStep) -> TimeStep: + def _aggregate_timestep( + self, timestep: TimeStep[Observation] + ) -> TimeStep[Observation]: """Apply the reward and discount aggregator to a multi-agent timestep object to create a new timestep object that consists of a scalar reward and discount value. @@ -290,7 +300,9 @@ def step( return state, timestep -class VmapWrapper(Wrapper): +class VmapWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: @@ -352,7 +364,9 @@ def render(self, state: State) -> Any: return super().render(state_0) -class AutoResetWrapper(Wrapper): +class AutoResetWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. @@ -404,7 +418,9 @@ def step( return state, timestep -class VmapAutoResetWrapper(Wrapper): +class VmapAutoResetWrapper( + Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] +): """Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of the combination of both wrappers. `env = VmapAutoResetWrapper(env)` is equivalent to `env = VmapWrapper(AutoResetWrapper(env))` @@ -463,7 +479,7 @@ def step( return state, timestep def _auto_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated. @@ -487,7 +503,7 @@ def _auto_reset( return state, timestep def _maybe_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Overwrite the state and timestep appropriately if the episode terminates.""" state, timestep = jax.lax.cond( @@ -511,14 +527,19 @@ def render(self, state: State) -> Any: return super().render(state_0) -class JumanjiToGymWrapper(gym.Env): +class JumanjiToGymWrapper(gym.Env, Generic[State, ActionSpec, Observation]): """A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API.""" # Flag that prevents `gym.register` from misinterpreting the `_step` and # `_reset` as signs of a deprecated gym Env API. _gym_disable_underscore_compat: ClassVar[bool] = True - def __init__(self, env: Environment, seed: int = 0, backend: Optional[str] = None): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + seed: int = 0, + backend: Optional[str] = None, + ): """Create the Gym environment. Args: @@ -622,6 +643,8 @@ def render(self, mode: str = "human") -> Any: mode: currently not used since Jumanji does not currently support modes. """ del mode + if self._state is None: + raise ValueError("Cannot render when _state is None.") return self._env.render(self._state) def close(self) -> None: @@ -629,7 +652,7 @@ def close(self) -> None: self._env.close() @property - def unwrapped(self) -> Environment: + def unwrapped(self) -> Environment[State, ActionSpec, Observation]: return self._env diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 6f93b9771..cbf87a57e 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -309,6 +309,9 @@ def test_jumanji_environment_to_gym_env__render( ) mock_state = mocker.MagicMock() + with pytest.raises(ValueError): + fake_gym_env.render(mock_state) + fake_gym_env.reset() fake_gym_env.render(mock_state) mock_render.assert_called_once() From 732229bc317bc9bedaaa9e96a091168edd255b8c Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Mon, 29 Jan 2024 22:56:41 -0600 Subject: [PATCH 07/16] feat(wrappers): Add types to wrapper unit tests --- jumanji/wrappers_test.py | 197 ++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 87 deletions(-) diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index cbf87a57e..f779faf7d 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -44,10 +44,11 @@ State = TypeVar("State") Observation = TypeVar("Observation") +FakeWrapper = Wrapper[FakeState, specs.BoundedArray, chex.Array] @pytest.fixture -def mock_wrapper_class() -> Type[Wrapper]: +def mock_wrapper_class() -> Type[FakeWrapper]: class MockWrapper(Wrapper[FakeState, specs.BoundedArray, chex.Array]): pass @@ -69,13 +70,13 @@ class TestBaseWrapper: @pytest.fixture def wrapped_fake_environment( - self, mock_wrapper_class: Type[Wrapper], fake_environment: FakeEnvironment - ) -> Wrapper: + self, mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment + ) -> FakeWrapper: wrapped_env = mock_wrapper_class(fake_environment) return wrapped_env def test_wrapper__unwrapped( - self, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment + self, wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment ) -> None: """Checks `Wrapper.unwrapped` returns the unwrapped env.""" assert wrapped_fake_environment.unwrapped is fake_environment @@ -83,7 +84,7 @@ def test_wrapper__unwrapped( def test_wrapper__step( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.step` calls the step method of the underlying env.""" @@ -98,7 +99,7 @@ def test_wrapper__step( def test_wrapper__reset( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.reset` calls the reset method of the underlying env.""" @@ -112,7 +113,7 @@ def test_wrapper__reset( def test_wrapper__make_observation_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper._make_observation_spec` calls the _make_observation_spec function of @@ -129,7 +130,7 @@ def test_wrapper__make_observation_spec( def test_wrapper__make_action_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper._make_action_spec` calls the _make_action_spec function of the underlying @@ -143,7 +144,7 @@ def test_wrapper__make_action_spec( mock_make_action_spec.assert_called_once() - def test_wrapper__repr(self, wrapped_fake_environment: Wrapper) -> None: + def test_wrapper__repr(self, wrapped_fake_environment: FakeWrapper) -> None: """Checks `Wrapper.__repr__` returns the expected representation string.""" repr_str = repr(wrapped_fake_environment) assert "MockWrapper" in repr_str @@ -151,7 +152,7 @@ def test_wrapper__repr(self, wrapped_fake_environment: Wrapper) -> None: def test_wrapper__render( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.render` calls the render method of the underlying env.""" @@ -168,7 +169,7 @@ def test_wrapper__render( def test_wrapper__close( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: Wrapper, + wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment, ) -> None: """Checks `Wrapper.close` calls the close method of the underlying env.""" @@ -180,13 +181,18 @@ def test_wrapper__close( mock_action_spec.assert_called_once() def test_wrapper__getattr( - self, wrapped_fake_environment: Wrapper, fake_environment: FakeEnvironment + self, wrapped_fake_environment: FakeWrapper, fake_environment: FakeEnvironment ) -> None: """Checks `Wrapper.__getattr__` calls the underlying env for unknown attr.""" # time_limit is defined in the mock env assert wrapped_fake_environment.time_limit == fake_environment.time_limit +FakeJumanjiToDMEnvWrapper = JumanjiToDMEnvWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestJumanjiEnvironmentToDeepMindEnv: """Test the JumanjiEnvironmentToDeepMindEnv that transforms an Environment into a dm_env.Environment format. @@ -206,14 +212,14 @@ def test_jumanji_environment_to_deep_mind_env__init( ) assert isinstance(dm_environment_with_key, dm_env.Environment) - def test_dm_env__reset(self, fake_dm_env: JumanjiToDMEnvWrapper) -> None: + def test_dm_env__reset(self, fake_dm_env: FakeJumanjiToDMEnvWrapper) -> None: """Validates reset function and timestep type of the wrapped environment.""" timestep = fake_dm_env.reset() assert isinstance(timestep, dm_env.TimeStep) assert timestep.step_type == dm_env.StepType.FIRST def test_jumanji_environment_to_deep_mind_env__step( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates step function of the wrapped environment.""" timestep = fake_dm_env.reset() @@ -222,31 +228,34 @@ def test_jumanji_environment_to_deep_mind_env__step( assert next_timestep != timestep def test_jumanji_environment_to_deep_mind_env__observation_spec( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates observation_spec property of the wrapped environment.""" assert isinstance(fake_dm_env.observation_spec(), dm_env.specs.Array) def test_jumanji_environment_to_deep_mind_env__action_spec( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates action_spec property of the wrapped environment.""" assert isinstance(fake_dm_env.action_spec(), dm_env.specs.Array) def test_jumanji_environment_to_deep_mind_env__unwrapped( - self, fake_dm_env: JumanjiToDMEnvWrapper + self, fake_dm_env: FakeJumanjiToDMEnvWrapper ) -> None: """Validates unwrapped property of the wrapped environment.""" assert isinstance(fake_dm_env.unwrapped, Environment) +FakeJumanjiToGymWrapper = JumanjiToGymWrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestJumanjiEnvironmentToGymEnv: """ Test the JumanjiEnvironmentToGymEnv that transforms an Environment into a gym.Env format. """ @pytest.fixture - def fake_gym_env(self, time_limit: int = 10) -> gym.Env: + def fake_gym_env(self, time_limit: int = 10) -> FakeJumanjiToGymWrapper: """Creates a fake environment wrapped as a gym.Env.""" return JumanjiToGymWrapper(FakeEnvironment(time_limit=time_limit)) @@ -260,12 +269,12 @@ def test_jumanji_environment_to_gym_env__init( assert isinstance(gym_environment_with_seed, gym.Env) def test_jumanji_environment_to_gym_env__reset( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates reset function of the wrapped environment.""" - observation1 = fake_gym_env.reset() # type: ignore + observation1 = fake_gym_env.reset() state1 = fake_gym_env._state - observation2 = fake_gym_env.reset() # type: ignore + observation2 = fake_gym_env.reset() state2 = fake_gym_env._state # Observation is typically numpy array @@ -277,24 +286,24 @@ def test_jumanji_environment_to_gym_env__reset( assert_trees_are_different(state1, state2) def test_jumanji_environment_to_gym_env__step( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates step function of the wrapped environment.""" - observation = fake_gym_env.reset() # type: ignore + observation = fake_gym_env.reset() action = fake_gym_env.action_space.sample() - next_observation, reward, terminated, info = fake_gym_env.step(action) # type: ignore + next_observation, reward, terminated, info = fake_gym_env.step(action) assert_trees_are_different(observation, next_observation) assert isinstance(reward, float) assert isinstance(terminated, bool) def test_jumanji_environment_to_gym_env__observation_space( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates observation_space attribute of the wrapped environment.""" assert isinstance(fake_gym_env.observation_space, gym.spaces.Space) def test_jumanji_environment_to_gym_env__action_space( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates action_space attribute of the wrapped environment.""" assert isinstance(fake_gym_env.action_space, gym.spaces.Space) @@ -302,7 +311,7 @@ def test_jumanji_environment_to_gym_env__action_space( def test_jumanji_environment_to_gym_env__render( self, mocker: pytest_mock.MockerFixture, - fake_gym_env: JumanjiToGymWrapper, + fake_gym_env: FakeJumanjiToGymWrapper, ) -> None: mock_render = mocker.patch.object( fake_gym_env.unwrapped, "render", autospec=True @@ -319,7 +328,7 @@ def test_jumanji_environment_to_gym_env__render( def test_jumanji_environment_to_gym_env__close( self, mocker: pytest_mock.MockerFixture, - fake_gym_env: JumanjiToGymWrapper, + fake_gym_env: FakeJumanjiToGymWrapper, ) -> None: mock_close = mocker.patch.object(fake_gym_env.unwrapped, "close", autospec=True) @@ -328,17 +337,22 @@ def test_jumanji_environment_to_gym_env__close( mock_close.assert_called_once() def test_jumanji_environment_to_gym_env__unwrapped( - self, fake_gym_env: JumanjiToGymWrapper + self, fake_gym_env: FakeJumanjiToGymWrapper ) -> None: """Validates unwrapped property of the wrapped environment.""" assert isinstance(fake_gym_env.unwrapped, Environment) +FakeMultiToSingleWrapper = MultiToSingleWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestMultiToSingleEnvironment: @pytest.fixture def fake_multi_to_single_env( self, fake_multi_environment: FakeMultiEnvironment - ) -> MultiToSingleWrapper: + ) -> FakeMultiToSingleWrapper: """Creates a fake wrapper that converts a multi-agent Environment to a single-agent Environment.""" return MultiToSingleWrapper(fake_multi_environment) @@ -346,7 +360,7 @@ def fake_multi_to_single_env( def test_multi_env_wrapper__init( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates initialization of the multi agent to single agent wrapper.""" single_agent_env = MultiToSingleWrapper(fake_multi_environment) @@ -355,7 +369,7 @@ def test_multi_env_wrapper__init( def test_multi_env__reset( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Validates (jitted) reset function and timestep type of the multi agent @@ -371,13 +385,13 @@ def test_multi_env__reset( def test_multi_env__step( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Validates (jitted) step function of the multi agent to single agent wrapped environment. """ - state, timestep = fake_multi_to_single_env.reset(key) # type: ignore + state, timestep = fake_multi_to_single_env.reset(key) action = fake_multi_to_single_env.action_spec.generate_value() state, next_timestep = jax.jit(fake_multi_to_single_env.step)(state, action) assert next_timestep != timestep @@ -393,18 +407,16 @@ def test_multi_env__step( def test_multi_env__different_reward_aggregator( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, key: chex.PRNGKey, ) -> None: """Checks that using a different reward aggregator is correct.""" mean_fake_multi_to_single_env = MultiToSingleWrapper( fake_multi_environment, reward_aggregator=jnp.mean ) - state, timestep = mean_fake_multi_to_single_env.reset(key) # type: ignore + state, timestep = mean_fake_multi_to_single_env.reset(key) action = mean_fake_multi_to_single_env.action_spec.generate_value() - state, next_timestep = mean_fake_multi_to_single_env.step( - state, action - ) # type: Tuple[FakeState, TimeStep] + state, next_timestep = mean_fake_multi_to_single_env.step(state, action) assert next_timestep != timestep assert next_timestep.reward.shape == () assert next_timestep.reward == fake_multi_environment.reward_per_step @@ -414,7 +426,7 @@ def test_multi_env__different_reward_aggregator( def test_multi_env__observation_spec( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates observation_spec property of the multi agent to single agent wrapped environment. @@ -427,19 +439,19 @@ def test_multi_env__observation_spec( def test_multi_env__action_spec( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates action_spec property of the multi agent to single agent wrapped environment. """ - action_spec: specs.Array = fake_multi_to_single_env.action_spec # type: ignore + action_spec = fake_multi_to_single_env.action_spec assert isinstance(fake_multi_to_single_env.action_spec, specs.Array) assert action_spec.shape == fake_multi_environment.action_spec.shape def test_multi_env__unwrapped( self, fake_multi_environment: FakeMultiEnvironment, - fake_multi_to_single_env: MultiToSingleWrapper, + fake_multi_to_single_env: FakeMultiToSingleWrapper, ) -> None: """Validates unwrapped property of the multi agent to single agent wrapped environment. @@ -448,9 +460,14 @@ def test_multi_env__unwrapped( assert fake_multi_to_single_env._env is fake_multi_environment +FakeVmapWrapper = Wrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestVmapWrapper: @pytest.fixture - def fake_vmap_environment(self, fake_environment: FakeEnvironment) -> VmapWrapper: + def fake_vmap_environment( + self, fake_environment: FakeEnvironment + ) -> FakeVmapWrapper: return VmapWrapper(fake_environment) def test_vmap_wrapper__init(self, fake_environment: FakeEnvironment) -> None: @@ -459,7 +476,7 @@ def test_vmap_wrapper__init(self, fake_environment: FakeEnvironment) -> None: assert isinstance(vmap_env, Environment) def test_vmap_env__reset( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: """Validates reset function and timestep type of the vmap wrapped environment.""" _, timestep = jax.jit(fake_vmap_environment.reset)(keys) @@ -471,19 +488,15 @@ def test_vmap_env__reset( assert timestep.discount.shape == (keys.shape[0],) def test_vmap_env__step( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: """Validates step function of the vmap environment.""" - state, timestep = fake_vmap_environment.reset( - keys - ) # type: Tuple[FakeState, TimeStep] + state, timestep = fake_vmap_environment.reset(keys) action = jax.vmap(lambda _: fake_vmap_environment.action_spec.generate_value())( keys ) - state, next_timestep = jax.jit(fake_vmap_environment.step)( - state, action - ) # type: Tuple[FakeState, TimeStep] + state, next_timestep = jax.jit(fake_vmap_environment.step)(state, action) assert_trees_are_different(next_timestep, timestep) chex.assert_trees_all_equal(next_timestep.reward, 0) @@ -493,45 +506,46 @@ def test_vmap_env__step( assert next_timestep.observation.shape[0] == keys.shape[0] def test_vmap_env__render( - self, fake_vmap_environment: VmapWrapper, keys: chex.PRNGKey + self, fake_vmap_environment: FakeVmapWrapper, keys: chex.PRNGKey ) -> None: - states, _ = fake_vmap_environment.reset( - keys - ) # type: Tuple[FakeState, TimeStep] + states, _ = fake_vmap_environment.reset(keys) result = fake_vmap_environment.render(states) assert result == (keys.shape[1:], ()) def test_vmap_env__unwrapped( - self, fake_environment: Environment, fake_vmap_environment: VmapWrapper + self, fake_environment: FakeEnvironment, fake_vmap_environment: FakeVmapWrapper ) -> None: """Validates unwrapped property of the vmap environment.""" assert isinstance(fake_vmap_environment.unwrapped, Environment) assert fake_vmap_environment._env is fake_environment +FakeAutoResetWrapper = AutoResetWrapper[FakeState, specs.BoundedArray, chex.Array] + + class TestAutoResetWrapper: @pytest.fixture def fake_auto_reset_environment( - self, fake_environment: Environment - ) -> AutoResetWrapper: + self, fake_environment: FakeEnvironment + ) -> FakeAutoResetWrapper: return AutoResetWrapper(fake_environment) @pytest.fixture def fake_state_and_timestep( - self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey - ) -> Tuple[State, TimeStep[Observation]]: + self, fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey + ) -> Tuple[FakeState, TimeStep[chex.Array]]: state, timestep = jax.jit(fake_auto_reset_environment.reset)(key) return state, timestep - def test_auto_reset_wrapper__init(self, fake_environment: Environment) -> None: + def test_auto_reset_wrapper__init(self, fake_environment: FakeEnvironment) -> None: """Validates initialization of the AutoResetWrapper.""" auto_reset_env = AutoResetWrapper(fake_environment) assert isinstance(auto_reset_env, Environment) def test_auto_reset_wrapper__auto_reset( self, - fake_auto_reset_environment: AutoResetWrapper, - fake_state_and_timestep: Tuple[State, TimeStep[Observation]], + fake_auto_reset_environment: FakeAutoResetWrapper, + fake_state_and_timestep: Tuple[FakeState, TimeStep[chex.Array]], ) -> None: """Validates the auto_reset function of the AutoResetWrapper.""" state, timestep = fake_state_and_timestep @@ -541,21 +555,19 @@ def test_auto_reset_wrapper__auto_reset( chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) def test_auto_reset_wrapper__step_no_reset( - self, fake_auto_reset_environment: AutoResetWrapper, key: chex.PRNGKey + self, fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey ) -> None: """Validates that step function of the AutoResetWrapper does not do an auto-reset when the terminal state is not reached. """ - state, first_timestep = fake_auto_reset_environment.reset( - key - ) # type: Tuple[FakeState, TimeStep] + state, first_timestep = fake_auto_reset_environment.reset(key) # Generate an action action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)( state, action - ) # type: Tuple[FakeState, TimeStep] + ) # type: Tuple[FakeState, TimeStep[chex.Array]] assert timestep.step_type == StepType.MID assert_trees_are_different(timestep, first_timestep) @@ -564,7 +576,7 @@ def test_auto_reset_wrapper__step_no_reset( def test_auto_reset_wrapper__step_reset( self, fake_environment: FakeEnvironment, - fake_auto_reset_environment: AutoResetWrapper, + fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey, ) -> None: """Validates that the auto-reset is done correctly by the step function @@ -584,16 +596,23 @@ def test_auto_reset_wrapper__step_reset( chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) +FakeVmapAutoResetWrapper = VmapAutoResetWrapper[ + FakeState, specs.BoundedArray, chex.Array +] + + class TestVmapAutoResetWrapper: @pytest.fixture def fake_vmap_auto_reset_environment( self, fake_environment: FakeEnvironment - ) -> VmapAutoResetWrapper: + ) -> FakeVmapAutoResetWrapper: return VmapAutoResetWrapper(fake_environment) @pytest.fixture def action( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> chex.Array: generate_action_fn = ( lambda _: fake_vmap_auto_reset_environment.action_spec.generate_value() @@ -608,7 +627,9 @@ def test_vmap_auto_reset_wrapper__init( assert isinstance(vmap_auto_reset_env, Environment) def test_vmap_auto_reset_wrapper__reset( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> None: """Validates reset function and timestep type of the wrapper.""" _, timestep = jax.jit(fake_vmap_auto_reset_environment.reset)(keys) @@ -621,11 +642,11 @@ def test_vmap_auto_reset_wrapper__reset( def test_vmap_auto_reset_wrapper__auto_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, ) -> None: """Validates the auto_reset function of the wrapper.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) _, reset_timestep = jax.lax.map( lambda args: fake_vmap_auto_reset_environment._auto_reset(*args), (state, timestep), @@ -634,11 +655,11 @@ def test_vmap_auto_reset_wrapper__auto_reset( def test_vmap_auto_reset_wrapper__maybe_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, ) -> None: """Validates the auto_reset function of the wrapper.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) _, reset_timestep = jax.lax.map( lambda args: fake_vmap_auto_reset_environment._maybe_reset(*args), (state, timestep), @@ -647,14 +668,14 @@ def test_vmap_auto_reset_wrapper__maybe_reset( def test_vmap_auto_reset_wrapper__step_no_reset( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates that step function of the wrapper does not do an auto-reset when the terminal state is not reached. """ - state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)(state, action) assert jnp.all(timestep.step_type == StepType.MID) @@ -663,15 +684,15 @@ def test_vmap_auto_reset_wrapper__step_no_reset( def test_vmap_auto_reset_wrapper__step_reset( self, - fake_environment: Environment, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_environment: FakeEnvironment, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates that the auto-reset is done correctly by the step function of the wrapper when the terminal timestep is reached. """ - state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, first_timestep = fake_vmap_auto_reset_environment.reset(keys) fake_vmap_auto_reset_environment.unwrapped.time_limit = 5 # type: ignore # Loop across time_limit so auto-reset occurs @@ -685,12 +706,12 @@ def test_vmap_auto_reset_wrapper__step_reset( def test_vmap_auto_reset_wrapper__step( self, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, keys: chex.PRNGKey, action: chex.Array, ) -> None: """Validates step function of the vmap environment.""" - state, timestep = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + state, timestep = fake_vmap_auto_reset_environment.reset(keys) state, next_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( state, action ) @@ -702,16 +723,18 @@ def test_vmap_auto_reset_wrapper__step( assert next_timestep.observation.shape[0] == keys.shape[0] def test_vmap_auto_reset_wrapper__render( - self, fake_vmap_auto_reset_environment: VmapAutoResetWrapper, keys: chex.PRNGKey + self, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, + keys: chex.PRNGKey, ) -> None: - states, _ = fake_vmap_auto_reset_environment.reset(keys) # type: ignore + states, _ = fake_vmap_auto_reset_environment.reset(keys) result = fake_vmap_auto_reset_environment.render(states) assert result == (keys.shape[1:], ()) def test_vmap_auto_reset_wrapper__unwrapped( self, fake_environment: FakeEnvironment, - fake_vmap_auto_reset_environment: VmapAutoResetWrapper, + fake_vmap_auto_reset_environment: FakeVmapAutoResetWrapper, ) -> None: """Validates unwrapped property of the vmap environment.""" assert isinstance(fake_vmap_auto_reset_environment.unwrapped, FakeEnvironment) From bd73139dab9a3bb3a7bda4c0cf65812dfea10897 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Mon, 29 Jan 2024 23:22:28 -0600 Subject: [PATCH 08/16] fix: remove duplicate TypeVar definitions --- jumanji/env.py | 3 +-- jumanji/testing/env_not_smoke.py | 2 +- jumanji/wrappers_test.py | 4 +--- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index cbd57268d..d728f86b9 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -23,7 +23,7 @@ from typing_extensions import Protocol from jumanji import specs -from jumanji.types import TimeStep +from jumanji.types import Observation, TimeStep class StateProtocol(Protocol): @@ -34,7 +34,6 @@ class StateProtocol(Protocol): State = TypeVar("State", bound="StateProtocol") ActionSpec = TypeVar("ActionSpec", bound=specs.Array) -Observation = TypeVar("Observation") class Environment(abc.ABC, Generic[State, ActionSpec, Observation]): diff --git a/jumanji/testing/env_not_smoke.py b/jumanji/testing/env_not_smoke.py index 6c0617ee5..4a9603c9d 100644 --- a/jumanji/testing/env_not_smoke.py +++ b/jumanji/testing/env_not_smoke.py @@ -20,8 +20,8 @@ from jumanji import specs from jumanji.env import Environment +from jumanji.types import Observation -Observation = TypeVar("Observation") Action = TypeVar("Action") SelectActionFn = Callable[[chex.PRNGKey, Observation], Action] diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index f779faf7d..aae6d39d6 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import namedtuple -from typing import Tuple, Type, TypeVar +from typing import Tuple, Type import chex import dm_env.specs @@ -42,8 +42,6 @@ jumanji_to_gym_obs, ) -State = TypeVar("State") -Observation = TypeVar("Observation") FakeWrapper = Wrapper[FakeState, specs.BoundedArray, chex.Array] From 77af200e5760fc8a23db6b6292fd54fe54117f07 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Tue, 30 Jan 2024 08:22:14 +0000 Subject: [PATCH 09/16] feat(pacman): change pacman specs to properties --- jumanji/environments/routing/pac_man/env.py | 11 +++++------ jumanji/environments/routing/pac_man/env_test.py | 12 +++++++++--- jumanji/training/networks/pac_man/actor_critic.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/jumanji/environments/routing/pac_man/env.py b/jumanji/environments/routing/pac_man/env.py index 6db84f16c..0d866d8d4 100644 --- a/jumanji/environments/routing/pac_man/env.py +++ b/jumanji/environments/routing/pac_man/env.py @@ -35,7 +35,7 @@ from jumanji.viewer import Viewer -class PacMan(Environment[State]): +class PacMan(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'PacMan' game where a single agent must navigate a maze to collect pellets and avoid 4 heuristic agents. The game takes place on a 31x28 grid where the player can move in 4 directions (left, right, up, down) and collect @@ -103,7 +103,7 @@ class PacMan(Environment[State]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` @@ -129,10 +129,11 @@ def __init__( self.x_size = self.generator.x_size self.y_size = self.generator.y_size self.pellet_spaces = self.generator.pellet_spaces + super().__init__() self._viewer = viewer or PacManViewer("Pacman", render_mode="human") self.time_limit = 1000 or time_limit - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `PacMan` environment. Returns: @@ -199,7 +200,7 @@ def observation_spec(self) -> specs.Spec[Observation]: score=score, ) - def action_spec(self) -> specs.DiscreteArray: + def _make_action_spec(self) -> specs.DiscreteArray: """Returns the action spec. 5 actions: [0,1,2,3,4] -> [Up, Right, Down, Left, No-op]. @@ -210,7 +211,6 @@ def action_spec(self) -> specs.DiscreteArray: return specs.DiscreteArray(5, name="action") def __repr__(self) -> str: - return ( f"PacMan(\n" f"\tnum_rows={self.x_size!r},\n" @@ -460,7 +460,6 @@ def check_power_up( return power_up_locations, eat, reward def check_wall_collisions(self, state: State, new_player_pos: Position) -> Any: - """ Check if the new player position collides with a wall. diff --git a/jumanji/environments/routing/pac_man/env_test.py b/jumanji/environments/routing/pac_man/env_test.py index f2ab5a7ec..54bc13d5c 100644 --- a/jumanji/environments/routing/pac_man/env_test.py +++ b/jumanji/environments/routing/pac_man/env_test.py @@ -19,7 +19,10 @@ from jumanji.environments.routing.pac_man.env import PacMan from jumanji.environments.routing.pac_man.types import Position, State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -96,12 +99,15 @@ def test_pac_man_step_invalid(pac_man: PacMan) -> None: def test_pac_man_does_not_smoke(pac_man: PacMan) -> None: - check_env_does_not_smoke(pac_man) -def test_power_pellet(pac_man: PacMan) -> None: +def test_env_pac_man_specs_does_not_smoke(pac_man: PacMan) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(pac_man) + +def test_power_pellet(pac_man: PacMan) -> None: key = jax.random.PRNGKey(0) state, timestep = pac_man.reset(key) diff --git a/jumanji/training/networks/pac_man/actor_critic.py b/jumanji/training/networks/pac_man/actor_critic.py index 566adf7b9..59ca353ad 100644 --- a/jumanji/training/networks/pac_man/actor_critic.py +++ b/jumanji/training/networks/pac_man/actor_critic.py @@ -37,7 +37,7 @@ def make_actor_critic_networks_pacman( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `PacMan` environment.""" - num_actions = np.asarray(pac_man.action_spec().num_values) + num_actions = np.asarray(pac_man.action_spec.num_values) parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) From e2d7e6a52b5a051b67c8fc60758a8d9c5aa341d4 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 17:21:46 +0000 Subject: [PATCH 10/16] Merge branch 'main' into 98-spec-props --- .gitignore | 1 + MANIFEST.in | 1 + README.md | 2 + docs/api/environments/sokoban.md | 9 + docs/env_anim/sokoban.gif | Bin 0 -> 109938 bytes docs/env_img/sokoban.png | Bin 0 -> 9925 bytes docs/environments/sokoban.md | 76 +++ jumanji/__init__.py | 4 +- jumanji/environments/__init__.py | 2 + jumanji/environments/logic/game_2048/env.py | 2 +- .../environments/logic/graph_coloring/env.py | 2 +- jumanji/environments/logic/minesweeper/env.py | 2 +- jumanji/environments/logic/rubiks_cube/env.py | 2 +- jumanji/environments/logic/sudoku/env.py | 2 +- jumanji/environments/packing/bin_pack/env.py | 2 +- jumanji/environments/packing/job_shop/env.py | 2 +- jumanji/environments/packing/knapsack/env.py | 2 +- jumanji/environments/packing/tetris/env.py | 2 +- jumanji/environments/routing/cleaner/env.py | 2 +- jumanji/environments/routing/connector/env.py | 2 +- .../routing/connector/env_test.py | 2 +- jumanji/environments/routing/cvrp/env.py | 2 +- .../environments/routing/cvrp/reward_test.py | 4 +- jumanji/environments/routing/maze/env.py | 2 +- jumanji/environments/routing/mmst/env.py | 11 + jumanji/environments/routing/mmst/utils.py | 7 +- .../environments/routing/multi_cvrp/env.py | 11 + jumanji/environments/routing/snake/env.py | 2 +- .../environments/routing/sokoban/__init__.py | 16 + .../environments/routing/sokoban/constants.py | 37 ++ jumanji/environments/routing/sokoban/env.py | 575 ++++++++++++++++++ .../environments/routing/sokoban/env_test.py | 217 +++++++ .../environments/routing/sokoban/generator.py | 448 ++++++++++++++ .../routing/sokoban/generator_test.py | 196 ++++++ .../routing/sokoban/imgs/agent.png | Bin 0 -> 134 bytes .../routing/sokoban/imgs/agent_on_target.png | Bin 0 -> 162 bytes .../environments/routing/sokoban/imgs/box.png | Bin 0 -> 175 bytes .../routing/sokoban/imgs/box_on_target.png | Bin 0 -> 165 bytes .../routing/sokoban/imgs/box_target.png | Bin 0 -> 141 bytes .../routing/sokoban/imgs/floor.png | Bin 0 -> 105 bytes .../routing/sokoban/imgs/wall.png | Bin 0 -> 120 bytes .../environments/routing/sokoban/reward.py | 120 ++++ .../routing/sokoban/reward_test.py | 74 +++ jumanji/environments/routing/sokoban/types.py | 53 ++ .../environments/routing/sokoban/viewer.py | 219 +++++++ jumanji/environments/routing/tsp/env.py | 2 +- .../environments/routing/tsp/reward_test.py | 2 +- jumanji/training/configs/config.yaml | 2 +- jumanji/training/configs/env/sokoban.yaml | 26 + jumanji/training/loggers.py | 44 +- jumanji/training/networks/__init__.py | 4 + jumanji/training/networks/sokoban/__init__.py | 13 + .../training/networks/sokoban/actor_critic.py | 115 ++++ jumanji/training/networks/sokoban/random.py | 35 ++ jumanji/training/setup_train.py | 12 + jumanji/wrappers.py | 73 ++- jumanji/wrappers_test.py | 72 ++- mkdocs.yml | 2 + pyproject.toml | 4 + requirements/requirements-dev.txt | 3 + requirements/requirements-train.txt | 1 - requirements/requirements.txt | 1 + setup.cfg | 2 +- 63 files changed, 2463 insertions(+), 63 deletions(-) create mode 100644 docs/api/environments/sokoban.md create mode 100644 docs/env_anim/sokoban.gif create mode 100644 docs/env_img/sokoban.png create mode 100644 docs/environments/sokoban.md create mode 100644 jumanji/environments/routing/sokoban/__init__.py create mode 100644 jumanji/environments/routing/sokoban/constants.py create mode 100644 jumanji/environments/routing/sokoban/env.py create mode 100644 jumanji/environments/routing/sokoban/env_test.py create mode 100644 jumanji/environments/routing/sokoban/generator.py create mode 100644 jumanji/environments/routing/sokoban/generator_test.py create mode 100644 jumanji/environments/routing/sokoban/imgs/agent.png create mode 100644 jumanji/environments/routing/sokoban/imgs/agent_on_target.png create mode 100644 jumanji/environments/routing/sokoban/imgs/box.png create mode 100644 jumanji/environments/routing/sokoban/imgs/box_on_target.png create mode 100644 jumanji/environments/routing/sokoban/imgs/box_target.png create mode 100644 jumanji/environments/routing/sokoban/imgs/floor.png create mode 100644 jumanji/environments/routing/sokoban/imgs/wall.png create mode 100644 jumanji/environments/routing/sokoban/reward.py create mode 100644 jumanji/environments/routing/sokoban/reward_test.py create mode 100644 jumanji/environments/routing/sokoban/types.py create mode 100644 jumanji/environments/routing/sokoban/viewer.py create mode 100644 jumanji/training/configs/env/sokoban.yaml create mode 100644 jumanji/training/networks/sokoban/__init__.py create mode 100644 jumanji/training/networks/sokoban/actor_critic.py create mode 100644 jumanji/training/networks/sokoban/random.py diff --git a/.gitignore b/.gitignore index 4e62578a6..7a4e033c2 100644 --- a/.gitignore +++ b/.gitignore @@ -154,3 +154,4 @@ cython_debug/ jumanji_env/ **/outputs/ *.xml +.sokoban_cache/ diff --git a/MANIFEST.in b/MANIFEST.in index b9e91a0e3..d66b0a7c6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ include LICENSE include requirements/* recursive-include * *.npy +recursive-include jumanji *.png # remove the test specific files recursive-exclude * *_test.py diff --git a/README.md b/README.md index 3313e2612..42ca7e6b1 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@
RobotWarehouse + RobotWarehouse
@@ -112,6 +113,7 @@ problems. | 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) | | Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) | | ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pacman/) | [doc](https://instadeepai.github.io/jumanji/environments/pacman/) +| 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |

Installation 🎬

diff --git a/docs/api/environments/sokoban.md b/docs/api/environments/sokoban.md new file mode 100644 index 000000000..1554480fc --- /dev/null +++ b/docs/api/environments/sokoban.md @@ -0,0 +1,9 @@ +::: jumanji.environments.routing.sokoban.env.Sokoban + selection: + members: + - __init__ + - observation_spec + - action_spec + - reset + - step + - render diff --git a/docs/env_anim/sokoban.gif b/docs/env_anim/sokoban.gif new file mode 100644 index 0000000000000000000000000000000000000000..fffcb85ff8e9f7a3554816284823427f74a15982 GIT binary patch literal 109938 zcmeF2RaYAf)3s@FDA1M`mzF|tcQ3TKyGx6^OK^(22X`kBJXi?sUfkW?f|Kul{>Hc7 zIh~`~vt|zVy0(;zBp?6p1LOnbFeIec*Hn?uMRhY=Btf+ zBFW_-m5cS3(-j6IkgDY-$L-;4xhC*xo9Fe>>PS=ddPe{9@x5&0?>nA@jiC3EbeqASG%Ytn zh|J42L&@9^Hp9Mz(rtxP11+~A=t|1ABAFluTTyKNblcHfbC%mNe0ycvu|n{J?YJLk z^gHpt2&{Gzq`#N%B+By~?j$Ko((fj#Yg+B5Xq%Vs0`=VwcTG#r1fmVCzmL=tT z8Mctay-de``u!}|Ijj9_&%N^f93S}Mer^C7!$Dp!f%QRt`1gu~f@q$jgTi=8hQp#{ zP3yzrH1mqXk}UV5!_vG^hNH3~p!HFCSxLoFMJ43u2vpOLU^uR9n6o~vYTm0j21DUT z$JL!^j3+g{1U4tNgWo|Xb)!7TC-sw(jHeAVnl`763+A9x$cp>%Y12k1<5}|#(B`b= zphU0;=>&3o)^-6A?1WzT+nl$b|I=fIJ;RT|Fywniwl*X@TSq7=<)}k97UQU655_Te zZO=y~GWK=?(r3qh6s=LG4pJAgs{z7uxjLk8xzE)D)TNcq14N*ctG2y}(W?Q@5uf@o z_6~jLPJz1<7nq<$*7cASWfs@CgdUV@;(OCd!-NvHe&e*-PoLXfJ=UzoHg3`v?mlrm z=JN&F!ZFuHWn??}ux%9i-K4>`9`B+`b@tt|$1-BMsoxs!xoIO%&wy{v#ih!9A^e=V zY1@?&w{AyQ&Aw^ZvyItpixp3iYhUaqnD0Q^d1zwqmaWP(GGttlyTdav}WQ(_#~@~7MIK^9Dm*XnY7ivtSv^a#EsG3;LRrHB@QC5nOUy1yVR zlEIPM*t^5*8zFBTnP}bcObsxVQBaUfO6zLaVcytv-u2w7CHCrFB)ez(il zsH3Q#gAW|=p=H>{RJ638&b}SpC(Ld^jH?XS2j{)B5ItQ~G=ELX4f|EbfKxFM_&Sk= zqj11W>(%XddKLS@z>@nVeJVy%C7cjy%hLJ9>eY|TPQ~*1tj;yqi zG__2|fL%-FRyReOS{KH^?oA^GBUp{G40aWsF<2T`IUmXn$g2U~6>Fu6jqPWv7G5}T z1Czf^W!DC*-nYm9OuaZd)W_(3s#?|DKF)L?$l25$G}cwuw7O9JSc6lC*BRgP!LG8| zLfG}Eki+JC2qBg~n4^twDmgF@rTj zg=5VEo}M&oByHk9kF$q{E987iyRM)<8$p^^}UGqbJEdFyoTcX<78ayw1J zr@`s*o5p$}D+|1*3NuYN2=9xvvsUI@`<{Etn2#+v=hAAYCZu3fLD<~GyJwkTtG)Qe z&UHU?nl^T$#mda%;PLwvQahTnh(!)VGga#tw|EeXzLi zcp-E2$0OU!MTEk8Rvv$4?xOa>BI)~3^?F(u>jNxawxFNC( zxQtrBK`b{pPd*2~5N@NtbDYv0$kydcAzTB6cBN4(#{x<%bNLh}iJAH*Cg|P^wK;As z`^2Etn6{OD#kXCbP3GpW%fx9)I zW{#i#-KW@ZQ}RB)et9aR1;9U63oj)ZxvEi;@Ai7soTcL3)IS+Mc9x2qqbk;iXsz#! zvBD77E&jgbN=6R{9&L-Wlzu0qqAxuNk$>$SzQKA4_yeKr6x+m`GEUwlf^!5 zz+I`_4Xw{psrxo>K(S`PC7};mlPfk|;QMgzv-yDMu)v0-00XVSB@0(AEr0a;fCr0! zcXWQ4MQV>0{#O=(v>5jK^M6O@{kaDNQHX-RNC!3T2Yl%ZqQ3w8l`rsHSpXKHz2Kl9 z&q5H*eGnZ+(3|F9o7BI=&3{EuLtY632M2=1F@mX@{ZQ`w`H4bghyo=Sf|>4ve|iRr zNSkw(g;FkrYBh)cwhT3t4mQT{RmTV*SqO0~3+BU6mOl6^79PUjso>%n79bt&Djn*K zAkz588xpb*?mMVT(Hu^CAO2}DJhCh(Lp#_hJT&4y%so89J3JsLEuz9R^7}zJ$TD0b zJUo6dfOjyWs4T2zFidmMI8{5$>^@3|C|prHvNt>`*)#mEA)?wdx;8EPt}Z0eGqjQ~ zy0bYtpD)HrI~sBjm75mUwGcx&AJy*})4~^;LKm~U5V(kGa8!OInBkhsPAs#od*~A@N78qs17d#m!kp9!ke7Yscs>#4lnbybQ*lTE@18 zvtmSqVe!XbYsWu_#}byu5@SXZ9L60lBus@TN>In5=!Cy{pTxnR#HEwO14!aSq$dfq zBnd?zi5w;g5hn{iB>kXImINeArzgv{B+D-*D?TJE6Q`*1r>N_sXaZ7X_!ISLDRd86 zwI5RCWs)(FfE@HdQvlF19cbMGv|R++KL8zxlMU$qz_k-r7Fe8yfaU<8cY3PdLuvqV z+F$;(VEQx{oj3=&_(bhA?Zs5%!?bAP^hEyjWSw;2Auy~wf#o5A4Un#cnI0pPoWq|{ zq?1ts$Vg33pJ-115D{CJp8N+hLxDK6;2{G7$ZSr}Y#mCjuwo@03axI={3(+uL7dX7 zlhvJ`HQbUls*~PM9P37(xU-PO4ahQ~Pnna+UTVo+SBfe$HE^+IopeYjFE9n3`i`)gy$UN^ z5-Az-Wf=+-y^6JE3zT(>SPwJQb&Ee`6#o$@d1@&}JjAo67a8m3n=X~uFJ+rEP?Tuo z+yRo@tdnjZO1xT2Ay&FoygIJ3S^R*~S(!2)ud;C6bU#@@A>wn{x5%VWfh0giS+Z_< zP)pekd+7~vxfy*~{BSvtq@th|m?l7xF`S(wo0La_R4z~feylLksc_cSUEDELnRlAoLSS)4W5!WC;)f`*bd~7SlV9ds4h&A(I!^N&8BQ3;#sx{TE z#dxZ=hSZW{*M7;Yvty{E7|zFLEWS(2q!p~Ms;J{Z2o^GV7hOEmy%~1pkgNZ}n8_>H zK$B4~Sy8i9UXK%%DYo1|A=jWBl`pwm{F>e%Wm9`Q+@MI>pt9VkyHtlMS2qo>*ORL^ z^o9@!G-kEbYsxj=L>B4lK{%rzu2Ff`BQ=D2MZ`92ri@MGdQAaQxgOZ{(xenY%k_8v zn*5^b0+*W)0nKC)X?AT8Iq%FU@4DE`mKDsFPRwi}(iRcYrqrjJ487K8Oo+B#^F~+8 zdlJ@2!Pa~0*1F8B%FIGro942pnriGe8-X@hW@ck%B8y;)Cux1bNZZ|G+wfCHSCkDA z$O)^yZ4gv6qSs!$2#p1nH+Vs(^@``7+EWq4)hkZmlEwB~Y}f%~`c@nHj$X!ICiIZB zBe(^&GzBYqfVDGrJYuKaWRk-LGw!iVo&-AqiyaG>9XZ6EYjU0MR)AL%O?+Ph>y0wVAXV|-wb-Gz2yV-9XPSAeW=wv&lYU7g>qsNBd zfPS~<0U^M^%4Po%eS@6UU_jO&2ZDZZS#0nyVzB0LAW~?Eb7*jBb1>E`+j4XuIcpG@ zMVBM*P>8I#_^oZ3?##1t8;dI>g36#lck(cQo6vuJr)ysKpV%;|zv6Tenx4Haf2UnZo<)71^;%zfv~>}C(38Atli z=Cn+`>pN_{Gq$L_GB`Pcit+rG=_1f66X1p8-{KQq5j7IGM1MrR! zZ2NcYmYe-P0Aaru5xZA>Tbh`X=?~uKWIk|GJaA#z3j+f^UV)kR8Q!?$zJBX^xCai3 zhxN?+w1!CO7g$O+6Z zV`-dhMI>!)ZGGeQRDtDu((p_haK@c}HV-??Sv!bk(2WP5z`>KvhUfOM<9l4-qhiuC zMGB5m`f2ROL}f#WUFN0Ye2>BTxZ?Q^?Be0<{K@cC-r#g@Jo$rY`huRrgpJ$O(yYqv;XT8F$hRfRBDjdQ?Fz;}BZvSZA0peC!nJ13y;C}1pf=Z`b52c9Y zD%fM%1YGRlF|_>F-g}YVudx;T+9-Fm>3lrF;jzvCmg@XTb@fr4{yC8N>4xU+E8>u) zapHLl@$5_a@&oUl^YCddH+casZAp}^-Y9y#V{ObIX$x_=V|2gAnsk8oEZy?zzxayu z^~c7KzbJ3Wex)gIiihHmiv$y?Z2gM-%xu`3#%>}Oi;V8`yr8lpo%rQv;wR!CyRtxL z16VMz+TLh7s_hbWyxP8Ej<6dA5{deOa=u6mMY~nfo@$9gwnzwx#*uo3MwMY-26w<4 zuzulpBy0#Mu*SG5mOD@5R3n)}0TLqkZzrG`m3yr(Q|sIi=KI1BE!cEs-1G$#TR(Hb z>fc)_@@s06^aGP&6e7m{ES+of@oZs`#8czRLLKIK>`RuwkBz3VV=R$pEjYq4)@4MR;IX^7=TETHHlnf<}7D6R7tr3gF;!+0R zpphz9z>ks~VEyj@-qXH22eJM7AVsb6{Y~*(HW0p?DONXzpm+@nxqw%rOdS-l6+b0Il0zjHPRN*Bwc30QtU9#b+Su*%^#FHVtcHH!`0;a}ZG6y#KIWCr&O9G69( zMcaJLGzVDKvZCp2m*y6ez`jg}-j0GALj}*$pli|gA4bMKZQJkqytnEnUS#O9O9`O+ zTOFWz|7C#5bkN?dt8HmmSo`t27)oIrzHoQkZ-6z|gIl7OHtqj4nTuRq=){v=} zrnZ~#)gNLb(fvnGJ3ei=^`^~F0y^{0WxO;&Vg8%PVUmEk%j=t*!i&4=%h$6ITA7m> zS~fUptP7f)_VXvYs)t#6?~p5_D;?3{BeGh5X#uC*Tv8=B`o2 z3&uy&p!cT+}Q{uadxEZ;DHV z$N7tOu9nrATK?l^!>_kFQ*>4+Sfr>C)O@nV@xxu)#s`*Cl3O$hx>*}naK0aSk5hIR z^y6Q+WvCil|0U(3n9)hl597TU{d*09X{t)wD2jD;vYxR-piC_dOL0W5p(bKFqy5?JU0C^`9m#a&N3p#? zh)h4O>{$8)REDdxWqeP6I{W;_lqaNFLAo{Z?|Lt7UI4L--Br@JooIQ^3BZu<$FOYl zzRyCRj}!q^2|1-)(jqh&@s2>#tQsa)55BMqgDReRt1@gQA z^+K|cLy*1~Qo`I!$%x8ftUQTIa*1m${F(j^BV;PNf2QovR0D`{9A?EaZitx=V?Fhi+^0IiJt4IO8F#%#LRY?VWZtWFp8T&ncFPMw05OcTgtcsxs{) znK@~4Na^ekP{&7)L5r9}Ei^6-U?~BpeW)@*- ztnHO<{x%H@R2;0ZVs}$v?4$`@cNg)L6+JP$*1sX}qx{AUOJr7(k(lv9@A;Gc0*NJ1 z3gqUMZB;{|w4N^8{nE#w&YJg)7d=GtONAPah%RU2EDv#Hfr-MTVI1Z>CqK`#YrDu&>T#6H+%y)Dq2JA&{wA(pb`JXyInRtoN3abe6p1BBV&J+7x`g~<>bn~n4D zjb@B~mBC>S+k=0YS~4x{ITfa`{uw>tVpT-3Xc z(D=byL?rY<#HZvrhq0@=VAeNtG)zF?sapZ|5A5)|M>>Bz>G9`O(EaJCT6Cs>v z#BK$}t}sDFMOA3>0-qb0Nh^4qIK)wZb)6o=UP+6|cIwDKl~cHOQ?As0i6&-lp!+5- zb|A*`-KUGWq0{lG9Rb|vRIckTK@vN9)04O*&idT_mAbp-W6zz97?*ChpI{H$jyT

lt^lZo`tMu>EUMt_AdeXp9>?RI4hz&MdB)|y?Q;} z@ddr5e|uSye?`eiROt4}J?Bkw^yu{T2#r#`pB5$E>8QzzozBFP?9J9<>K3q-(1J=( zjY^RCFiy4Q%L_Fr`1I0}b#g~zlHGKczVDy>Dp`N@3rn!y_fS%};g^VfcXnWZVp6~C zv*fnTfKsK{nP9I9lZ2g+L_CQUl%w0RGR;T7|BH>}QIx1I$}hc}E{|T3VAFnab`o0@ z$)BxK!B+zsJ5mV0^g-cY;-m~ce~UVUMt?d)cN-)RkQNQtLc5iH>P?{#-%4rUmOS`(qnPrr(&f8D;90bR%-L5IOt!wVa-4iR#V zC~}xLL`vV}YAfZi*n1J4(wif-mCzprStG`WBNXfdy6U6^LQ;fWa*Dk&zQHoeJAE>k zqZe(X$v4tj&r+H~@=7|wI_g6IVFfioqjiGNawsrdsB9Kn;WzPEpMFyNue?^Zu`ZJ_ zJ8wDg-!a6(Cxw57V`jBu@$_SF)5aQmo5x&1>nMtfGUIJ0JxBO0efT2>(Ncep$2^yz z`zN1H-bq7EhmTwdR(>gNs*k5ED+V1ao|oMiYE+(KN{LUX)2Or;0Ns-xN>mShs%ao9K9Md3o3R8a>721j%G* z3SdwXrB<;ZR;6)LNwHLIxrBT%i>Ik7r=p20j)uKCea}EcjVnyN*E_r#B|D#}O1k+k zzV08QA-x!8ExxBWv^_mJS;va|aS0yHqqGew0Mrs}dskdamhFoEvXBF8(-JWI3Z!B5c}9w6X4h2Zux`#u2xeh7tIj-Xn!@pod(7%5!%+C#`<7aFiCGi! zVRsZQ${gV{ zEa}{?73ZZ}=0;zhz4W6;H-1M4WUyE}CgZ=Vn}525%c$E>EMJAQTwI}BzlzaZRS9-p zf)rO*_vzVY>QOxFMKvxjAuV^#z9aw10z1_ks;TVJi0$TH=?YQnz5USF$lM>pK48W) z_<}POqB)$RH$whl6o+{%g+Ay=e@bJS_suHnAAze`|p_i@SD7=k102ec)M=^=E6$>(6sOb}J|39jD~;b7qG3t5FvQ zg^x7VPbupOjyjhOMK=nDYi33VIL2?z5>bFf2izMW;>H*vI?dV}I>O)34cBq)jWGR; zu)xN+bBkxQjQTX(Til9*oL|Uh4)*gV>L+J9xY84XSF?M-+MX#~y;9>+95&%ec3Fa-Ns(B|Ym zY|XJK%89}af50&K*1qzy82$Q=Araf`QUF2m+F?KYBbC1+JGT=;#w0INtww;MxX0^C z4oNuV(NrA9WZ8|VF(p^s#R%Hja@j?G4-o>Il?z z=!uwzOYebo_r`4Y(ooIQBvh3BKsGE1);SnX<`uRus9k8YR&4*)FS2%y{SRL)BEl_- z#rFr)ER6L`w}iBWAyD#=BtUGq7p&Y{1nLtDQPh<3aNlvvF=r;QM5F^O+0KSWXT(%Z z_G+MG+n?Hhn=1v3Z-$#={Dsi>TaDmRXGEKCOt{Db4Xyksm=Y=8C%9vz6VN9?4vgik z?B))L2u5Pj_V=<5WJ(T0Y7XPO4=aSeM(-V_t{v2!Skxq$$CvCCh=doSl^4AV&e&k4 zplz1E9_IL2Gs?C&Kj~4yQ zO#{Pb5Ta?_BxEYp*Qb$n+&b5k+SnpIcO3UDuL`#b4!7Y^G=sn;5H80O!8V;|yMfv# z*TQ2haLeLWTN%T|VTkP{f!S;*sB7)yT+k|mFJb0;#MHPEVqlb?Yy3D8M}0n2DjaUJ zg3EfHXji*&(k*#NNMO$m#$FXsS+=n%y*_SJ`Kt84Rf?%&>@M z7AIEsn0IkGaIgY0-ZVn--rMMvyWs4|d77el@(}lmT~9 zxBoq+X}{!tEgp9LGw$Zs`4utkg*)W(C-aqv;JaSIKQl#747vmlE&BX5}U+>(-EVp{My)GIVd-)VZ$RQ3l-|ZQg~$|3b?m zVoB1S4hC17I8|mny)TeCj)gbO+FK(`IMF@0N8x|MFzx-174PQQ9Xt{&;o4w(XurME zom->f-7tZB7ufRaz?JFUof&Pk58<9~+#Z6>;mS2tmah7wPJ}LqLTGwiC3o(&+z-l8 zeV-3=wX4p|R@}ek;g7YDjWbe+yt9e2ppBjvjIpqs(SPuN^-LADNi(8NPZi8Kx2QSu z%yIyTSlH#{e$Vv~%=33|v-fhL^;(?)fF!-FRK4m!Q=G7=WJ;T`k|S%@ql$rI(0*%W z?pYhbZN2Eg^y^eL-hC?J!vc$`4y8!7p*FUJR8_SIFx9IN576Y{=4qsvG_e0c z9!zN6X=L|E@T5oOUCHT95_|Mk__;mq@m;nTMXA|eTELj& zS$-OEme)pKZc2AKfUGL3FTv-&9q0BZ=ooL_qT`YoPk!kN#zbbbJ-GPV8~%}2wbJsE zsy~{9SDb5O>f1m(mFQpBvJ2|rzv$>iwu{a*a;b0eCW~}7zK>_|V_uvbUeiwIiF?mA zl;c}WAycJt*d5(iNvCuCmaSrluTU@3&3=1ubj!F<^JgHD*#^$E)L=bdXLk%g)dSJj z4!g|NSs3^`}J42?LsHHd8c`DD3o&>2?Nx5(_Gl+;i-W5`3T)8col$ zNY0SzVYYkZI-04_sz-a$+3lK-FtH%FdxGqBHM1Yt)9L*JC-oaNg;Q(_K4# ze)zZE!>>e(;p=dGvee{qzGl$6wVu~A0=0kj>F+q2uXi|qeev$D0}+;pa$2cu58eIb ziS3~CdV@uT_!|fP}aAG-eD3ATj04hj%WT;G<<7LQlDM2yw>pyH|tRCHzRjaIpFcegQVHFH@LnH!B zNe#M$ObxjlddIo0uOw#Gi`K`-)l)zL=HhYOPoTQpjJ9vh^#mhjO{Zt+_6Di1t<3Gq zhED};f4dpA@``hwSUcArWC{z86Ej%5CC-kEn(xN+SbN*nGTCym^_FUocF>^g17$i) z8Xcdam^KE~e)~8edg-;?PX{f5wuJp89zvXBBDn9U9od_p*8tXzqHE9{OKIJ(GM4u> z@W+!L=d?o9?zO(`wGg+xa^a`CQt^FRt|`eL>s#B|66*SCGfo_cxVBn5kGZACcegoR z(muB-77B&R1?k+DhDG0r6SyZDYcJ1c`WsxmwG%XEUWjYy$t@@hUB*$)^vU<3BXOVU zzCOfzJax~Sci3|~_UF5YCNILQ`>ugMua7PHJvcT+jo%i*Cs}dWHrcUOt$!e665!Ud zR%}!LlcB;4!Q&t)8^Ig$EyZV8Eqra^+!M~=bL6RJC|o%lmCJCwPKs@JbJ)h)e#aNZ za&RkUHerOgK=M-%evP~q7D4<08fmo6(Ehlssn+S=oS*Z=V`?K+bsqS4Uab6qF56JWE_&i4!c(J?$4k)8yq6i89l$F=o z=?SG~sf}K-+FHnw{zxRF5+Y@qPsG_T$nvTj6`yaS2qlvIM(`m5%=zt0eNg}PqGJq+ zwaML@b{_$YVmP#J+wUSAU$`IJ)UdfxS^Uj(qFk>IfJFG+^(X0PV~!Nck#2S zQ4>K3Q)4ApaTyefi$te}jqfgtK(RSC%|hrnD|z3@T=74}N~TCg6HK|BLFdRw{Z)F8 zJ1GNIm(=2wLQ1iR4?LkCb$eV5*;4_d8PrMXBh)`QR}Y3H))J#C$56{G2J{zrgqzi_f2g@JS|o{?%oZRhH77^ULW;s^#M1(88YqNt$Rk zS;#arpK%^#f*Bf$Oad`w)hhm_)HM{d7+4|EkmMQj92N7l(<{1I$7`o3)3t6LoIesL z$9c(?N(vvTI9aKNP&Jk-+F8kejr^GoJ1W;~KT`ZXJDu-r?I(_kr9vJuTg#(SE*C{SjR_#V%wKIy=APbZVnz$ z^%8@2TgJRVhD)Nao$YUAMs0&XJoVN-HlNCME5<(p!F8M0e%K=&uh007^#?TmE3FF65jW{m0LskHrs&+`5N;1S$QS2w|!pt!jBUHokPuq zcg1lzc_s+1okT8_5y05cST*I;sLVUb%I)7u|266A)JpqOPaT8(bN%x}O)rT8`-|oY zJ@G(oo+f}j`e|S%QsjE*q}|EQ?c{gKJ>-vQ4m(eX{MP12oy1jP4pZJ@Q_9SDqXfkq zd@>|v){4Xy4lea%yh5wd8`>kN_Ohw#x4YlbTUdUs6{KH#uTaYo6|wbjO4yu?^V`&o zyGxQ~?f++dNEJ;`F0$wGS5=Ey2#khOT&F*%nQ|1ScUjxtB+B?&$;)|VB!e3&E!z3T z<?FPmPYLMZZmCSrqrc4rmBP4B56tZOnDn-Od4*h-dG(XbhIkwZ=T&s z>qql{+fYwhEy0W1d{50L;#>_4TnZ=cEZp|_lNNajp{wYo4h|%}?=j)tgV{NMupbxZ zz>vC!!z20@@46cE>%ww!D|VBoHMvnMz(lTVw^8r!<3guq%(oKf!z3#bKl>2r|GdA7<#+1TaQZWE*Awa3=WCzuZ2 z$Boik&@+sjXhU&J$a$>AX02a|_QSWD)}6lzt($&<+-|9m?)!jS-+00jDow0TSxFzW z)xqUERKi>6xt{Z<*>ECc;cd)w-ewnEkCs%y9Z{v$MTgiUcS6MKrisVi#1}gcaWCr& zvlSodDj_$f^SO!E-P#WFY(PX%P8zKR_f(Jl!`paYXIWA2Y2k^EIl+g#@O4f}r1AT} zQjxiM+LlS$T;cZhi1U8<(+g3_Zj!H$ z(5xS@_st52ywXQMudX6avmbDbgSzzA$&GKs@;2Cwuc*IttLVRcSb*>5ZdPP(npTmVaA$stD7e0);=t`>eQVMp|BV~=$I@xb zL7|X?gNT#o_j_RzBoRI8IwB6uLhkR_JRi41M>kU%5Z@}83)3SCO@qHZrWa7d3k1N0 z6@%OFyQzc-Fn_R+3MZO;TQd>fGpYQv^YlYxCQ}7!H?Lo`l!8 zFXE;HgZLd&<39{T*}E}h7`{k4V-YT@@ZlzWg0DY6tbL9wTC6l3RVWhHFzwPHzV}=; z*pw3Qr;xIi!~OnyX3?~8!&I@`__*a8TlP-O!4B{1wEriwwb?@X3Q7gEVg`ZTe(hqO z;e4e?{ewUn{8)UIy-iIj(`xG7PiU(W!bRG#)asg>iWxgPOOmqI7&=&HCQOsesucQH zk@|)>qx8j3%o_%C^2Q~`O76LrZ)kq%;As(<^Rpwq-FdES8JbIrl(hWXYXbacS2c@{ zH(wz$=eI9OcQZGGQ&`sJTlOm&3gqFV?(nF7|7d8W5IbW{vHM{4!J_27*h7h~fVrJ> zwDlE@gUEo9xwH$zdmC#Sd02^)mw6pAh2z?eGsnK2xRD1e65UPt&5*gSb53D=yxvlR zoVkT9)xK@$uE+NgSD-8acM>pc;idc4&Agbnbzj+iSl)hLHMTTwhSn2aYW7I$w^#aj zvTwm*`3_9`8)DAUPK)yS0F$EZ9#Hz%T+x%m+|}^;DM25C4o0>g%u{wuPa6x2iuJ=yj_ge;!5mH- z->w;re!HnpMVHG}9+y09-owEF7JZgA4R3j+~>5S}RKxC{JhD6G@=Y zvF}YTkVyeswad!`k7!~_KG=;P*3DQgDC^{~Fi4$PrQVe1F3Dw59cF%~4JWCn;A1GX zuP8CZERsAz-jx0Goxy;?+VYQen(I*sm;nT-0EvVt7F76BSubf%fz}wndlg`7iOON? zl7C0pv%6qi###bUt@y9%$D=Ex9Qaw%{jx31qeoHKM0zE-UtC9TtTN!htM;|4mc=!1B5FZ(oRv) ztXbjp$k_e8vipjm6FAsOFxYnuVD|b~e<~uVD0Z+QU zP6qa*`!kqEAWXFdw!^V9gSbast(9YIC-vR7qXL7&Ta^>9w#i2)<6cq|Z>y##m?tn! zCv#MKmKgqtFwfM*{9~zd|715~$UJ9VI>_;9)?9gd&1x>RYT=P?EYx;>Z%`t(Y6)!T z^L2YsQhKSbYGvF`NO5~vTzchS)!JEAfbsUKi1gZ3)dnt${r%}WkMzbz@D>Y;5z*PE zyVMp3ct?>%jqhyRUTQ}b0p7C*OK6|%noI3DgAZa^etMqm*GV2Ef{#?ehdE&DqO+rZ z@D>Do?EKLlc(geLKC=b`u^7Z=z+WEDnqyDS2+nulXQx<97jJYX=0y#l))vY%T-|&+*Eu>^)Tt3T4$>bS7k8M=N!< zGX*wuTDCW-RdU9TX#O>r?lqX(m1sR#7(EhLJ?EGm7yS^n#wthsc(yk^Y`964@0M9{ z^rn%K>DBgX-rpgPalc98J=I)`7(d{$<3HJas6PKt-G?vgfUo5QRdhm~T)*bp(E6Xo`ikUEkKT-LyxFq$rPI>BW7mlglO7>$LwFW z=_&7=w4R;*W6Sp6aeO7^pyIt+qI0I?tox?qJo2NCT#kdrf}>m8nbf52yT5aT;}wxN z2VLqFeF4?iR8;y*4n~jv3|S7HN@vCij*Rd+0zGEt{i|1BJF>2;ii^7XqdGRJkr@a( z3mzv2o-}*hB|B|B=QJZ*R4r?pB-imVrx9lY(>1deCz(k-uN*D!{3Vya1gC#JzfmZU zloNk0r{FO{LV)s8ps8L+v`_HPQD}lwB=v9Dex$l0=W@yJ z*UMtF$u->wb=6svEYL`ql}lp% zMl_S%L;M1l#O>vZ3#l6Pk?K~%q!IPWRdE}vg?D?W*{JQ$D1nTm&RDLR%dO4It#y2^ z&1$aW(5NxOt(nCAu-2%K$EsnYXy9U}TkWP7R<38`ra}oZV08bZ)TlgtYv_He_03&d z5Tc`pfd9hdQS9M1y1Ow$xS0gGn?>=5Gpo#{+&=pASXcAdl0xi~ z+<$Mw?Is{b*Y0p2j~#N8J>D-}(5-o?yZw{<8|)@$e7HR;ud`^A%N>s^x`#9N4=1HN z2csr;EnYVd4^_E4C%q=md3PI5UNgpD?s0dH4!i(=UT&s~$B9K?|K_D_e0@7?X+0r$U69$qdUK2rBC)x3coZiYX3{)V{+rSb-; z-3MUb`+M92q&&jqJc+}2-SnEnpLoOL?xWiHJloC#L445@o{=W^Ci_jXBks}le3AYL zzF6!B9}8~db>3jSmY{l`#H9Ne!WN5j?0E0{#9a4iMt=V$K42Q^3#>bqI`+W44;*pjZxyKUZ=5*24G%11npRHy7fI^$r z92c*&dC$Bgfutlru{R*Y!o6%lpu+pHA_!1$?o~0}T504}S_t@G1fZ!kbbw!1jf;n; ztzJ~HL8`4mNwCpF(@9GZV$lY15Nz^jYoc9f^lxj96KqLsYsr0rgbB7*BLv%;+S)n< zp(eoA0m1h9w)S;FSkn`9U$En@tpg$0sk8+{7wRH}c2Np-m(F(53ia?pdqkhRScQ6( zg!;6geW`7|M$mo-p#hI)CXUuR1b;QctGc5#N$s(AA25XXIA|g;R1Gc9g;r|`3}!+{ z_ggD^d`6L5LGuEj4x!;`_6b_SaZ>N8T+gw3pAl@|%5$ITxW_Td_7UWl(S4suM#qUK zpBbBWK2EOb+}3}`0(1E7Go?basqKrU?Muf(Gg0m17QVIXfB>!bDVO$Te}M%--{~IT zL2U2U0pGQzm(}U^VI$!;w7&UxB1>_?ajdXSQIXBu_SGK#^<~8KB(KQEw!n^*i2FqQ zw!`brx$t5bY%j=fPfvI)Rb;=q{V>;WvkA7_@owwbcSFl>!N^a)?e(}?WP2ZWrX+mf z-+pv0axws0Sr@n@^xsl@JwS*YGm4(#i5|*zT#Nc&mA?GDF_Yo zyLxl{*Z*?*q>{;W7n=bQgru6i%~P8{V=1w{P%r~bWs+!B{$G1v8JE@GZ7ZO3OGt}! zgLErM3rKg1NO$vqAR#G@ba!{7fOL0*ba&@-(X-Fp?>*l4clSN#J!czhf7%}&U26^3 z7;FCL9COYL`RY)8i6{o0)^N_SSCX-8Oh$LBWmbEAc|qcLal_|dq#y(mZacg%{G28! z_}I%&9=tT1j+(Do{=xQ(T$W~p2a;Mu?s)2}w<(emc8ijg=JTx)+)-vTVVo1M6l$#* z=IZ1wJrKJs8RtK0p^!v&vQjIydhOG@#yR9Im4x@{_L@vkFSQ3vO=9xK*cc9_^PB5W zD@Rs@=cv9zsxomu=nJ1JLva2x#7{9)v+Ed$G+F%;f;5)-5ixG z%=i#Tak|r|FOrFrz#m_Kd9`(0i z>wE_9Jj1#c4B^Mc#u(tkXBZCQ+*Tfv638)m9xMW-!EGf>gm>4f-7~>+P|z!7*eLTs zaks2QRq*rB4%p1J$IofJ@Jnf!`)CiE85-roM9G%l9tyKD(+>5pud6dj^KOS*_tI1% z%J)B&+c;?CH9i^=eP(Tg-t+MCWtoV`Zg`2T1ZGRv;PaNc@a_&C$uPO+Wbrnn0y2`U zU4F`@BO^*WfeNE4hWR6-YL-0;V;YXzBV$_bNQ&b+zI3DGdLfdE6Z+AniX#Tem_-wY z+4-ZB#>G8~Q>NA1qf_RMNZ8}AJLtw(EC(c&W^5-+$7by31C?ftd-BILt5ix1%IYq<;6fw233`gGbFYtb&0&T{}o1g<@3y=<61 z?3ubNvp=L0u5R1A-ip6DD$;kglXacpv{@;*J8jo^UZHN^vv~K$_OlEF-=`4?Z17Fr ztRXh&v=7A||72CHkoj~y!Y$^!!(i3*(mMbJ%(aw%O5nU6!JvM!t0FyTTbBB9_IliB zmrwR`s?Cjf|2#413Uwfpp8+oTjZ3H&gk{vHm23JpTTmnC$ZJo$j^2BuohL%{u=8)^ zOS{i*ez_)n_vx8FiV0D-7Oc!#IdqJL8>S7h27Oi&Y>CAis=EY8Ldq40D~#SG>hnn0 zu#E(hFYYj~=&CzqH9s^8_N;_Tfk}iFz(*gD7mQoL1aUM|CZPHY@ho6l4w6#)75j_j zEucG!w$kwh`|<}QVvZ>{KhD_?M2A=;Fwhkiv}6i0R1d{F{=mZ@S@>S>DisU73ET2a zyf~PyX^}(&z5NSBXi$JUDG@_Q>r=@7V9)L}d^{X+GU~^^!7R(-q|qtl2WCMYTci|f z>mBmCFTzWQ3Gw->JC^55!ZVN3#Y|KLSQ3&?t^Gw>s2}vE- z`-y$StCZOtQgSt=PK(nNyqDFI9#XOitLnFH`tbcOv-eZpA!a=xI)8CSY!iL8MJbZ$ zD;5}Pow635#ZQGh=;S1ye*g2@6N!x%A+6yV8)cbdtT_!AtTyjy?qwr3uY-r8;mf@k zdR|N7;0}8VhexmztUnF87|5!S&!KEtmlrb{@y9*(=LBhOE1pZ`2{WfZp`%eL6dz7M zC`&R~PLh-897)Y7&w(e{6dSwf%P1Mna}nB<26GM9CDcJ5@+e8V}6~ck4@U0z3P6W+a)u>`^R&&dGW0mC-tg1=n3;dl;7B}paGYN(FgS?*4UwAxd z<3uJ3RdQK6!WngN+3PSFAv&PEfh~i64 ztgvVvUBbf%xbE5M8pn2XdTYG=IBnWKjvjW+(&Q<1HE#asUgrWUPqh5it;Z_;CWJ?E zrRr;ZnICF>W{*;E0u3;oRXgIO@*|lgH|Sz(%6)IFV-HOX8OCtH!!2)1a^$u*PFgwp zKNK8g*>1nq6m%M@6gJ71v1id|h@I%MvnvUX-!hPnogAsKD^CU07@NmV&781xm*iKO zzmJ_>W~*#!im$RRjGfuGJN-0RUt!-DJ9}KgUc5lJ>%1E~cYSgyc#vP_h7$MX9=rXg zuvZpd3~}>4!W{j$gqFV2F0W9)6FDG_jR3u?LHwxWLb#p{<+nWT>SnBNsd(#Is?)Cs zl{h=x)F;CHxt6BV%X-eIEu)OpS6M3^XW*L;Vz%Q}8GJa4+zQyMA+BYfgJXy6VU88V zXWF>2FFM~P*e0f4t%#jY4@S%$74d_%WEd`o`!tW$-|lWa1P@(|z7MjAiBey(dheWR ziF{J%588dXTRQ=!&#t)51grNs_N{B4%Gq-)3yZicAh|iT@w+bBC!V{(fQu^yU3Zua zTo(QL90%U>9ktw4EkD(|=y9H1O6ucjQPgtm?7BI^V{lpB4st@ExjCrWt(ngbG8+8o zs$1N5z4czpd5NLn_+uOIZ2b*Sy+q>f@M-O_30XN^^W~6_&|EQ7v;)BdI(I&cus2p$$$!p$9Q*68fx1M zs<{w+wtH8m2--CGo_-OOJHxi8G=$uXdv5)AMY`{8-h&=ufc;o_*9#Tyy6-Xu_4*F# zU8dkue;mX66$e??nKN3QxqLHP5dr8JS_)y{i!!4AG|KKpvia1r4C=(sYx{pfhF15%3 zb!;Dz`C}R#*_(IhqyvMLbe43ru(TS_X!B*M4bX|f5s%3`Lg8mhs4r3IyACKxWatw^ zDGU#2d(bHlWf?WB9<>e9`$uSUzy@R!WvUb35ioSZMgc*~WX)}~&caRjvnv$Kl zfd0?{@jk<4F$;wpo%A7!)(|tE3=NSS3(0*J7-m)oQf546DysYJ6SC|>2XuP(uV3Hi z&?$Yq7|Pyez$(Vf=|9BacAxRwiY2C}lT)&py{aESE^Q0RPa_ zrz`^d_n%N;2uTi8h}rO(hL6f)uqqGJ6qoYymkDs%JkdGgeLF1VbtI~VA?AaD>deC9 z7Jm0yPAuO>#6O&wH(V@gSgiU;eEv{ann4_7LnBy5lU??-_BOs{c}~Z2FnT1Ze<;@c z{<+fdbNw>bMHY#98;N=w={_44H5R_rBMG@-sZ{|%UzQivLoX(5gp0#vd~U-+49;r| zakLS7jB+uCatZDTVcTKJ^RgEZtkSvR{Hgclm@&OGFvd7WXX{%(+su8H5G-0ci ziz$0;8^07Gp~^ZQR4#aUEL~nM>wm0@gC!Dutl4F*4acV1!>YWADcZrRJ%*{rIjRwB ztI8j_O?g7TX8V#=QHwxP1rJM;Y*bI_L_4rdzB%t@bGt4#mIii&md>a?k>X3eNCVpw zb-~d!lhKz5c7$M7wfiS>%xsdEtgoXY)r45|-)_D(KVj2Xl!&b`_8B!!993{;Q>m`+NX3&-HlnUbdF&H2_fa#joyy^; zQ9PD@tckfs69K`g<#nK?KBN^@rPZ9Im0pt-tG#s_lC_S9^;346>>g_^0UOCvn?P?H zJZxK)N?S`|TeT)zeS13@Xgd`TJ4<$Z!vQ-Q;oJO;J&w1%Gq!_orNgC=148=KK2#$$ zYsZ#I3^jTExnm8HJx84)CqXDjcQMC~V8>)^OG{SE7JH|7cJ^X>*{;)wvQzphCFE3l zm+UB*`ctRnF(u#kh z9;>_k-u*^ zLWGylp64OER}Z!i+y^8)4sQx&AD=Mqwtf)9xP)nyzHgQO`~J5Ocizd2zaw_=@~YyA zQpN~VMk{o%h#h~M%yH`&_BI|CtE@6x9P`L@@OFIQ4*KBsvC5~}!Dm1@ppE10c2xkx zsOKujJMMA+rw*Pon7+BnZ^k~n732tpa0o`K3f59~p8F8E=@a8*{WLyUKf3&cHdvV3G+xit2YL(QR>OAxJpE5+7V>sy#%ag9I_dN*yfr$NgQ; z-s_!*F{ngJM+cvB_$hIQd!2s(M^{FGaH8EgBkd<5cr89)SAS?!j&6yJ=H-mqWQ&Q7 z4&UU6*1?IccZ@Z_Ipm7=%bo~L=8TB`5PPT`=Z2GTX&F_n5?`zmWE}neP$e`UC+HF< z0*n)iQ{^6bnuNxs8OM=~QnypZ( zP)(X-je_V!n!;q7oKw0+P3o=+zZY9vhh2tww1p{`%M_b|Va%(UiUVD)C-=qzKrxx- zH5r*$TB~+hi#XYC7c@gE8DW!Q{%k856Ao*NyL2NPScsG8kX zlfM{YKY^8*>11*oQ@nYRMKPsqjr*2ZO}kOmm|4v!elp9}sZ>a>yoRa*eO${wlWsC65(-q+xv22jH>Mm6OE~| zKyH#myt35GVx7wdJG_FB*c`X1isaf_iP)M(=PG_Ro%mRhnXdt76;iFysgqOg?KBJS z;_AyfNs3yC+PXfcIz6?I=aUT+YULEH4b3$nA!9Xg)1}cm1-q;SM?)Y)%&&O z>v&BN(;w%#o2V|Ez;9#AI-DC3%9^64YFoIQg|2F4ra$*MXG5tMqo_3*3s&o~v`A04 zu(~vD;?x10u77R4@+JT%kYjUfJRj6M>MWSba+SgL8FHCqo&uRmb_!OpfSh#G0^pxJMXv` zXxz7cJn(uP9Kt&h0h)-epNPMnNameP2Tf+zPv&1w7V}P(gQlwMr|Pe#8hNK%K+_%d z(>>SI1H3aMpqYvKnVIXEdEVJ&(Cm8s?DqBSKJVNyXzsjz?)rKTg73>c*Dr7lUyyFT zpz+OPy3XS?%oE(q6Z0*Qxh_yOEYRI7F!3$2x-N1yEb`tgKIK~ya$OQ@SdzS1lHprc za9vhuSk|~%*5O;xcU>`TSTVg>vE*B|bzOCASOwjzy7R4hxvu#(tOeezh48IMxUNSx ztjFK1C-ZHjyKZDRY~$Jwt8;12Kcr|T(>70wr6g( z=lOPsmR)z&8+NvDcJ}#pk6m}q8+Na6b|Lup?z!!~<$L&TX6PM++i{_^xDYLOV(v)Z z@c{>i?LTOYGr;f+zkGmTBSh*O|G)%4m(kL74ovGjpwOO~~ld zDe4_oX`PXZ_d{<|?%R~IU((3ij%d>M`*dW&iQ2OHTlmWl;N0*zM>o(0dyAYyDPSAwPnjY{T zZQ5_Acwv)m(WTxgdtu#p>3F{xX(RT6a%~^?ex!@7>x-N4PS@u=KMUKJzGPHAfVNdc zPxp#UNjGT3w_Q(-=H^|9zkm_;%aBj=N`um(^f>791FY-n;qtsxWZ~k9V#^U?1!Lh6 z>Ri>!D%#!?WJ-E5a4%!DBeGw`nw*?vM5-ikQbb!8+-9iBs|_g+&{kP!2kcQ2mliI3rgI#gDM1M z%k42W1~{nxuG9ve``-}wPU8Ltm%Ff0cvPB8HwD&O(nD5Tb56D(D4A=9Id9(3;V7rh zGOK@IlXxN6On!B+*>$_Mo=ic6*5I&u{M3*lWGHyFl^69afxO%LA^W z?f+-F%Y!ejk4BYJ0!bG*Qi51@*(S7$AgZTzNsH}6LK5M8(LxECI`qu=(Q(qk1slji z!xm+5&?DEW&SRsnEFI$w^aR#Pj1}qOUWFO+QOP{{b`B-E{inMP($FD*R2x9!-yVTK z6fF3m$6FYLZ|kX^dcxxWv7QQ(Som|hWm0R<*Y#BAmT4`a!1++~`h^<~q`y#jG*e||G@Q#HTTjZ~8ccAHHimgx#LarRXt%~I@EWkYWq2=epS4Zc}Y19NS%ObFyIP+f|`Fg1S*h+6Q@WvG^!C55re6 zQaRP+k$QGGtFcc~99QCOa8<}-9cQaE6U|Q6XA&*Ysb5=okNO&XfPnbNbQ>UZzdgkP z0PcL>cV!YM%X-yMNlnM3(nIDOa%`M;Nq!Xcp}J}FKS@`ru= zs_xse(Fd%L?){m^LH8uQsmmA_x}q>qrn@R>Ubck^Lgi1RmKrw4+_O^L zJStrqBZG_LlXNySiU|Nac_Y?Vws|laBNR>z68H7BhuNR>T-jSjHK(?LdiSahHrr#3 z!Q$8XM{6^fs^B&$*YlI@PdctJ{HP0O2YXKw4UlDanmjgT)LLoy46eY3rb?AuV>MT@On9*edcUY$JBr6m)Y+{I0H2GvNIFIo;S-j9OVi0zWa?%r$ z!I2byoQtTGw=^des2<2NlrugR@uH!|MEiu9HzwNH?NC3?tKm+dp_9Y?S9*Dh9I+5oQ!4d$yAM7d!8mh zl&NmmjuA*}-U03EThd8gO8fY8b0hv?EZl&4X%ANQR$(uG-&Sec6PUnK3$A8L>pl^Z zc$+SXvC$(4Awh!U&vJ6{{Z2{-f&HTx_5^SG5!vbNAYNL}uueh@U1k*sR|XbNSR~R^ zW_C=1O&MnuhS>D#tAe&qXUt&;i$8f1b)C#_L*5(@t?=XQ?@)Z$`l7$56oBFvcmW&s7#<#cZDEYmiXmF63Z;#H^dojUfl%NQh@g#C~4 z+?rS3ZGv%i>lu=BF1g14<=vf8Yhpj`GYrojMlV#1 zS){bp(B_t{BHjFuc@%etI}=DurmQY#WO%X!W zEsz#*Bu3BA`pg8K6o8x;d=Kf7@W(RC01MKG*sEtk;hr*q3O0IJVx>a#(0;`F$qQx< zl5+`M)CYT_PE9JpDqBC!kgau>7}2;Gd*m)W4ZJ zv1xQaXIZ9-_@VNPZ&8+z-ZqR$r%Y#ArGN6H(|~@1BfBCPgJp{RIr_z_L;{~ZBvL1H zW?#xv17xxqmi4}r7pbE*@Ul4yY4W+6Gq}t&V})|1vv<1`OvFpHda$0wD(1haP?~A* zq8Xu8^p@;fc=uD#ryjMhTNfduYCG<*vltM4 z?LC5kx%{e&k<;hI0ryaUAdF%#erA+)dn#RKaW-C+eRob+3<{j@P_@6<9!8_pTUUL! z+MgJsMB;hH$D%en;%V z@z94P#v#eis(fw1pRtv4F@Rx}G`7?Xr?+~;sl~LfaAYC0^c1>0f*BM?{NC&(F7Fix92VZ zz*hu-bLX#e?gS(aU^IZhpORLTbUUl`{x{QRvkxlCAEwRQ{^;p%W&W?G%|FWgIqDrH z>9=Kmz0NqX(u|&14(pBS+cJM&;*+=V1ma~`1F526)Eb?)W&X?;sp9eCh1$)21QHc_6Q!?*ax}UC+xdUacK%0D2KX@n1_KEEdrj%0SQPP zKyUy7K;nMcAvRJ#`!m3A00AJi_Ul+W07(P*3?Oie08sDz`6fDiORY%(=T3mY00RFP z1^dsW)~Fu4_a*Xk+ugk_4#=bm&=5*~D$V|!E=Eir&&9gYoGw*L?P(N|GnS%=radBF z@o+3lYv=-zyCQF@RBR+e-+4G+BZaMbVYFQx6_-h zvG*r@pqQ+%*miw@jC|j6zfLTKUy(2F{P}8!q=-+0hPLr0NwG9|j9%X9^v+zP8-`oM z=~?5%0E3cND5cIJ$oEiEnm-lPaHb#qIrJf^WE5~9W;m8N{AfAd$J$+v2 ziTU@K@Dn5}3tI48tgM9U)wL*Pg z`Dqt)WW<;SjlYVsY7$#Dwx0FPjCb;AHAt+yhI_5!h>W`y?HwtqpW^31rS;+smuR}T zy*NBsx{YclX@;edA8DqUllV(hjL5R2EdEfgCGkv;wYAh7B%N{D+Tf4U6H&}h7>sk2jkQxK^&jv0)bpgo%5V%F)pSVooM2k~e z{Qv*v=y2mc__HKC{6r)PR!cy#y*T@}B%99{PaR=AlBeCQI|E;CJW;4Sl*1~Hm8DYo zdaM)E*eYkHR#T{C2FGewy6z1`>`@ncfn2Ej{nc%?GDgv+hnNyOoK?ko&9)TAvsq&Y z%l^SrW>V_WYX$+yoIwfdPHTq!34#TY(-(Wje%VaDit3kkEz_*nG#Hxw1^wdxlP#3}i^88!T?y90*exOz3A=Y<)&`bir-epfiWPp^x_wKVH{e9(E4pVz(#oXs#^# zBizt|bv?jvb@Eygm%cDj?)EAy+JbkW+Xvsc&)aJ|pYE^5iz2~~Cpuz2S_=cgsjf|m z;_y`^yAsHXPS{e3mpv3DD=3>bV=2L!^SWVGmVb+E=S||n3XStfkOBVlVi42g6B0iG zHqK-LA!rLy8J?5YlrUvs+r@VeBst|HNP0dje~44Ce-ywa8@;609_ul!;zH=-5yWtN@lOXK(ACpo!!M+2Qn zx#Uhaf$Hs3!~ zENVP6ie)mzv&Lu4N2V~^w?`#6v1-`P;Mwa03%eX-pES{0wb14&AEvRt=Q(^vB#b5h zp)~=EEKENGgQbt1Nb*ok%s%F@MQ9U|wE&tMX)}qME^ta+`)x#YB#k4l%=l|9R<<#5 z6av#p3!-wn;ddeNrqd2>d9hO(hopgKvyOcr+dke{P{_RJoVSCpD+S-~0>;ywD!ot` zgVQDEe;fo3Bpd+Q1Q7V=9}b``?x$;u`#KdDm?;U6`nN~mUl`DuAOOhRFTO_ktIPq@-~pol_6U4;Ud_&3nFS&5!lMp^M5 z0c;nUDZ#z#YiZuXMpSX(uO?U1Ll4A0WhQx3Z)BxK!fzx;ofvN9g6Ysv=cs(kfRrHnSt(oz@CJhpMh-v|*`kl_ZhZ8f11puFWfT5|GF$8)}U< z%yR7Y-Oi~;|3q2o;B=W^kPvE|Rr9XXk-o;+kBh$Oed{@W_I&T9Vdc&&jd{&cE5dH+ z_2bRV-1!nU)5b1Bo}vWUL&JhbI2gqJ=Br7X{N^*oI))NlImEpz{L-Al4#KdwgO8ho zH6=}sy-ts-7}W4fyXY!aANP+9?7KZyT><(uWgJV*KJ)6*>9o~0>Y%2w zT?6~Fa?-JFhHcjUKDR@$#bS*4oI;pn)zX51rsI6j>9oUVU+@LTN}OCm*-E5&5YvcW zEMfJ4W+spITAqRC*=DBscJ*fS!Zc@p>1yv~-(63-lZof4J4fwJ*hvoi74~t~d)+97 z_D7v#v$h9a#{QRwOq@>aiP| zG>x2;jlN;-dCshp{DP7XGbh-~#%dlxN0c9@fKm1zcW**K=0O0p^Iu)#f#>^wBJka3 z5O6&H{Eo*zautBY{XbuS;0M1x0)I%{e{?+lXJ(24a}<8KISN1|_v>V-&i3ce z*?!x;6MHLhfWPc7PyqPJuZ_TW68F1B`uRVRJ45Elw@o%Nx345^ETk|bP+0%9$q78& zUqIkH1p|V#UnWQcBYWq2=epS4TpK_q9NS%ObFyIP zo11GNtv5RmqS{W-grb@r@UCszZzp+Sb!^e49F)DVF1&QSosimzy$D@@xWYPc7X8)v z3q0RnLg1gs956BMr<<4tNZMa|;ej9h>Ii&yI0e#cKYRl^!|h86q}P6RW&+Rm*AVzl z!GQD{kY4+1uMY4l0D+$Yfq!!D1o~Ef_Vn7f_B^0a`ZI74xX1r?1b|5HSM6H?B<{bR zJirBh1_Xde?iZcjO#lJpGQ#Crx@13=(!LI8;5e$hzoTl*En zZS%tF?TP|mIPTv>FW|C21_HM-hXRW)TJ*3H681i|dVA5WKLid59q~zVijY6-<5zX< z#i?!YAFw{UBUX~u5k@H|uI`AM(iO#sP{d;)yVM@e!7$m~QJN_o#;nqcFV6C)E9Q~g z!gO|7_E6kYC9HUOYjW8@Y01Yuo#mS?S?Jk*pw|k&I-42A1c05q5o;^kJeZ6T3a18%`}*3$>`!{G>@B03Q`e8aOfe1 ztZZbVVkPCIzPx^C1mO>`ig2|FbUM%?R0_(|1713wF9yhvvydx_Y$#^N8|GxAzDMHZ z#Q%Wa86E4Wuj)vV>}Xe>;kuy@O%W!$@p(S^Wdf&RiZop(g-uXo)p|sSyI)pXvTn?J zLVSZKPLc_{s(yB15EvIdFM~RrDlPC)t5I+PmUwQu!$pinzT`}2dQSG^GNZIoWIP(< zG%kre=~lwXSrZ+_uGfK>Gi!7%Ln~?98e`OW6cw3t<6D(mDL4gUi<1gFDYdROX0>Mm zm-M3h=-7k;bt%R%H8U&5MGcou!wedSDKYuY>*9`kpF_x=4>ozb>-Re$$ZQoSV_ADL zRpZv4rwI^csvEXr1k##!K)d>ubW)enKK|U?h<_LhH=thHgH^p%*o)t{RoeChCa~0k ztC`ZePlP1irb}XM^aw&okl^^UoLqdrlafJT|0sq%!JB?Wb~-zVm)0|^lMq9fSp~wC zfrS$mi8Pg&9g|>F##x0SHof|)pzYHca~Q(nPo6|wC-d8oH^)OOd_2`(Yz19Ym+p+U zZoKF+%0FN6)oQ5v$cW2a^*iZWm z!?TCc3l(D)DQz{hxn-+JH$P+^#ogh~1QL@es|%W05M^SAm;_&LwY3qvCJ~Sg7hKw>X}fur%a%NjUJX*sSrK1 zAF+P&f|&zlyWeXQ+TdrT${P;^h1m`QwSCjb(NdViPxgY8U6#m6Cd8i5TLlJUEK|10 ziVIB+gl2ayQLAu@$u*!ycp$=l>pJ}Q;r==r0_dCmF}@wZUHpv*{FCzu;8p*byy}3Y z{f!I)F8zZc@Lg6J@T&jtUUib&cZjwO0?4NPU|a%j=kGuO$ff|<6d;=dWK({YY|7V> z`oNCBzvD{*T=R!Q;Q#*N7LVgVOKI_cfF?U##dKQ+LF4qJg)Jd%4WYajqkhgvCgyWj zw=+(xG@~c>F>{8?yW28I9GA?&m00-e{t!Wb>Uu}{mxJjqtP|2cFy$!3!l=EfcPP*8 z&PUR+I-bO&6)eW;h$N`8T50kya^@rZGLSD&Wf|R6<{wfh?BRIKmZqX))b8$X+DnIj zNcZKl=P52BEytE#A*y@tvGa-5Sa$@Dk=;A&!i8a$a* zKs%`X@&-)Bk*t8JNtPt*&+Cu?!U`3QdPkuNyI@TeRG}kJoU@kVJ5~Hn*FPs(i71Ff z7iP|^rDtu?4;wmK&zTWYCBlk9Txls(N?7ER%hB>`e9)H>VSNaEP2VNsF3ENM2*jX{ z<&a&|2{JWt@9c~S6F$y}XM#E9>0NoA3t$_@|jeYu{*gEhHkB-QcZgULnm8nNNs7}ZpLaCQu7Zn_^B_d)hu!cT+w z0i!jP1=#E26oqvcMyN%N@WzzI^+6b)O1jymXd(LrHn+-#6^*w`2Oe2$R}8<(rK|kn zv1wX0A0)9;J@d+cr(&gG)2wz!3cm@@egakKZ!#$$s z7ESQjw0n)nG~Eo%NOH7HO~?j$#UHUVYxmpG?6;WONjGV%Iw-AaAAh2Xsw=6ZOT8*7 zW$NNNY+`-u!raSger44@qKat!2_1^Jq~pP{)=Z9 zufzO1{%N~~chWs)i@_rF=S$&AX6MTv3<}Oy;_P?MSCf3`FV@nc%r4fmG7By?@+x;O zHjCPj&!*EpYSe6(e{`H;|EpslQ>Z!}w`r~u8%_dFhSCqc8fHzGUjy?uz73@akqgdVf! zN%=e(Ua7qat7gxOF(4VyK(rZma?hKiI~m!&y_xW0&*uqZ3hF%{(SkR3zSSB2*r^57 z8Ga~*25#o>K>$d~{e1IT|FAd#3ikI<2Ds`Eg}@&Q_8*5A0+rIAu2KpC>;?Ium<8O- z--N(-K^ZU;=jWe^^X=RT_}l&_N&%Psu@Lx9;s7(q4{rwf!_Ec_&Hu5e1l-Nvg1~nQ z1_WupNRS4MyZKuP1zhw;LjaICU=GBO_D%x#10e7(MBopJ`;XNFU^eT|KAROdto{pG z0{m4!IRb#p{i5S;{-`PdIk}&lpup|_D-Zwz}|7isN P$*~f!cKpoN4(|U2!fxM{ literal 0 HcmV?d00001 diff --git a/docs/env_img/sokoban.png b/docs/env_img/sokoban.png new file mode 100644 index 0000000000000000000000000000000000000000..fbc25a80eb0d11115ffadd465db4d3668ea73397 GIT binary patch literal 9925 zcmeHNdsq`^y8i;Up+$q$Ru;KwS<2ZimkNpjfmkoF8WB%z6%+$lC1`*I5`+*SsY+Q7 z8fv#mF7TX`kAjeIEbFWS&W8 z=9}+*f0yrl=fsZfev95-{x$$$(U!k#`UC*u!KshQo6r&eydV3ZFAvh@K+;ZpI4M3X z_8{;NBM~C;q{u`2R~$JQOFV>+c5`0qyvlJ!1c^i-db+qoJ^zCO>;De@X+3Q{ z8~_^l0s;L1n?Awl^$D+M>HleJaMDZRq5Ib2-5WGCX>%7&P$f=ZDhRtmmxCLji;zq zF1{)mXdDQ`cVfC&M_823kJ3Ilf6Y^{Iu3hcHhOd z=+0ZeDP6QaPih8zezz4SNb&7v#`8ACB?Au7U&oo21m z0bM7<+K}nxA+{c-=`8JiTY0crna{rV(OuVkcNgGgo=f!_g?rVq54o`KaNc*&!(B9U zoEae4p4l9k=K+9dmEO#EJYjsp_kOqU{k9KHJ6ZyF!2w6n^Rt6-j?9phriFU1wMAX^ zzIB(ZRzyqR;nokaCu8_SHJV3wcvbgGd0eg<2Fz4X1y@S_kRC6?f=i1H5K$fXRUg@^ zq<(-F12EyO7dX!cfOj{eh?S3Jr9)lfFFn4o<0+LxH>2ZP7~&1gv>QkSu#qS-%#y?w zI)BR3&2KLA_*jo1RtHjm=xOm{S%{@3#gT3dkSM)(|Aj(|x#|u6+<{pmkBdmlkq*fS zDkC5one@N!_VXh!GXy|ksO5FbFdf#zTC_Kb-ksT{9d}V> zzX@nqj=r06zDE9C+<#OsNgDf;jjZXY;ABdDYPuLO(&gsnxQwqx9xIPX-=Qtm2N{6w zrNuc0nZzIVHGNW`(t1Cu$Q8?tL=0|NPq6?6>L$?|HIEd~ zv^q+@mdigm(4}kh<9(MQgb>vhb5xJvD*B_WGpdj)cJs6GY!GyXNsznq4jw|ET*ke2 z9U1*-e|efV2XJDwm-@2^`mcc)T<#QT>4dd(qFb_L7fd@Gj~gS&)NZxPc2|~>D1@m0 zR9WXfxs9?$+T0_Hi|B7lsOG<4HgvPpK3$yYd||IA9K5AA?YVR)JB7)dkO(@3Lh`A) z7>5VlB`TFn(<7_9iPQ(6z$NgU8mXu~BBpaAMnO@*SA^Z&z&y>zL11bbm~$=xaKMe7 zoEQum)nK{09x`!-UVVb1Eg~mYh}NJM6(sLTh#JFyC7F)&;>C#nh&A-IQH9wm&ytN? zWn4Uy67z>V`)Iqg;3wfRR-7%@-<4%4);jTy>lFd}7>;MAV^Eb%^+5kEHHsA0F_?^& z6hW*)^D_{hsX3-obkgP~YB(98E#=|p}2}RQw07agLvY)Pb z<8aFLyzMyd3?AkeBIXha8_58?^*vSYZsx)?KZqrA&@ zqCC^_xTOWFz86Vrnwvhs`rCsRBmCk%H7`e3s~7;hqXr-VCaEsw)dTrAO9M;`PYF#Y z_d-JtGJy|!>7B920Ww2ib9B5jX|hOZ;vkXos!IJt3zIB0*durBaZ?0<_8~rE*(B;x zbEhJdczcE9uBt68iOpmZrOm=6<&wc*YIL@b?y=5THZCjCa#<^Oz(Y)-G5Sm?k)Kb7P#N)Gx0$DPfD9{r=#>fyYN;D6y0Grv13QjHi{#1c< zopJ;7qD?U}x_eM%3<$k*SV6|eaO8VGafKx7SxU@x2x>Vi3~&Y5-ljDw<^NC`5U}bF zEa>fq6XY3ydM(xKk}_|@7O=DjrEhKLNo)lS(O8mN0mHlR*^#hynX|(f+n;fZ&*qiw zjxU7ha`HmMd9hUB#~qa??dwsw#d16bj&+^bXn@=iqD@fD0x8E&PV*K4lj%rPKN{TW z)tUcBp@rt8sd)p8b3t#SKzc>EHKX+bk&sgDblVde&r-?xJdG9A9Q@DR`~f6~AGx)t zc?9b6>Tq3_>`?Y4sixi|n_sO)sRupA`D!5g>l{=5FMiD@1seC?6iP`~yB*0&`mQLu7auT^6e+<`vu_`B(!L>Ps8f^X2;p8J! ztxpxoqE#d=V9k5x_|ASRGHSix*U- zvG9~QCdGOW+1;cg#AL01(^zd?9g01mRB~~y>Efi^u5PNeY*+0pVnHC+)bK=+b4=k( z^=WGEzntNU4_HWgTohd3n=ZB)&F-#3ft-)#NJQ*I82909<^XZ{kvx2dhhl+Xq>I@V?$Be(t8{iq2vDi4J1Pm_$1{)V%Yg;_74r)eIu`iUUz4 zx1$c+#X7SD8|Ca_w-7b&q=QKPL^}zE$F=CLJ}sFh4zM!sV5Az|QUEU`R}5>-_O-ev zcK_g!Ik^oi*1JRw-O^c6CA3y)JFevGprtQoWWuM4t=ltEIh#Lr^r#ndGYzpTp&-RF z*wX1!%Y7a3Pf`hJ@9>z;8onG3=yUim$Y9WxDmA_PrSBv}XZoB?+3JNya84NTK@dQz zD_Q#}3%Jd*$OFFqD}$rDgwfqlF36X3>bsaxU_Z zmhxlE_`OX2kWFP#;|3&c^&Fs@DGsB1xn1M3FnI^YT82s6#W>|q{Gua>zm=35bJ#b% zwE!-FeD~hU`FTVQm6*+zBTN4DsY_xg^(<5 z7yb*zj~jYoFZ7HZ$T@!26r*Sowdbb_eC2P1zGpy_&R0o4Q()4L+-a+kA>v&a-&`c{ za^%HoMRJSVo?ggY)kQCB7>5vr>c$!tH{ezOwLMO7Oi&%1pi6U(v}EBQSwRna^_KoL zV-n+8>tvep-2u9+eh->>B-!^aOxaal~Q)xFbLih4Y>94hhvAKe6JEn}p+ z<>Zv3}WJj#a{SB_JI3(Ixc#qu@H4_N8Qg-2{Y*-a^i+X|>IA zAPb;E?unNz5Pw-}ZjF0s!{CMX0%@$x*-#1k)ifDrV;AZfFk1VoRJaTdO|OM8-12O2qK(Ur)hmg-li@MVhE$h9NF`?u{!r^q9NrJTAMk_%liN zC0S&1#ij4o8940uEb@1d_?epOub7q3t7z42Rx0D4YBXxdvSwa>uqW&D6Zfv3nb&S2 z$K!VpjTtBLlQDs?E9%s{3Nx&quSU~ejnR2Ej^y>*LV7yxzg}ugP*X#jE`C^>{|ri1 Oux0c1O$8f5kNz*c_sQ7+ literal 0 HcmV?d00001 diff --git a/docs/environments/sokoban.md b/docs/environments/sokoban.md new file mode 100644 index 000000000..41ca8be4b --- /dev/null +++ b/docs/environments/sokoban.md @@ -0,0 +1,76 @@ +# Sokoban Environment 👾 + +

+ +

+ +This is a Jax implementation of the _Sokoban_ puzzle, a dynamic box-pushing environment where the agent's goal is to place all boxes on their targets. This version follows the rules from the DeepMind paper on [Imagination Augmented Agents](https://arxiv.org/abs/1707.06203), with levels based on the Boxoban dataset from [Guez et al., 2018](https://github.com/deepmind/boxoban-levels)[[1]](#ref1). The graphical assets were taken from [gym-sokoban](https://github.com/mpSchrader/gym-sokoban) by Schrader, a diverse Sokoban library implementing many versions of the game in the OpenAI gym framework [[2]](#ref2). + +## Observation + +- `grid`: An Array (uint8) of shape `(10, 10, 2)`. It represents the variable grid (containing movable objects: boxes and the agent) and the fixed grid (containing fixed objects: walls and targets). +- `step_count`: An Array (int32) of shape `()`, representing the current number of steps in the episode. + +## Object Encodings + +| Object | Encoding | +|--------------|----------| +| Empty Space | 0 | +| Wall | 1 | +| Target | 2 | +| Agent | 3 | +| Box | 4 | + +## Actions + +The agent's action space is an Array (int32) with potential values of `[0,1,2,3]` (corresponding to `[Up, Down, Left, Right]`). If the agent attempts to move into a wall, off the grid, or push a box into a wall or off the grid, the grid state remains unchanged; however, the step count is incremented by one. Chained box pushes are not allowed and will result in no action. + +## Reward + +The reward function comprises: +- `-0.1` for each step taken in the environment. +- `+1` for each box moved onto a target location and `-1` for each box moved off a target location. +- `+10` upon successful placement of all four boxes on their targets. + +## Episode Termination + +The episode concludes when: +- The step limit of 120 is reached. +- All 4 boxes are placed on targets (i.e., the problem is solved). + +## Dataset + +The Boxoban dataset offers a collection of puzzle levels. Each level features four boxes and four targets. The dataset has three levels of difficulty: 'unfiltered', 'medium', and 'hard'. + +| Dataset Split | Number of Levels | +|---------------|------------------| +| Unfiltered (Training) | 900,000 | +| Unfiltered (Validation) | 100,000 | +| Unfiltered (Test) | 1000 | +| Medium (Training) | 450,000 | +| Medium (Validation) | 50,000 | +| Hard | 3332 | + + +The dataset generation procedure and more details can be found in Guez et al., 2018 [1]. + +## Graphics + +| Type | Graphic | +|------------------|----------------------------------------------------------------------------------------| +| Wall | ![Wall](../../jumanji/environments/routing/sokoban/imgs/wall.png) | +| Floor | ![Floor](../../jumanji/environments/routing/sokoban/imgs/floor.png) | +| Target | ![BoxTarget](../../jumanji/environments/routing/sokoban/imgs/box_target.png) | +| Box on Target | ![BoxTarget](../../jumanji/environments/routing/sokoban/imgs/box_on_target.png) | +| Box Off Target | ![BoxOffTarget](../../jumanji/environments/routing/sokoban/imgs/box.png) | +| Agent Off Target | ![PlayerOffTarget](../../jumanji/environments/routing/sokoban/imgs/agent.png) | +| Agent On Target | ![PlayerOnTarget](../../jumanji/environments/routing/sokoban/imgs/agent_on_target.png) | + +## Registered Versions 📖 + +- `Sokoban-v0`: Sokoban game with levels generated using DeepMind Boxoban dataset (unfiltered train). + +## References +[1] Guez, A., Mirza, M., Gregor, K., Kabra, R., Racaniere, S., Weber, T., Raposo, D., Santoro, A., Orseau, L., Eccles, T., Wayne, G., Silver, D., Lillicrap, T., Valdes, V. (2018). An investigation of Model-free planning: boxoban levels. Available at [https://github.com/deepmind/boxoban-levels](https://github.com/deepmind/boxoban-levels) + +[2] Schrader, M. (2018). Gym-sokoban. Available at [https://github.com/mpSchrader/gym-sokoban](https://github.com/mpSchrader/gym-sokoban) diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 11f15cb08..49b3fdcfd 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -73,7 +73,6 @@ kwargs={"generator": very_easy_sudoku_generator}, ) - ### # Packing Environments ### @@ -93,7 +92,6 @@ # Tetris - the game of tetris with a grid size of 10x10 and a time limit of 400. register(id="Tetris-v0", entry_point="jumanji.environments:Tetris") - ### # Routing Environments ### @@ -128,5 +126,7 @@ # TSP with 20 randomly generated cities and a dense reward function. register(id="TSP-v1", entry_point="jumanji.environments:TSP") +# Sokoban with deepmind dataset generator +register(id="Sokoban-v0", entry_point="jumanji.environments:Sokoban") # Pacman - minimal version of Atarti Pacman game register(id="PacMan-v0", entry_point="jumanji.environments:PacMan") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 444eaa616..239ef8f51 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -35,6 +35,7 @@ pac_man, robot_warehouse, snake, + sokoban, tsp, ) from jumanji.environments.routing.cleaner.env import Cleaner @@ -46,6 +47,7 @@ from jumanji.environments.routing.pac_man.env import PacMan from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.snake.env import Snake +from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.tsp.env import TSP diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index f95ca173e..6f217478c 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -66,7 +66,7 @@ class Game2048(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import Game2048 env = Game2048() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index a6ed516eb..c57a82b22 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -73,7 +73,7 @@ class GraphColoring(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import GraphColoring env = GraphColoring() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 6921a671f..4715145b1 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -78,7 +78,7 @@ class Minesweeper(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import Minesweeper env = Minesweeper() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index cf13ce388..f34ab3893 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -72,7 +72,7 @@ class RubiksCube(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import RubiksCube env = RubiksCube() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 8857c7931..9890be728 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -63,7 +63,7 @@ class Sudoku(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import Sudoku env = Sudoku() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index 2716bb6e2..f026bae5c 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -103,7 +103,7 @@ class BinPack(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import BinPack env = BinPack() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index b25e2eedc..f835e7436 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -80,7 +80,7 @@ class JobShop(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import JobShop env = JobShop() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index 6ee9ec919..f24548a6f 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -73,7 +73,7 @@ class Knapsack(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import Knapsack env = Knapsack() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 7e4bbc216..379518576 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -66,7 +66,7 @@ class Tetris(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import Tetris env = Tetris() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 53dcdf5d6..739b23de9 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -71,7 +71,7 @@ class Cleaner(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import Cleaner env = Cleaner() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index e2d621729..bc8177dc6 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -85,7 +85,7 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]): ```python from jumanji.environments import Connector env = Connector() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_specc.generate_value() diff --git a/jumanji/environments/routing/connector/env_test.py b/jumanji/environments/routing/connector/env_test.py index 19478400a..12a6c1d94 100644 --- a/jumanji/environments/routing/connector/env_test.py +++ b/jumanji/environments/routing/connector/env_test.py @@ -44,7 +44,7 @@ def is_target_on_grid(agent: Agent, grid: chex.Array) -> chex.Array: return jnp.any(grid[agent.target] == get_target(agent.id)) -def test_connector__reset(connector: Connector, key: jax.random.KeyArray) -> None: +def test_connector__reset(connector: Connector, key: jax.random.PRNGKey) -> None: """Test that all heads and targets are on the board.""" state, timestep = connector.reset(key) diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index f5d9f47d5..bd68552e3 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -86,7 +86,7 @@ class CVRP(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import CVRP env = CVRP() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/cvrp/reward_test.py b/jumanji/environments/routing/cvrp/reward_test.py index 8a607611d..acb35a9e4 100644 --- a/jumanji/environments/routing/cvrp/reward_test.py +++ b/jumanji/environments/routing/cvrp/reward_test.py @@ -46,10 +46,10 @@ def test_sparse_reward__compute_tour_length() -> None: assert jnp.isclose(tour_length, 6.8649917) trajectory = jnp.array([0, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 0, 6]) - assert compute_tour_length(coordinates, trajectory) == 6.8649917 + assert jnp.isclose(compute_tour_length(coordinates, trajectory), 6.8649917) trajectory = jnp.array([0, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 0, 6, 0]) - assert compute_tour_length(coordinates, trajectory) == 6.8649917 + assert jnp.isclose(compute_tour_length(coordinates, trajectory), 6.8649917) def test_dense_reward(cvrp_dense_reward: CVRP, dense_reward: DenseReward) -> None: diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 647bc79e3..8de273f3f 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -68,7 +68,7 @@ class Maze(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import Maze env = Maze() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index d81146c74..ca7394f12 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -117,6 +117,17 @@ class MMST(Environment[State, specs.MultiDiscreteArray, Observation]): - INVALID_CHOICE = -1 - INVALID_TIE_BREAK = -2 - INVALID_ALREADY_TRAVERSED = -3 + + ```python + from jumanji.environments import MMST + env = MMST() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` """ def __init__( diff --git a/jumanji/environments/routing/mmst/utils.py b/jumanji/environments/routing/mmst/utils.py index 0b73be496..18f653822 100644 --- a/jumanji/environments/routing/mmst/utils.py +++ b/jumanji/environments/routing/mmst/utils.py @@ -596,10 +596,10 @@ def multi_random_walk( ] # Get the total number of edges we need to add when merging the graphs. - sum_ratio: int = np.sum(np.arange(1, num_agents)) + sum_ratio = np.arange(1, num_agents).sum() frac = np.cumsum( - [total_edges_merge_graph * (i) / sum_ratio for i in range(1, num_agents - 1)] - ) + total_edges_merge_graph * np.arange(1, num_agents - 1) / sum_ratio + ).astype(np.int32) edges_per_merge_graph = jnp.split(jnp.arange(total_edges_merge_graph), frac) num_edges_per_merge_graph = [len(edges) for edges in edges_per_merge_graph] @@ -608,6 +608,7 @@ def multi_random_walk( total_edges = num_edges_per_sub_graph[0] merge_graph_keys = jax.random.split(base_key, num_agents - 1) + # TODO: could do a scan to speed up compilation for i in range(1, num_agents): total_edges += num_edges_per_sub_graph[i] + num_edges_per_merge_graph[i - 1] graph_i = correct_graph_offset(graphs[i - 1], nodes_offsets[i]) diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 666f2ba2c..2590b1539 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -64,6 +64,17 @@ class MultiCVRP(Environment[State, specs.BoundedArray, Observation]): [1] Zhang et al. (2020). "Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach". + + ```python + from jumanji.environments import MultiCVRP + env = MultiCVRP() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` """ def __init__( diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index 90a179d6e..a39e9cb97 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -81,7 +81,7 @@ class Snake(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import Snake env = Snake() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/sokoban/__init__.py b/jumanji/environments/routing/sokoban/__init__.py new file mode 100644 index 000000000..ff8aa3a7c --- /dev/null +++ b/jumanji/environments/routing/sokoban/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.types import Observation, State diff --git a/jumanji/environments/routing/sokoban/constants.py b/jumanji/environments/routing/sokoban/constants.py new file mode 100644 index 000000000..782395e31 --- /dev/null +++ b/jumanji/environments/routing/sokoban/constants.py @@ -0,0 +1,37 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +# Translating actions to coordinate changes +MOVES = jnp.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) +NOOP = -1 + +# Object encodings +EMPTY = 0 +WALL = 1 +TARGET = 2 +AGENT = 3 +BOX = 4 +TARGET_AGENT = 5 +TARGET_BOX = 6 + +# Environment Variables +N_BOXES = 4 +GRID_SIZE = 10 + +# Reward Function +LEVEL_COMPLETE_BONUS = 10 +SINGLE_BOX_BONUS = 1 +STEP_BONUS = -0.1 diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py new file mode 100644 index 000000000..c56fcbf89 --- /dev/null +++ b/jumanji/environments/routing/sokoban/env.py @@ -0,0 +1,575 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib.animation + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.routing.sokoban.constants import ( + AGENT, + BOX, + EMPTY, + GRID_SIZE, + MOVES, + N_BOXES, + NOOP, + TARGET, + TARGET_AGENT, + TARGET_BOX, + WALL, +) +from jumanji.environments.routing.sokoban.generator import ( + Generator, + HuggingFaceDeepMindGenerator, +) +from jumanji.environments.routing.sokoban.reward import DenseReward, RewardFn +from jumanji.environments.routing.sokoban.types import Observation, State +from jumanji.environments.routing.sokoban.viewer import BoxViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class Sokoban(Environment[State]): + """A JAX implementation of the 'Sokoban' game from deepmind. + + - observation: `Observation` + - grid: jax array (uint8) of shape (num_rows, num_cols, 2) + Array that includes information about the agent, boxes, and + targets in the game. + - step_count: jax array (int32) of shape () + current number of steps in the episode. + + - action: jax array (int32) of shape () + [0,1,2,3] -> [Up, Right, Down, Left]. + + - reward: jax array (float) of shape () + A reward of 1.0 is given for each box placed on a target and -1 + when removed from a target and -0.1 for each timestep. + 10 is awarded when all boxes are on targets. + + - episode termination: + - if the time limit is reached. + - if all boxes are on targets. + + - state: `State` + - key: jax array (uint32) of shape (2,) used for auto-reset + - fixed_grid: jax array (uint8) of shape (num_rows, num_cols) + array indicating the walls and targets in the level. + - variable_grid: jax array (uint8) of shape (num_rows, num_cols) + array indicating the current location of the agent and boxes. + - agent_location: jax array (int32) of shape (2,) + the agent's current location. + - step_count: jax array (int32) of shape () + current number of steps in the episode. + + ```python + from jumanji.environments import Sokoban + from jumanji.environments.routing.sokoban.generator import + HuggingFaceDeepMindGenerator, + + env_train = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + env_test = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-test", + proportion_of_files=1, + ) + ) + + # Train... + + ``` + key_train = jax.random.PRNGKey(0) + state, timestep = jax.jit(env_train.reset)(key_train) + env_train.render(state) + action = env_train.action_spec().generate_value() + state, timestep = jax.jit(env_train.step)(state, action) + env_train.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, + viewer: Optional[Viewer] = None, + time_limit: int = 120, + ) -> None: + """ + Instantiates a `Sokoban` environment with a specific generator, + time limit, and viewer. + + Args: + generator: `Generator` whose `__call__` instantiates an environment + instance (an initial state). Implemented options are [`ToyGenerator`, + `DeepMindGenerator`, and `HuggingFaceDeepMindGenerator`]. + Defaults to `HuggingFaceDeepMindGenerator` with + `dataset_name="unfiltered-train", proportion_of_files=1`. + time_limit: int, max steps for the environment, defaults to 120. + viewer: 'Viewer' object, used to render the environment. + If not provided, defaults to`BoxViewer`. + """ + + self.num_rows = GRID_SIZE + self.num_cols = GRID_SIZE + self.shape = (self.num_rows, self.num_cols) + self.time_limit = time_limit + + self.generator = generator or HuggingFaceDeepMindGenerator( + "unfiltered-train", + proportion_of_files=1, + ) + + self._viewer = viewer or BoxViewer( + name="Sokoban", + grid_combine=self.grid_combine, + ) + self.reward_fn = reward_fn or DenseReward() + + def __repr__(self) -> str: + """ + Returns a printable representation of the Sokoban environment. + + Returns: + str: A string representation of the Sokoban environment. + """ + return "\n".join( + [ + "Bokoban environment:", + f" - num_rows: {self.num_rows}", + f" - num_cols: {self.num_cols}", + f" - time_limit: {self.time_limit}", + f" - generator: {self.generator}", + ] + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """ + Resets the environment by calling the instance generator for a + new instance. + + Args: + key: random key used to sample new Sokoban problem. + + Returns: + state: `State` object corresponding to the new state of the + environment after a reset. + timestep: `TimeStep` object corresponding the first timestep + returned by the environment after a reset. + """ + + generator_key, key = jax.random.split(key) + + state = self.generator(generator_key) + + timestep = restart( + self._state_to_observation(state), + extras=self._get_extras(state), + ) + + return state, timestep + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """ + Executes one timestep of the environment's dynamics. + + Args: + state: 'State' object representing the current state of the + environment. + action: Array (int32) of shape (). + - 0: move up. + - 1: move down. + - 2: move left. + - 3: move right. + + Returns: + state, timestep: next state of the environment and timestep to be + observed. + """ + + # switch to noop if action will have no impact on variable grid + action = self.detect_noop_action( + state.variable_grid, state.fixed_grid, action, state.agent_location + ) + + next_variable_grid, next_agent_location = jax.lax.cond( + jnp.all(action == NOOP), + lambda: (state.variable_grid, state.agent_location), + lambda: self.move_agent(state.variable_grid, action, state.agent_location), + ) + + next_state = State( + key=state.key, + fixed_grid=state.fixed_grid, + variable_grid=next_variable_grid, + agent_location=next_agent_location, + step_count=state.step_count + 1, + ) + + target_reached = self.level_complete(next_state) + time_limit_exceeded = next_state.step_count >= self.time_limit + + done = jnp.logical_or(target_reached, time_limit_exceeded) + + reward = jnp.asarray(self.reward_fn(state, action, next_state), float) + + observation = self._state_to_observation(next_state) + + extras = self._get_extras(next_state) + + timestep = jax.lax.cond( + done, + lambda: termination( + reward=reward, + observation=observation, + extras=extras, + ), + lambda: transition( + reward=reward, + observation=observation, + extras=extras, + ), + ) + + return next_state, timestep + + def observation_spec(self) -> specs.Spec[Observation]: + """ + Returns the specifications of the observation of the `Sokoban` + environment. + + Returns: + specs.Spec[Observation]: The specifications of the observations. + """ + grid = specs.BoundedArray( + shape=(self.num_rows, self.num_cols, 2), + dtype=jnp.uint8, + minimum=0, + maximum=4, + name="grid", + ) + step_count = specs.Array((), jnp.int32, "step_count") + return specs.Spec( + Observation, + "ObservationSpec", + grid=grid, + step_count=step_count, + ) + + def action_spec(self) -> specs.DiscreteArray: + """ + Returns the action specification for the Sokoban environment. + There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + + Returns: + specs.DiscreteArray: Discrete action specifications. + """ + return specs.DiscreteArray(4, name="action", dtype=jnp.int32) + + def _state_to_observation(self, state: State) -> Observation: + """Maps an environment state to an observation. + + Args: + state: `State` object containing the dynamics of the environment. + + Returns: + The observation derived from the state. + """ + + total_grid = jnp.stack([state.variable_grid, state.fixed_grid], axis=-1) + + return Observation( + grid=total_grid, + step_count=state.step_count, + ) + + def _get_extras(self, state: State) -> Dict: + """ + Computes extras metrics to be returned within the timestep. + + Args: + state: 'State' object representing the current state of the + environment. + + Returns: + extras: Dict object containing current proportion of boxes on + targets and whether the problem is solved. + """ + num_boxes_on_targets = self.reward_fn.count_targets(state) + total_num_boxes = N_BOXES + extras = { + "prop_correct_boxes": num_boxes_on_targets / total_num_boxes, + "solved": num_boxes_on_targets == 4, + } + return extras + + def grid_combine( + self, variable_grid: chex.Array, fixed_grid: chex.Array + ) -> chex.Array: + """ + Combines the variable grid and fixed grid into one single grid + representation of the current Sokoban state required for visual + representation of the Sokoban state. Takes care of two possible + overlaps of fixed and variable entries (an agent on a target or a box + on a target), introducing two additional encodings. + + Args: + variable_grid: Array (uint8) of shape (num_rows, num_cols). + fixed_grid: Array (uint8) of shape (num_rows, num_cols). + + Returns: + full_grid: Array (uint8) of shape (num_rows, num_cols, 2). + """ + + mask_target_agent = jnp.logical_and( + fixed_grid == TARGET, + variable_grid == AGENT, + ) + + mask_target_box = jnp.logical_and( + fixed_grid == TARGET, + variable_grid == BOX, + ) + + single_grid = jnp.where( + mask_target_agent, + TARGET_AGENT, + jnp.where( + mask_target_box, + TARGET_BOX, + jnp.maximum(variable_grid, fixed_grid), + ), + ).astype(jnp.uint8) + + return single_grid + + def level_complete(self, state: State) -> chex.Array: + """ + Checks if the sokoban level is complete. + + Args: + state: `State` object representing the current state of the environment. + + Returns: + complete: Boolean indicating whether the level is complete + or not. + """ + return self.reward_fn.count_targets(state) == N_BOXES + + def check_space( + self, + grid: chex.Array, + location: chex.Array, + value: int, + ) -> chex.Array: + """ + Checks if a specific location in the grid contains a given value. + + Args: + grid: Array (uint8) shape (num_rows, num_cols) The grid to check. + location: Tuple size 2 of Array (int32) shape () containing the x + and y coodinate of the location to check in the grid. + value: int The value to look for. + + Returns: + present: Array (bool) shape () indicating whether the location + in the grid contains the given value or not. + """ + + return grid[tuple(location)] == value + + def in_grid(self, coordinates: chex.Array) -> chex.Array: + """ + Checks if given coordinates are within the grid size. + + Args: + coordinates: Array (uint8) shape (num_rows, num_cols) The + coordinates to check. + Returns: + in_grid: Array (bool) shape () Boolean indicating whether the + coordinates are within the grid. + """ + return jnp.all((0 <= coordinates) & (coordinates < GRID_SIZE)) + + def detect_noop_action( + self, + variable_grid: chex.Array, + fixed_grid: chex.Array, + action: chex.Array, + agent_location: chex.Array, + ) -> chex.Array: + """ + Masks actions to -1 that have no effect on the variable grid. + Determines if there is space in the destination square or if + there is a box in the destination square, it determines if the box + destination square is valid. + + Args: + variable_grid: Array (uint8) shape (num_rows, num_cols). + fixed_grid Array (uint8) shape (num_rows, num_cols) . + action: Array (int32) shape () The action to check. + + Returns: + updated_action: Array (int32) shape () The updated action after + detecting noop action. + """ + + new_location = agent_location + MOVES[action].squeeze() + + valid_destination = self.check_space( + fixed_grid, new_location, WALL + ) | ~self.in_grid(new_location) + + updated_action = jax.lax.select( + valid_destination, + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + jax.lax.select( + self.check_space(variable_grid, new_location, BOX), + self.update_box_push_action( + fixed_grid, + variable_grid, + new_location, + action, + ), + action, + ), + ) + + return updated_action + + def update_box_push_action( + self, + fixed_grid: chex.Array, + variable_grid: chex.Array, + new_location: chex.Array, + action: chex.Array, + ) -> chex.Array: + """ + Masks actions to -1 if pushing the box is not a valid move. If it + would be pushed out of the grid or the resulting square + is either a wall or another box. + + Args: + fixed_grid: Array (uint8) shape (num_rows, num_cols) The fixed grid. + variable_grid: Array (uint8) shape (num_rows, num_cols) The + variable grid. + new_location: Array (int32) shape (2,) The new location of the agent. + action: Array (int32) shape () The action to be executed. + + Returns: + updated_action: Array (int32) shape () The updated action after + checking if pushing the box is a valid move. + """ + + return jax.lax.select( + self.check_space( + variable_grid, + new_location + MOVES[action].squeeze(), + BOX, + ) + | ~self.in_grid(new_location + MOVES[action].squeeze()), + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + jax.lax.select( + self.check_space( + fixed_grid, + new_location + MOVES[action].squeeze(), + WALL, + ), + jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32), + action, + ), + ) + + def move_agent( + self, + variable_grid: chex.Array, + action: chex.Array, + current_location: chex.Array, + ) -> Tuple[chex.Array, chex.Array]: + """ + Executes the movement of the agent specified by the action and + executes the movement of a box if present at the destination. + + Args: + variable_grid: Array (uint8) shape (num_rows, num_cols) + action: Array (int32) shape () The action to take. + current_location: Array (int32) shape (2,) + + Returns: + next_variable_grid: Array (uint8) shape (num_rows, num_cols) + next_location: Array (int32) shape (2,) + """ + + next_location = current_location + MOVES[action] + box_location = next_location + MOVES[action] + + # remove agent from current location + next_variable_grid = variable_grid.at[tuple(current_location)].set(EMPTY) + + # either move agent or move agent and box + + next_variable_grid = jax.lax.select( + self.check_space(variable_grid, next_location, BOX), + next_variable_grid.at[tuple(next_location)] + .set(AGENT) + .at[tuple(box_location)] + .set(BOX), + next_variable_grid.at[tuple(next_location)].set(AGENT), + ) + + return next_variable_grid, next_location + + def render(self, state: State) -> None: + """ + Renders the current state of Sokoban. + + Args: + state: 'State' object , the current state to be rendered. + """ + + self._viewer.render(state=state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """ + Creates an animated gif of the Sokoban environment based on the + sequence of states. + + Args: + states: Sequence of 'State' object + interval: int, The interval between frames in the animation. + Defaults to 200. + save_path: str The path where to save the animation. If not + provided, the animation is not saved. + + Returns: + animation: 'matplotlib.animation.FuncAnimation'. + """ + return self._viewer.animate(states, interval, save_path) diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py new file mode 100644 index 000000000..8c3d8da93 --- /dev/null +++ b/jumanji/environments/routing/sokoban/env_test.py @@ -0,0 +1,217 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.constants import AGENT, BOX, TARGET, WALL +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import ( + DeepMindGenerator, + SimpleSolveGenerator, +) +from jumanji.environments.routing.sokoban.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.types import TimeStep + + +@pytest.fixture(scope="session") +def sokoban() -> Sokoban: + env = Sokoban( + generator=DeepMindGenerator( + difficulty="unfiltered", + split="train", + proportion_of_files=0.005, + ) + ) + return env + + +@pytest.fixture(scope="session") +def sokoban_simple() -> Sokoban: + env = Sokoban(generator=SimpleSolveGenerator()) + return env + + +def test_sokoban__reset(sokoban: Sokoban) -> None: + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(sokoban.reset, n=1)) + key = jax.random.PRNGKey(0) + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + assert state.step_count == 0 + assert timestep.observation.step_count == 0 + key2 = jax.random.PRNGKey(1) + state2, timestep2 = reset_fn(key2) + assert not jnp.array_equal(state2.fixed_grid, state.fixed_grid) + assert not jnp.array_equal(state2.variable_grid, state.variable_grid) + + +def test_sokoban__multi_step(sokoban: Sokoban) -> None: + """Validates the jitted step of the sokoban environment.""" + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1)) + + # Repeat test for 5 different state initializations + for j in range(5): + step_count = 0 + key = jax.random.PRNGKey(j) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban.reset(reset_key) + + # Repeating random step 120 times + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + # Check step_count increases after each step + step_count += 1 + assert state.step_count == step_count + assert timestep.observation.step_count == step_count + + # Check that the fixed part of the state has not changed + assert jnp.array_equal(state.fixed_grid, state.fixed_grid) + + # Check that there are always four boxes in the variable grid and 0 elsewhere + num_boxes = jnp.sum(state.variable_grid == BOX) + assert num_boxes == jnp.array(4, jnp.int32) + + num_boxes = jnp.sum(state.fixed_grid == BOX) + assert num_boxes == jnp.array(0, jnp.int32) + + # Check that there are always 4 targets in the fixed grid and 0 elsewhere + num_targets = jnp.sum(state.variable_grid == TARGET) + assert num_targets == jnp.array(0, jnp.int32) + + num_targets = jnp.sum(state.fixed_grid == TARGET) + assert num_targets == jnp.array(4, jnp.int32) + + # Check that there is one agent in variable grid and 0 elsewhere + num_agents = jnp.sum(state.variable_grid == AGENT) + assert num_agents == jnp.array(1, jnp.int32) + + num_agents = jnp.sum(state.fixed_grid == AGENT) + assert num_agents == jnp.array(0, jnp.int32) + + # Check that the grid size remains constant + assert state.fixed_grid.shape == (10, 10) + + # Check the agent is never in the same location as a wall + mask_agent = state.variable_grid == AGENT + mask_wall = state.fixed_grid == WALL + num_agents_on_wall = jnp.sum(mask_agent & mask_wall) + assert num_agents_on_wall == jnp.array(0, jnp.int32) + + # Check the boxes are never on a wall + mask_boxes = state.variable_grid == BOX + mask_wall = state.fixed_grid == WALL + num_agents_on_wall = jnp.sum(mask_boxes & mask_wall) + assert num_agents_on_wall == jnp.array(0, jnp.int32) + + +def test_sokoban__termination_timelimit(sokoban: Sokoban) -> None: + """Check that with random actions the environment terminates after + 120 steps""" + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban.step, n=1)) + + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban.reset(reset_key) + + for _ in range(119): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + assert not timestep.last() + + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + assert timestep.last() + + +def test_sokoban__termination_solved(sokoban_simple: Sokoban) -> None: + """Check that with correct sequence of actions to solve a trivial problem, + the environment terminates""" + + correct_actions = [0, 2, 1] * 3 + [0] + wrong_actions = [0, 2, 1] * 3 + [2] + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + # Check that environment does terminate with right series of actions + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for action in correct_actions: + assert not timestep.last() + + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + assert timestep.last() + + # Check that environment does not terminate with wrong series of actions + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for action in wrong_actions: + assert not timestep.last() + + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + assert not timestep.last() + + +def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None: + """Check the reward function is correct when solving the trivial problem. + Every step should give -0.1, each box added to a target adds 1 and + solving adds an additional 10""" + + # Correct actions that lead to placing a box every 3 actions + correct_actions = [0, 2, 1] * 3 + [0] + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + key = jax.random.PRNGKey(0) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + for i, action in enumerate(correct_actions): + action = jnp.array(action, jnp.int32) + state, timestep = step_fn(state, action) + + if i % 3 == 0 and i != 9: + assert timestep.reward == jnp.array(0.9, jnp.float32) + elif i != 9: + assert timestep.reward == jnp.array(-0.1, jnp.float32) + else: + assert timestep.reward == jnp.array(10.9, jnp.float32) + + +def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(sokoban) diff --git a/jumanji/environments/routing/sokoban/generator.py b/jumanji/environments/routing/sokoban/generator.py new file mode 100644 index 000000000..e3ace0a4a --- /dev/null +++ b/jumanji/environments/routing/sokoban/generator.py @@ -0,0 +1,448 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import os +import zipfile +from os import listdir +from os.path import isfile, join +from typing import List, Tuple + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import requests +from huggingface_hub import hf_hub_download +from tqdm import tqdm + +from jumanji.environments.routing.sokoban.constants import AGENT +from jumanji.environments.routing.sokoban.types import State + + +class Generator(abc.ABC): + """Defines the abstract `Generator` base class. A `Generator` is responsible + for generating a problem instance when the environment is reset. + """ + + def __init__( + self, + ) -> None: + """ """ + + self._fixed_grids: chex.Array + self._variable_grids: chex.Array + + @abc.abstractmethod + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a problem instance. + + Args: + key: the Jax random number generation key. + + Returns: + state: the generated problem instance. + """ + + def get_agent_coordinates(self, grid: chex.Array) -> chex.Array: + """Extracts the coordinates of the agent from a given grid with the + assumption there is only one agent in the grid. + + Args: + grid: Array (uint8) of shape (num_rows, num_cols) + + Returns: + location: (int32) of shape (2,) + """ + + coordinates = jnp.where(grid == AGENT, size=1) + + x_coord = jnp.squeeze(coordinates[0]) + y_coord = jnp.squeeze(coordinates[1]) + + return jnp.array([x_coord, y_coord]) + + +class DeepMindGenerator(Generator): + def __init__( + self, + difficulty: str, + split: str, + proportion_of_files: float = 1.0, + verbose: bool = False, + ) -> None: + self.difficulty = difficulty + self.verbose = verbose + self.proportion_of_files = proportion_of_files + + # Set the cache path to user's home directory's .cache sub-directory + self.cache_path = os.path.join( + os.path.expanduser("~"), ".cache", "sokoban_dataset" + ) + + # Downloads data if not already downloaded + self._download_data() + + self.train_data_dir = os.path.join( + self.cache_path, "boxoban-levels-master", self.difficulty + ) + + if self.difficulty in ["unfiltered", "medium"]: + if self.difficulty == "medium" and split == "test": + raise Exception( + "not a valid Deepmind Boxoban difficulty split" "combination" + ) + self.train_data_dir = os.path.join( + self.train_data_dir, + split, + ) + + # Generates the dataset of sokoban levels + self._fixed_grids, self._variable_grids = self._generate_dataset() + + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a random Boxoban problem from the Deepmind dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + idx = jax.random.randint( + idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] + ) + fixed_grid = self._fixed_grids.take(idx, axis=0) + variable_grid = self._variable_grids.take(idx, axis=0) + + initial_agent_location = self.get_agent_coordinates(variable_grid) + + state = State( + key=key, + fixed_grid=fixed_grid, + variable_grid=variable_grid, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + def _generate_dataset( + self, + ) -> Tuple[chex.Array, chex.Array]: + """Parses the text files to generate a jax arrays (fixed and variable + grids representing the Boxoban dataset + + Returns: + fixed_grid: Array (uint8) shape (dataset_size, num_rows, num_cols) + the fixed components of the problem. + variable_grid: Array (uint8) shape (dataset_size, num_rows, + num_cols) the variable components of the problem. + """ + + all_files = [ + f + for f in listdir(self.train_data_dir) + if isfile(join(self.train_data_dir, f)) + ] + # Only keep a few files if specified + all_files = all_files[: int(self.proportion_of_files * len(all_files))] + + fixed_grids_list: List[chex.Array] = [] + variable_grids_list: List[chex.Array] = [] + + for file in all_files: + + source_file = join(self.train_data_dir, file) + current_map: List[str] = [] + # parses a game file containing multiple games + with open(source_file, "r") as sf: + for line in sf.readlines(): + if ";" in line and current_map: + fixed_grid, variable_grid = convert_level_to_array(current_map) + + fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8)) + variable_grids_list.append( + jnp.array(variable_grid, dtype=jnp.uint8) + ) + + current_map = [] + if "#" == line[0]: + current_map.append(line.strip()) + + fixed_grids_list.append(jnp.array(fixed_grid, dtype=jnp.uint8)) + variable_grids_list.append(jnp.array(variable_grid, dtype=jnp.uint8)) + + fixed_grids = jnp.asarray(fixed_grids_list, jnp.uint8) + variable_grids = jnp.asarray(variable_grids_list, jnp.uint8) + + return fixed_grids, variable_grids + + def _download_data(self) -> None: + """Downloads the deepmind boxoban dataset from github into text files""" + + # Check if the cache directory exists, if not, create it + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + # Check if the dataset is already downloaded in the cache + dataset_path = os.path.join(self.cache_path, "boxoban-levels-master") + if not os.path.exists(dataset_path): + url = "https://github.com/deepmind/boxoban-levels/archive/master.zip" + if self.verbose: + print("Boxoban: Pregenerated levels not downloaded.") + print('Starting download from "{}"'.format(url)) + + response = requests.get(url, stream=True) + + if response.status_code != 200: + raise Exception("Could not download levels") + + path_to_zip_file = os.path.join( + self.cache_path, "boxoban_levels-master.zip" + ) + with open(path_to_zip_file, "wb") as handle: + for data in tqdm(response.iter_content()): + handle.write(data) + + with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: + zip_ref.extractall(self.cache_path) + + +class HuggingFaceDeepMindGenerator(Generator): + """Instance generator that generates a random problem from the DeepMind + Boxoban dataset a popular dataset for comparing Reinforcement Learning + algorithms and Planning Algorithms. The dataset has unfiltered, medium and + hard versions. The unfiltered dataset contain train, test and valid + splits. The Medium has train and valid splits available. And the hard set + contains just a small number of problems. The problems are all guaranteed + to be solvable. + """ + + def __init__( + self, + dataset_name: str, + proportion_of_files: float = 1.0, + ) -> None: + """Instantiates a `DeepMindGenerator`. + + Args: + dataset_name: the name of the dataset to use. Choices are: + - unfiltered-train, + - unfiltered-valid, + - unfiltered-test, + - medium-train, + - medium-test, + - hard. + proportion_of_files: float between (0,1) for the proportion of + files to use in the dataset . + """ + + self.dataset_name = dataset_name + self.proportion_of_files = proportion_of_files + + dataset_file = hf_hub_download( + repo_id="InstaDeepAI/boxoban-levels", + filename=f"{dataset_name}.npy", + ) + + with open(dataset_file, "rb") as f: + dataset = np.load(f) + + # Convert to jax arrays and resize using proportion_of_files + length = int(proportion_of_files * dataset.shape[0]) + self._fixed_grids = jnp.asarray(dataset[:length, ..., 0], jnp.uint8) + self._variable_grids = jnp.asarray(dataset[:length, ..., 1], jnp.uint8) + + def __call__(self, rng_key: chex.PRNGKey) -> State: + """Generate a random Boxoban problem from the Deepmind dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + idx = jax.random.randint( + idx_key, shape=(), minval=0, maxval=self._fixed_grids.shape[0] + ) + fixed_grid = self._fixed_grids.take(idx, axis=0) + variable_grid = self._variable_grids.take(idx, axis=0) + + initial_agent_location = self.get_agent_coordinates(variable_grid) + + state = State( + key=key, + fixed_grid=fixed_grid, + variable_grid=variable_grid, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +class ToyGenerator(Generator): + def __call__( + self, + rng_key: chex.PRNGKey, + ) -> State: + """Generate a random Boxoban problem from the toy 2 problem dataset. + + Args: + rng_key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + + key, idx_key = jax.random.split(rng_key) + + level1 = [ + "##########", + "# @ #", + "# $ . #", + "# $# . #", + "# .#$ # ", + "# . # $ # ", + "# #", + "##########", + "##########", + "##########", + ] + + level2 = [ + "##########", + "# #", + "#$ # . #", + "# # $ # .#", + "# .# $ #", + "# @ # . $#", + "# #", + "##########", + "##########", + "##########", + ] + + game1_fixed, game1_variable = convert_level_to_array(level1) + game2_fixed, game2_variable = convert_level_to_array(level2) + + games_fixed = jnp.stack([game1_fixed, game2_fixed]) + games_variable = jnp.stack([game1_variable, game2_variable]) + + game_index = jax.random.randint( + key=idx_key, + shape=(), + minval=0, + maxval=games_fixed.shape[0], + ) + + initial_agent_location = self.get_agent_coordinates(games_variable[game_index]) + + state = State( + key=key, + fixed_grid=games_fixed[game_index], + variable_grid=games_variable[game_index], + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +class SimpleSolveGenerator(Generator): + def __call__( + self, + key: chex.PRNGKey, + ) -> State: + """Generate a trivial Boxoban problem. + + Args: + key: the Jax random number generation key. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) the fixed + components of the problem. + variable_grid: Array (uint8) shape (num_rows, num_cols) the + variable components of the problem. + """ + level1 = [ + "##########", + "# ##", + "# .... #", + "# $$$$ ##", + "# @ # #", + "# # # ", + "# #", + "##########", + "##########", + "##########", + ] + + game_fixed, game_variable = convert_level_to_array(level1) + + initial_agent_location = self.get_agent_coordinates(game_variable) + + state = State( + key=key, + fixed_grid=game_fixed, + variable_grid=game_variable, + agent_location=initial_agent_location, + step_count=jnp.array(0, jnp.int32), + ) + + return state + + +def convert_level_to_array(level: List[str]) -> Tuple[chex.Array, chex.Array]: + """Converts text representation of levels to a tuple of Jax arrays + representing the fixed elements of the Boxoban problem and the variable + elements + + Args: + level: List of str representing a boxoban level. + + Returns: + fixed_grid: Array (uint8) shape (num_rows, num_cols) + the fixed components of the problem. + variable_grid: Array (uint8) shape (num_rows, + num_cols) the variable components of the problem. + """ + + # Define the mappings + mapping = { + "#": (1, 0), + ".": (2, 0), + "@": (0, 3), + "$": (0, 4), + " ": (0, 0), # empty cell + } + + fixed = [[mapping[cell][0] for cell in row] for row in level] + variable = [[mapping[cell][1] for cell in row] for row in level] + + return jnp.array(fixed, jnp.uint8), jnp.array(variable, jnp.uint8) diff --git a/jumanji/environments/routing/sokoban/generator_test.py b/jumanji/environments/routing/sokoban/generator_test.py new file mode 100644 index 000000000..0c766b632 --- /dev/null +++ b/jumanji/environments/routing/sokoban/generator_test.py @@ -0,0 +1,196 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import List + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import ( + DeepMindGenerator, + HuggingFaceDeepMindGenerator, +) + + +def test_sokoban__hugging_generator_creation() -> None: + """checks we can create datasets for all valid boxoban datasets and + perform a jitted step""" + + datasets = [ + "unfiltered-train", + "unfiltered-test", + "unfiltered-valid", + "medium-train", + "medium-valid", + "hard", + ] + + for dataset in datasets: + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name=dataset, + proportion_of_files=1, + ) + ) + + print(env.generator._fixed_grids.shape) + + step_fn = jax.jit(chex.assert_max_traces(env.step, n=1)) + + key = jax.random.PRNGKey(0) + state, timestep = env.reset(key) + + step_count = 0 + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + # Check step_count increases after each step + step_count += 1 + assert state.step_count == step_count + assert timestep.observation.step_count == step_count + + +def test_sokoban__hugging_generator_different_problems() -> None: + """checks that resetting with different keys leads to different problems""" + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + key1 = jax.random.PRNGKey(0) + state1, timestep1 = env.reset(key1) + + key2 = jax.random.PRNGKey(1) + state2, timestep2 = env.reset(key2) + + # Check that different resets lead to different problems + assert not jnp.array_equal(state2.fixed_grid, state1.fixed_grid) + assert not jnp.array_equal(state2.variable_grid, state1.variable_grid) + + +def test_sokoban__hugging_generator_same_problems() -> None: + """checks that resettting with the same key leads to the same problems""" + + chex.clear_trace_counter() + + env = Sokoban( + generator=HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + ) + + key1 = jax.random.PRNGKey(0) + state1, timestep1 = env.reset(key1) + + key2 = jax.random.PRNGKey(0) + state2, timestep2 = env.reset(key2) + + assert jnp.array_equal(state2.fixed_grid, state1.fixed_grid) + assert jnp.array_equal(state2.variable_grid, state1.variable_grid) + + +def test_sokoban__hugging_generator_proportion_of_problems() -> None: + """checks that generator initialises correct number of problems""" + + chex.clear_trace_counter() + + unfiltered_dataset_size = 900000 + + generator_full = HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=1, + ) + + assert jnp.array_equal( + generator_full._fixed_grids.shape, + (unfiltered_dataset_size, 10, 10), + ) + + generator_half_full = HuggingFaceDeepMindGenerator( + dataset_name="unfiltered-train", + proportion_of_files=0.5, + ) + + assert jnp.array_equal( + generator_half_full._fixed_grids.shape, + (unfiltered_dataset_size / 2, 10, 10), + ) + + +def test_sokoban__deepmind_generator_creation() -> None: + """checks we can create datasets for all valid boxoban datasets""" + + # Different datasets with varying proportion of files to keep rutime low + valid_datasets: List[List] = [ + ["unfiltered", "train", 0.01], + ["unfiltered", "test", 1], + ["unfiltered", "valid", 0.02], + ["medium", "train", 0.01], + ["medium", "valid", 0.02], + ["hard", None, 1], + ] + + for dataset in valid_datasets: + + chex.clear_trace_counter() + + env = Sokoban( + generator=DeepMindGenerator( + difficulty=dataset[0], + split=dataset[1], + proportion_of_files=dataset[2], + ) + ) + + assert env.generator._fixed_grids.shape[0] > 0 + + +def test_sokoban__deepmind_invalid_creation() -> None: + """checks that asking for invalid difficulty, split, proportion leads to + exception""" + + # Different datasets with varying proportion of files to keep rutime low + valid_datasets: List[List] = [ + ["medium", "test", 0.01], + ["mediumy", "train", 0.01], + ["hardy", "train", 0.01], + ["unfiltered", None, 0.01], + ] + + for dataset in valid_datasets: + + chex.clear_trace_counter() + + with pytest.raises(Exception): + _ = Sokoban( + generator=DeepMindGenerator( + difficulty=dataset[0], + split=dataset[1], + proportion_of_files=dataset[2], + ) + ) diff --git a/jumanji/environments/routing/sokoban/imgs/agent.png b/jumanji/environments/routing/sokoban/imgs/agent.png new file mode 100644 index 0000000000000000000000000000000000000000..00298ce9b5a5c47404a770268b42d99cf68892d4 GIT binary patch literal 134 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!D3?x-;bCrM;TYyi9E0ESaB^B4hBnM=17I;J! z11T#IX8e#c$qOi`=IP=X!f`!W;6cKK2?-DWGdN8$me6@_aOHpUkN@(lH~%L-5NJ?f YXiX6~yi>U`2B?F<)78&qol`;+04_%(r2qf` literal 0 HcmV?d00001 diff --git a/jumanji/environments/routing/sokoban/imgs/agent_on_target.png b/jumanji/environments/routing/sokoban/imgs/agent_on_target.png new file mode 100644 index 0000000000000000000000000000000000000000..3e8310d2501b6260a947be727a0d613b9b606672 GIT binary patch literal 162 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!93?!50ihlx9oB=)|u0Z-71B2!%sgPqDk3oV3 z9+AaB$_j)TKcr0Z0t(uDx;Tb#Tu)9&O2|kMm@Q@2>X>jMD%9SgJ6s}&`(&CtX zb(JX3>IMdF#jROah1eR|)Wp~>u1r*PXr3X#5ELO}u>74y9?&QTPgg&ebxsLQ03$0b A@&Et; literal 0 HcmV?d00001 diff --git a/jumanji/environments/routing/sokoban/imgs/box.png b/jumanji/environments/routing/sokoban/imgs/box.png new file mode 100644 index 0000000000000000000000000000000000000000..9a2497df0efe811c14f6dafe57c662d897ddf35b GIT binary patch literal 175 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!93?!50ihlx9oB=)|uGiBUt`{*}Kf=%ybSM-g zSl|&^45X|;nDIl(Brl+#kEe@c2*>s01P)%_goN0wS}l$xS679#d8`ec$TKre+gNhb zp=s+16aWAK literal 0 HcmV?d00001 diff --git a/jumanji/environments/routing/sokoban/imgs/box_on_target.png b/jumanji/environments/routing/sokoban/imgs/box_on_target.png new file mode 100644 index 0000000000000000000000000000000000000000..74629af03f4cce827aa072824950663d27cfae14 GIT binary patch literal 165 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!93?!50ihlx9JOMr-uJ0Hat`{*}Kf-W5ond#^ zLlvMnXMsm#F_5wXVa5+Ble~a}_MR?|Asp9}6B?MWg|0R;DA*Szdf-6hR-?v7?VOiP zZ0puuZa5;PEwNGL$Xv5hM&|M>w^-Qzt+~pmkSi|G#=ziiCCZj{z3V2>C?O8D?ow; z9+AaB$_j)TKcr0Z0t#w*x;Tb#TuD>e&K>ciV}~Diw~<1=PH9RiB&=t ejAjfeA`H(>g!+uac#i;eF?hQAxvX literal 0 HcmV?d00001 diff --git a/jumanji/environments/routing/sokoban/imgs/wall.png b/jumanji/environments/routing/sokoban/imgs/wall.png new file mode 100644 index 0000000000000000000000000000000000000000..8da06c8d6b139c22dd3321d4ba95b1edd93d7978 GIT binary patch literal 120 zcmeAS@N?(olHy`uVBq!ia0vp^0wB!D3?x-;bCrM;TYyi9>jqnf4YT?dy?LSq6yYrJ zh%5$DRv^syA!U*mP*B3t#W95AdUC?Ugoefc|2uj(Ff=MCd1NtMN#MA-`g|N4Pz{5p LtDnm{r-UW|7U3T> literal 0 HcmV?d00001 diff --git a/jumanji/environments/routing/sokoban/reward.py b/jumanji/environments/routing/sokoban/reward.py new file mode 100644 index 000000000..775122293 --- /dev/null +++ b/jumanji/environments/routing/sokoban/reward.py @@ -0,0 +1,120 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban.constants import ( + BOX, + LEVEL_COMPLETE_BONUS, + N_BOXES, + SINGLE_BOX_BONUS, + STEP_BONUS, + TARGET, +) +from jumanji.environments.routing.sokoban.types import State + + +class RewardFn(abc.ABC): + @abc.abstractmethod + def __call__( + self, + state: State, + action: chex.Numeric, + next_state: State, + ) -> chex.Numeric: + """Compute the reward based on the current state, + the chosen action, the next state. + """ + + def count_targets(self, state: State) -> chex.Array: + """ + Calculates the number of boxes on targets. + + Args: + state: `State` object representing the current state of the + environment. + + Returns: + n_targets: Array (int32) of shape () specifying the number of boxes + on targets. + """ + + mask_box = state.variable_grid == BOX + mask_target = state.fixed_grid == TARGET + + num_boxes_on_targets = jnp.sum(mask_box & mask_target) + + return num_boxes_on_targets + + +class SparseReward(RewardFn): + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + ) -> chex.Array: + """ + Implements the sparse reward function in the Sokoban environment. + + Args: + state: `State` object The current state of the environment. + action: Array (int32) shape () representing the action taken. + next_state: `State` object The next state of the environment. + + Returns: + reward: Array (float32) of shape () specifying the reward received + at transition + """ + + next_num_box_target = self.count_targets(next_state) + + level_completed = next_num_box_target == N_BOXES + + return LEVEL_COMPLETE_BONUS * level_completed + + +class DenseReward(RewardFn): + def __call__( + self, + state: State, + action: chex.Array, + next_state: State, + ) -> chex.Array: + """ + Implements the dense reward function in the Sokoban environment. + + Args: + state: `State` object The current state of the environment. + action: Array (int32) shape () representing the action taken. + next_state: `State` object The next state of the environment. + + Returns: + reward: Array (float32) of shape () specifying the reward received + at transition + """ + + num_box_target = self.count_targets(state) + next_num_box_target = self.count_targets(next_state) + + level_completed = next_num_box_target == N_BOXES + + return ( + SINGLE_BOX_BONUS * (next_num_box_target - num_box_target) + + LEVEL_COMPLETE_BONUS * level_completed + + STEP_BONUS + ) diff --git a/jumanji/environments/routing/sokoban/reward_test.py b/jumanji/environments/routing/sokoban/reward_test.py new file mode 100644 index 000000000..c8ae5b08b --- /dev/null +++ b/jumanji/environments/routing/sokoban/reward_test.py @@ -0,0 +1,74 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.sokoban.env import Sokoban +from jumanji.environments.routing.sokoban.generator import SimpleSolveGenerator +from jumanji.types import TimeStep + + +@pytest.fixture(scope="session") +def sokoban_simple() -> Sokoban: + env = Sokoban(generator=SimpleSolveGenerator()) + return env + + +def test_sokoban__reward_function_random(sokoban_simple: Sokoban) -> None: + """Check the reward function is correct when randomly acting in the + trivial problem, where accidently pushing boxes onto targets is likely. + Every step should give -0.1, each box pushed on adds 1 , each box removed + on takes away 1 ,solving adds an additional 10""" + + def check_correct_reward( + timestep: TimeStep, + num_boxes_on_targets_new: chex.Array, + num_boxes_on_targets: chex.Array, + ) -> None: + + if num_boxes_on_targets_new == jnp.array(4, jnp.int32): + assert timestep.reward == jnp.array(10.9, jnp.float32) + elif num_boxes_on_targets_new - num_boxes_on_targets > jnp.array(0, jnp.int32): + assert timestep.reward == jnp.array(0.9, jnp.float32) + elif num_boxes_on_targets_new - num_boxes_on_targets < jnp.array(0, jnp.int32): + assert timestep.reward == jnp.array(-1.1, jnp.float32) + else: + assert timestep.reward == jnp.array(-0.1, jnp.float32) + + for i in range(5): + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(sokoban_simple.step, n=1)) + + key = jax.random.PRNGKey(i) + reset_key, step_key = jax.random.split(key) + state, timestep = sokoban_simple.reset(reset_key) + + num_boxes_on_targets = sokoban_simple.reward_fn.count_targets(state) + + for _ in range(120): + action = jnp.array(random.randint(0, 4), jnp.int32) + state, timestep = step_fn(state, action) + + num_boxes_on_targets_new = sokoban_simple.reward_fn.count_targets(state) + + check_correct_reward( + timestep, num_boxes_on_targets_new, num_boxes_on_targets + ) + + num_boxes_on_targets = num_boxes_on_targets_new diff --git a/jumanji/environments/routing/sokoban/types.py b/jumanji/environments/routing/sokoban/types.py new file mode 100644 index 000000000..eeb561e76 --- /dev/null +++ b/jumanji/environments/routing/sokoban/types.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, NamedTuple + +import chex + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + + +@dataclass +class State: + """ + key: random key used for auto-reset. + fixed_grid: Array (uint8) shape (n_rows, n_cols) array representing the + fixed elements of a sokoban problem. + variable_grid: Array (uint8) shape (n_rows, n_cols) array representing the + variable elements of a sokoban problem. + agent_location: Array (int32) shape (2,) + step_count: Array (int32) shape () + """ + + key: chex.PRNGKey + fixed_grid: chex.Array + variable_grid: chex.Array + agent_location: chex.Array + step_count: chex.Array + + +class Observation(NamedTuple): + """ + The observation returned by the sokoban environment. + grid: Array (uint8) shape (n_rows, n_cols, 2) array representing the + variable and fixed grids. + step_count: Array (int32) shape () the index of the current step. + """ + + grid: chex.Array + step_count: chex.Array diff --git a/jumanji/environments/routing/sokoban/viewer.py b/jumanji/environments/routing/sokoban/viewer.py new file mode 100644 index 000000000..db5716704 --- /dev/null +++ b/jumanji/environments/routing/sokoban/viewer.py @@ -0,0 +1,219 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence, Tuple + +import chex +import matplotlib.animation +import matplotlib.cm +import matplotlib.pyplot as plt +import numpy as np +import pkg_resources +from numpy.typing import NDArray +from PIL import Image + +import jumanji.environments +from jumanji.viewer import Viewer + + +class BoxViewer(Viewer): + FIGURE_SIZE = (10.0, 10.0) + + def __init__( + self, + name: str, + grid_combine: Callable, + ) -> None: + """ + Viewer for a `Sokoban` environment using images from + https://github.com/mpSchrader/gym-sokoban. + + Args: + name: the window name to be used when initialising the window. + grid_combine: function for combining fixed_grid and variable grid + """ + self._name = name + self.NUM_COLORS = 10 + self.grid_combine = grid_combine + self._display = self._display_rgb_array + self._animation: Optional[matplotlib.animation.Animation] = None + + image_names = [ + "floor", + "wall", + "box_target", + "agent", + "box", + "agent_on_target", + "box_on_target", + ] + + def get_image(image_name: str) -> Image.Image: + img_path = pkg_resources.resource_filename( + "jumanji", f"environments/routing/sokoban/imgs/{image_name}.png" + ) + return Image.open(img_path) + + self.images = [get_image(image_name) for image_name in image_names] + + def render(self, state: chex.Array) -> Optional[NDArray]: + """Render the given state of the `Sokoban` environment. + + Args: + state: the environment state to render. + """ + + self._clear_display() + fig, ax = self._get_fig_ax() + ax.clear() + self._add_grid_image(state, ax) + return self._display(fig) + + def animate( + self, + states: Sequence[chex.Array], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to + consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If + it is None, the plot will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = plt.subplots( + num=f"{self._name}Animation", figsize=BoxViewer.FIGURE_SIZE + ) + plt.close(fig) + + def make_frame(state_index: int) -> None: + ax.clear() + state = states[state_index] + self._add_grid_image(state, ax) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self._name) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + """ + Fetch or create a matplotlib figure and its associated axes. + + Returns: + fig: (plt.Figure) A matplotlib figure object + axes: (plt.Axes) The axes associated with the figure. + """ + recreate = not plt.fignum_exists(self._name) + fig = plt.figure(self._name, BoxViewer.FIGURE_SIZE) + if recreate: + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + else: + ax = fig.get_axes()[0] + return fig, ax + + def _add_grid_image(self, state: chex.Array, ax: plt.Axes) -> None: + """ + Add a grid image to the provided axes. + + Args: + state: 'State' object representing a state of Sokoban. + ax: (plt.Axes) object where the state image will be added. + """ + grid = self.grid_combine(state.variable_grid, state.fixed_grid) + + self._draw_grid(grid, ax) + ax.set_axis_off() + ax.set_aspect(1) + ax.relim() + ax.autoscale_view() + + def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: + """ + Draw a grid onto provided axes. + + Args: + grid: Array () of shape (). + ax: (plt.Axes) The axes on which to draw the grid. + """ + + cols, rows = grid.shape + + for col in range(cols): + for row in range(rows): + self._draw_grid_cell(grid[row, col], 9 - row, col, ax) + + def _draw_grid_cell( + self, cell_value: int, row: int, col: int, ax: plt.Axes + ) -> None: + """ + Draw a single cell of the grid. + + Args: + cell_value: int representing the cell's value determining its image. + row: int representing the cell's row index. + col: int representing the cell's col index. + ax: (plt.Axes) The axes on which to draw the cell. + """ + cell_value = int(cell_value) + image = self.images[cell_value] + ax.imshow(image, extent=(col, col + 1, row, row + 1)) + + def _clear_display(self) -> None: + """ + Clear the current notebook display if the environment is a notebook. + """ + + if jumanji.environments.is_notebook(): + import IPython.display + + IPython.display.clear_output(True) + + def _display_rgb_array(self, fig: plt.Figure) -> NDArray: + """ + Convert the given figure to an RGB array. + + Args: + fig: (plt.Figure) The figure to be converted. + + Returns: + NDArray: The RGB array representation of the figure. + """ + fig.canvas.draw() + return np.asarray(fig.canvas.buffer_rgba()) diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index 7f4c084a5..c945b20ec 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -80,7 +80,7 @@ class TSP(Environment[State, specs.DiscreteArray, Observation]): ```python from jumanji.environments import TSP env = TSP() - key = jax.random.key(0) + key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) action = env.action_spec.generate_value() diff --git a/jumanji/environments/routing/tsp/reward_test.py b/jumanji/environments/routing/tsp/reward_test.py index 8eb55e999..b4c632aaa 100644 --- a/jumanji/environments/routing/tsp/reward_test.py +++ b/jumanji/environments/routing/tsp/reward_test.py @@ -69,7 +69,7 @@ def test_sparse_reward( # noqa: CCR001 tour_length = jnp.linalg.norm( sorted_cities - sorted_cities_rolled, axis=-1 ).sum() - assert reward == -tour_length + assert jnp.isclose(reward, -tour_length) else: # Check that the reward is 0 for every non-final valid action. assert reward == 0 diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index a110c89b8..458d60b07 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sudoku, tetris, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/sokoban.yaml b/jumanji/training/configs/env/sokoban.yaml new file mode 100644 index 000000000..6baf4f2d0 --- /dev/null +++ b/jumanji/training/configs/env/sokoban.yaml @@ -0,0 +1,26 @@ +name: sokoban +registered_version: Sokoban-v0 + +network: + channels: [256,256,512,512] + policy_layers: [64, 64] + value_layers: [128, 128] + +training: + num_epochs: 1000 + num_learner_steps_per_epoch: 500 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 + +a2c: + normalize_advantage: True + discount_factor: 0.97 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 3e-4 diff --git a/jumanji/training/loggers.py b/jumanji/training/loggers.py index 722bd9987..a0ad74384 100644 --- a/jumanji/training/loggers.py +++ b/jumanji/training/loggers.py @@ -57,6 +57,14 @@ def close(self) -> None: def upload_checkpoint(self) -> None: """Uploads a checkpoint when exiting the logger.""" + def is_loggable(self, value: Any) -> bool: + """Returns True if the value is loggable.""" + if isinstance(value, (float, int)): + return True + if isinstance(value, (jnp.ndarray, np.ndarray)): + return bool(value.ndim == 0) + return False + def __enter__(self) -> Logger: logging.info("Starting logger.") self._variables_enter = self._get_variables() @@ -134,8 +142,9 @@ def __init__( def _format_values(self, data: Dict[str, Any]) -> str: return " | ".join( f"{key.replace('_', ' ').title()}: " - f"{(f'{value:.3f}' if isinstance(value, (float, jnp.ndarray)) else f'{value:,}')}" + f"{(f'{value:,}' if isinstance(value, int) else f'{value:.3f}')}" for key, value in sorted(data.items()) + if self.is_loggable(value) ) def write( @@ -166,7 +175,8 @@ def write( env_steps: Optional[int] = None, ) -> None: for key, value in data.items(): - self.history[key].append(value) + if self.is_loggable(value): + self.history[key].append(value) class TensorboardLogger(Logger): @@ -191,15 +201,12 @@ def write( self._env_steps = env_steps prefix = label and f"{label}/" for key, metric in data.items(): - if np.ndim(metric) == 0: - if not np.isnan(metric): - self.writer.add_scalar( - tag=f"{prefix}/{key}", - scalar_value=metric, - global_step=int(self._env_steps), - ) - else: - raise ValueError(f"Expected metric {key} to be a scalar, got {metric}.") + if self.is_loggable(metric) and not np.isnan(metric): + self.writer.add_scalar( + tag=f"{prefix}/{key}", + scalar_value=metric, + global_step=int(self._env_steps), + ) def close(self) -> None: self.writer.close() @@ -232,15 +239,12 @@ def write( self._env_steps = env_steps prefix = label and f"{label}/" for key, metric in data.items(): - if np.ndim(metric) == 0: - if not np.isnan(metric): - self.run[f"{prefix}/{key}"].log( - float(metric), - step=int(self._env_steps), - wait=True, - ) - else: - raise ValueError(f"Expected metric {key} to be a scalar, got {metric}.") + if self.is_loggable(metric) and not np.isnan(metric): + self.run[f"{prefix}/{key}"].log( + float(metric), + step=int(self._env_steps), + wait=True, + ) def close(self) -> None: self.run.stop() diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 1fb1df18d..82ad0ae65 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -78,6 +78,10 @@ make_actor_critic_networks_snake, ) from jumanji.training.networks.snake.random import make_random_policy_snake +from jumanji.training.networks.sokoban.actor_critic import ( + make_actor_critic_networks_sokoban, +) +from jumanji.training.networks.sokoban.random import make_random_policy_sokoban from jumanji.training.networks.sudoku.actor_critic import ( make_cnn_actor_critic_networks_sudoku, make_equivariant_actor_critic_networks_sudoku, diff --git a/jumanji/training/networks/sokoban/__init__.py b/jumanji/training/networks/sokoban/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/sokoban/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py new file mode 100644 index 000000000..968180c60 --- /dev/null +++ b/jumanji/training/networks/sokoban/actor_critic.py @@ -0,0 +1,115 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +import chex +import haiku as hk +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban import Observation, Sokoban +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + CategoricalParametricDistribution, +) + + +def make_actor_critic_networks_sokoban( + sokoban: Sokoban, + channels: Sequence[int], + policy_layers: Sequence[int], + value_layers: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `Sokoban` environment.""" + num_actions = sokoban.action_spec().num_values + parametric_action_distribution = CategoricalParametricDistribution( + num_actions=num_actions + ) + + policy_network = make_sokoban_cnn( + num_outputs=num_actions, + mlp_units=policy_layers, + channels=channels, + time_limit=sokoban.time_limit, + ) + value_network = make_sokoban_cnn( + num_outputs=1, + mlp_units=value_layers, + channels=channels, + time_limit=sokoban.time_limit, + ) + + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +def make_sokoban_cnn( + num_outputs: int, + mlp_units: Sequence[int], + channels: Sequence[int], + time_limit: int, +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + + # Iterate over the channels sequence to create convolutional layers + layers = [] + for i, conv_n_channels in enumerate(channels): + layers.append(hk.Conv2D(conv_n_channels, (3, 3), stride=2 if i == 0 else 1)) + layers.append(jax.nn.relu) + + layers.append(hk.Flatten()) + + torso = hk.Sequential(layers) + + x_processed = preprocess_input(observation.grid) + + embedding = torso(x_processed) + + norm_step_count = jnp.expand_dims(observation.step_count / time_limit, axis=-1) + embedding = jnp.concatenate([embedding, norm_step_count], axis=-1) + head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False) + if num_outputs == 1: + value = jnp.squeeze(head(embedding), axis=-1) + return value + else: + logits = head(embedding) + + return logits + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def preprocess_input( + input_array: chex.Array, +) -> chex.Array: + + one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype( + jnp.float32 + ) + + one_hot_array_variable = jnp.equal(input_array[..., 1:2], jnp.array([1, 2])).astype( + jnp.float32 + ) + + total = jnp.concatenate((one_hot_array_fixed, one_hot_array_variable), axis=-1) + + return total diff --git a/jumanji/training/networks/sokoban/random.py b/jumanji/training/networks/sokoban/random.py new file mode 100644 index 000000000..8b428174f --- /dev/null +++ b/jumanji/training/networks/sokoban/random.py @@ -0,0 +1,35 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.sokoban import Observation +from jumanji.training.networks.protocols import RandomPolicy + + +def categorical_random( + observation: Observation, + key: chex.PRNGKey, +) -> chex.Array: + logits = jnp.zeros(shape=(observation.grid.shape[0], 4)) + + action = jax.random.categorical(key, logits) + return action + + +def make_random_policy_sokoban() -> RandomPolicy: + """Make random policy for the `Sokoban` environment.""" + return categorical_random diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index dee956345..e2d2b9890 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -40,6 +40,7 @@ RobotWarehouse, RubiksCube, Snake, + Sokoban, Sudoku, Tetris, ) @@ -178,6 +179,9 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "maze": assert isinstance(env.unwrapped, Maze) random_policy = networks.make_random_policy_maze() + elif cfg.env.name == "sokoban": + assert isinstance(env.unwrapped, Sokoban) + random_policy = networks.make_random_policy_sokoban() elif cfg.env.name == "connector": assert isinstance(env.unwrapped, Connector) random_policy = networks.make_random_policy_connector() @@ -326,6 +330,14 @@ def _setup_actor_critic_neworks( # noqa: CCR001 policy_layers=cfg.env.network.policy_layers, value_layers=cfg.env.network.value_layers, ) + elif cfg.env.name == "sokoban": + assert isinstance(env.unwrapped, Sokoban) + actor_critic_networks = networks.make_actor_critic_networks_sokoban( + sokoban=env.unwrapped, + channels=cfg.env.network.channels, + policy_layers=cfg.env.network.policy_layers, + value_layers=cfg.env.network.value_layers, + ) elif cfg.env.name == "cleaner": assert isinstance(env.unwrapped, Cleaner) actor_critic_networks = networks.make_actor_critic_networks_cleaner( diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 92f47c480..1025625a8 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -364,6 +364,25 @@ def render(self, state: State) -> Any: return super().render(state_0) +NEXT_OBS_KEY_IN_EXTRAS = "next_obs" + + +def add_obs_to_extras(timestep: TimeStep[Observation]) -> TimeStep[Observation]: + """Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. + Used when auto-resetting to store the observation from the terminal TimeStep (useful for + e.g. truncation). + + Args: + timestep: TimeStep object containing the timestep returned by the environment. + + Returns: + timestep where the observation is placed in timestep.extras["next_obs"]. + """ + extras = timestep.extras + extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation + return timestep.replace(extras=extras) # type: ignore + + class AutoResetWrapper( Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation] ): @@ -371,13 +390,29 @@ class AutoResetWrapper( the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. + NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`), which would lead to inefficient computation due to both the `step` and `reset` functions being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead. """ + def __init__(self, env: Environment, next_obs_in_extras: bool = False): + """Wrap an environment to automatically reset it when the episode terminates. + + Args: + env: the environment to wrap. + next_obs_in_extras: whether to store the next observation in the extras of the + terminal timestep. This is useful for e.g. truncation. + """ + super().__init__(env) + self.next_obs_in_extras = next_obs_in_extras + if next_obs_in_extras: + self._maybe_add_obs_to_extras = add_obs_to_extras + else: + self._maybe_add_obs_to_extras = lambda timestep: timestep # no-op + def _auto_reset( - self, state: State, timestep: TimeStep + self, state: State, timestep: TimeStep[Observation] ) -> Tuple[State, TimeStep[Observation]]: """Reset the state and overwrite `timestep.observation` with the reset observation if the episode has terminated. @@ -393,11 +428,17 @@ def _auto_reset( key, _ = jax.random.split(state.key) # type: ignore state, reset_timestep = self._env.reset(key) + # Place original observation in extras. + timestep = self._maybe_add_obs_to_extras(timestep) + # Replace observation with reset observation. - timestep = timestep.replace( # type: ignore - observation=reset_timestep.observation - ) + timestep = timestep.replace(observation=reset_timestep.observation) # type: ignore + + return state, timestep + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + state, timestep = super().reset(key) + timestep = self._maybe_add_obs_to_extras(timestep) return state, timestep def step( @@ -410,7 +451,7 @@ def step( state, timestep = jax.lax.cond( timestep.last(), self._auto_reset, - lambda *x: x, + lambda s, t: (s, self._maybe_add_obs_to_extras(t)), state, timestep, ) @@ -429,8 +470,24 @@ class VmapAutoResetWrapper( - Homogeneous computation: call step function on all environments in the batch. - Heterogeneous computation: conditional auto-reset (call reset function for some environments within the batch because they have terminated). + NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. """ + def __init__(self, env: Environment, next_obs_in_extras: bool = False): + """Wrap an environment to vmap it and automatically reset it when the episode terminates. + + Args: + env: the environment to wrap. + next_obs_in_extras: whether to store the next observation in the extras of the + terminal timestep. This is useful for e.g. truncation. + """ + super().__init__(env) + self.next_obs_in_extras = next_obs_in_extras + if next_obs_in_extras: + self._maybe_add_obs_to_extras = add_obs_to_extras + else: + self._maybe_add_obs_to_extras = lambda timestep: timestep # no-op + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets a batch of environments to initial states. @@ -449,6 +506,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: environments, """ state, timestep = jax.vmap(self._env.reset)(key) + timestep = self._maybe_add_obs_to_extras(timestep) return state, timestep def step( @@ -495,6 +553,9 @@ def _auto_reset( key, _ = jax.random.split(state.key) state, reset_timestep = self._env.reset(key) + # Place original observation in extras. + timestep = self._maybe_add_obs_to_extras(timestep) + # Replace observation with reset observation. timestep = timestep.replace( # type: ignore observation=reset_timestep.observation @@ -509,7 +570,7 @@ def _maybe_reset( state, timestep = jax.lax.cond( timestep.last(), self._auto_reset, - lambda *x: x, + lambda s, t: (s, self._maybe_add_obs_to_extras(t)), state, timestep, ) diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index aae6d39d6..8f697071d 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -32,6 +32,7 @@ from jumanji.testing.pytrees import assert_trees_are_different from jumanji.types import StepType, TimeStep from jumanji.wrappers import ( + NEXT_OBS_KEY_IN_EXTRAS, AutoResetWrapper, JumanjiToDMEnvWrapper, JumanjiToGymWrapper, @@ -526,7 +527,7 @@ class TestAutoResetWrapper: def fake_auto_reset_environment( self, fake_environment: FakeEnvironment ) -> FakeAutoResetWrapper: - return AutoResetWrapper(fake_environment) + return AutoResetWrapper(fake_environment, next_obs_in_extras=True) @pytest.fixture def fake_state_and_timestep( @@ -551,6 +552,8 @@ def test_auto_reset_wrapper__auto_reset( state, timestep ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # Expect that non-reset timestep obs and extras are the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_auto_reset_wrapper__step_no_reset( self, fake_auto_reset_environment: FakeAutoResetWrapper, key: chex.PRNGKey @@ -570,6 +573,8 @@ def test_auto_reset_wrapper__step_no_reset( assert timestep.step_type == StepType.MID assert_trees_are_different(timestep, first_timestep) chex.assert_trees_all_equal(timestep.reward, 0) + # no reset so expect extras and obs to be the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_auto_reset_wrapper__step_reset( self, @@ -580,18 +585,30 @@ def test_auto_reset_wrapper__step_reset( """Validates that the auto-reset is done correctly by the step function of the AutoResetWrapper when the terminal timestep is reached. """ - state, first_timestep = fake_auto_reset_environment.reset(key) + state, first_timestep = fake_auto_reset_environment.reset(key) # type: ignore fake_environment.time_limit = 5 # Loop across time_limit so auto-reset occurs - timestep = first_timestep - for _ in range(fake_environment.time_limit): + for _ in range(fake_environment.time_limit - 1): action = fake_auto_reset_environment.action_spec.generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) + assert jnp.all( + timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) - assert timestep.step_type == StepType.LAST - chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) + state, final_timestep = jax.jit(fake_auto_reset_environment.step)(state, action) + + assert final_timestep.step_type == StepType.LAST + chex.assert_trees_all_equal( + final_timestep.observation, first_timestep.observation + ) + assert not jnp.all( + final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all( + (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) FakeVmapAutoResetWrapper = VmapAutoResetWrapper[ @@ -604,7 +621,7 @@ class TestVmapAutoResetWrapper: def fake_vmap_auto_reset_environment( self, fake_environment: FakeEnvironment ) -> FakeVmapAutoResetWrapper: - return VmapAutoResetWrapper(fake_environment) + return VmapAutoResetWrapper(fake_environment, next_obs_in_extras=True) @pytest.fixture def action( @@ -637,6 +654,8 @@ def test_vmap_auto_reset_wrapper__reset( assert timestep.observation.shape[0] == keys.shape[0] assert timestep.reward.shape == (keys.shape[0],) assert timestep.discount.shape == (keys.shape[0],) + # only reset so expect extras and obs to be the same. + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) def test_vmap_auto_reset_wrapper__auto_reset( self, @@ -650,6 +669,10 @@ def test_vmap_auto_reset_wrapper__auto_reset( (state, timestep), ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # expect rest timestep.extras to have the same obs as the original timestep + assert jnp.all( + timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__maybe_reset( self, @@ -663,6 +686,10 @@ def test_vmap_auto_reset_wrapper__maybe_reset( (state, timestep), ) chex.assert_trees_all_equal(timestep.observation, reset_timestep.observation) + # expect rest timestep.extras to have the same obs as the original timestep + assert jnp.all( + timestep.observation == reset_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__step_no_reset( self, @@ -680,6 +707,13 @@ def test_vmap_auto_reset_wrapper__step_no_reset( assert_trees_are_different(timestep, first_timestep) chex.assert_trees_all_equal(timestep.reward, 0) + # no reset so expect extras and obs to be the same. + # and the first timestep should have different obs in extras. + assert not jnp.all( + first_timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all(timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]) + def test_vmap_auto_reset_wrapper__step_reset( self, fake_environment: FakeEnvironment, @@ -694,13 +728,27 @@ def test_vmap_auto_reset_wrapper__step_reset( fake_vmap_auto_reset_environment.unwrapped.time_limit = 5 # type: ignore # Loop across time_limit so auto-reset occurs - for _ in range(fake_vmap_auto_reset_environment.time_limit): + for _ in range(fake_vmap_auto_reset_environment.time_limit - 1): state, timestep = jax.jit(fake_vmap_auto_reset_environment.step)( state, action ) + assert jnp.all( + timestep.observation == timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) - assert jnp.all(timestep.step_type == StepType.LAST) - chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) + state, final_timestep = jax.jit(fake_vmap_auto_reset_environment.step)( + state, action + ) + assert jnp.all(final_timestep.step_type == StepType.LAST) + chex.assert_trees_all_equal( + final_timestep.observation, first_timestep.observation + ) + assert not jnp.all( + final_timestep.observation == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) + assert jnp.all( + (timestep.observation + 1) == final_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__step( self, @@ -719,6 +767,10 @@ def test_vmap_auto_reset_wrapper__step( assert next_timestep.reward.shape == (keys.shape[0],) assert next_timestep.discount.shape == (keys.shape[0],) assert next_timestep.observation.shape[0] == keys.shape[0] + # expect observation and extras to be the same, since no reset + assert jnp.all( + next_timestep.observation == next_timestep.extras[NEXT_OBS_KEY_IN_EXTRAS] + ) def test_vmap_auto_reset_wrapper__render( self, diff --git a/mkdocs.yml b/mkdocs.yml index 33c290c91..39dbca1cd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,6 +38,7 @@ nav: - RobotWarehouse: environments/robot_warehouse.md - Snake: environments/snake.md - TSP: environments/tsp.md + - Sokoban: environments/sokoban.md - PacMan: environments/pac_man.md - User Guides: - Advanced Usage: guides/advanced_usage.md @@ -68,6 +69,7 @@ nav: - RobotWarehouse: api/environments/robot_warehouse.md - Snake: api/environments/snake.md - TSP: api/environments/tsp.md + - Sokoban: api/environments/sokoban.md - PacMan: api/environments/pac_man.md - Wrappers: api/wrappers.md - Types: api/types.md diff --git a/pyproject.toml b/pyproject.toml index a6321cfd9..79ed22fe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,5 +41,9 @@ module = [ "haiku.*", "hydra.*", "omegaconf.*", + "huggingface_hub.*", + "requests.*", + "pkg_resources.*", + "PIL.*", ] ignore_missing_imports = true diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index df33e2e2b..ea0b437c2 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -24,3 +24,6 @@ pytest-xdist pytype scipy>=1.7.3 testfixtures +types-Pillow +types-requests +types-setuptools diff --git a/requirements/requirements-train.txt b/requirements/requirements-train.txt index 644e7f1bd..087ca0369 100644 --- a/requirements/requirements-train.txt +++ b/requirements/requirements-train.txt @@ -1,5 +1,4 @@ dm-haiku==0.0.9 -huggingface-hub hydra-core==1.3 neptune-client==0.16.15 optax>=0.1.4 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 4e43bdd4e..2e398c025 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,7 @@ chex>=0.1.3 dm-env>=1.5 gym>=0.22.0 +huggingface-hub jax>=0.2.26 matplotlib~=3.7.4 numpy>=1.19.5 diff --git a/setup.cfg b/setup.cfg index 6b8becfe0..d032b15c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ exclude = .cache, .eggs max-line-length=100 -max-cognitive-complexity=10 +max-cognitive-complexity=14 import-order-style = google application-import-names = jumanji doctests = True From 68282afd2b338c1c0f0d2c07f599b122f239f419 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 17:42:52 +0000 Subject: [PATCH 11/16] feat(wrappers): add typevars to autoreset wrappers --- jumanji/wrappers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 1025625a8..e76e1c878 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -396,7 +396,11 @@ class AutoResetWrapper( being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead. """ - def __init__(self, env: Environment, next_obs_in_extras: bool = False): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + next_obs_in_extras: bool = False, + ): """Wrap an environment to automatically reset it when the episode terminates. Args: @@ -473,7 +477,11 @@ class VmapAutoResetWrapper( NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. """ - def __init__(self, env: Environment, next_obs_in_extras: bool = False): + def __init__( + self, + env: Environment[State, ActionSpec, Observation], + next_obs_in_extras: bool = False, + ): """Wrap an environment to vmap it and automatically reset it when the episode terminates. Args: From d04c66d1787eea2e7b1c386770eaf2ca1a937e01 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 17:44:02 +0000 Subject: [PATCH 12/16] feat(sokoban): change sokoban specs to properties --- jumanji/environments/routing/sokoban/env.py | 8 +++++--- jumanji/training/networks/sokoban/actor_critic.py | 4 +--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py index c56fcbf89..60aeda0ce 100644 --- a/jumanji/environments/routing/sokoban/env.py +++ b/jumanji/environments/routing/sokoban/env.py @@ -45,7 +45,7 @@ from jumanji.viewer import Viewer -class Sokoban(Environment[State]): +class Sokoban(Environment[State, specs.DiscreteArray, Observation]): """A JAX implementation of the 'Sokoban' game from deepmind. - observation: `Observation` @@ -136,6 +136,8 @@ def __init__( self.shape = (self.num_rows, self.num_cols) self.time_limit = time_limit + super().__init__() + self.generator = generator or HuggingFaceDeepMindGenerator( "unfiltered-train", proportion_of_files=1, @@ -256,7 +258,7 @@ def step( return next_state, timestep - def observation_spec(self) -> specs.Spec[Observation]: + def _make_observation_spec(self) -> specs.Spec[Observation]: """ Returns the specifications of the observation of the `Sokoban` environment. @@ -279,7 +281,7 @@ def observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def action_spec(self) -> specs.DiscreteArray: + def _make_action_spec(self) -> specs.DiscreteArray: """ Returns the action specification for the Sokoban environment. There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. diff --git a/jumanji/training/networks/sokoban/actor_critic.py b/jumanji/training/networks/sokoban/actor_critic.py index 968180c60..97f9d8174 100644 --- a/jumanji/training/networks/sokoban/actor_critic.py +++ b/jumanji/training/networks/sokoban/actor_critic.py @@ -36,7 +36,7 @@ def make_actor_critic_networks_sokoban( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Sokoban` environment.""" - num_actions = sokoban.action_spec().num_values + num_actions = sokoban.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions ) @@ -68,7 +68,6 @@ def make_sokoban_cnn( time_limit: int, ) -> FeedForwardNetwork: def network_fn(observation: Observation) -> chex.Array: - # Iterate over the channels sequence to create convolutional layers layers = [] for i, conv_n_channels in enumerate(channels): @@ -101,7 +100,6 @@ def network_fn(observation: Observation) -> chex.Array: def preprocess_input( input_array: chex.Array, ) -> chex.Array: - one_hot_array_fixed = jnp.equal(input_array[..., 0:1], jnp.array([3, 4])).astype( jnp.float32 ) From fcb952f3bb09cc4213af0e62dd218a1ca13c25de Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 19:16:22 +0000 Subject: [PATCH 13/16] feat: implement specs as cached properties --- jumanji/env.py | 51 ++++--------------- jumanji/environments/logic/game_2048/env.py | 9 ++-- .../environments/logic/graph_coloring/env.py | 9 ++-- jumanji/environments/logic/minesweeper/env.py | 9 ++-- jumanji/environments/logic/rubiks_cube/env.py | 9 ++-- jumanji/environments/logic/sudoku/env.py | 11 ++-- jumanji/environments/packing/bin_pack/env.py | 7 ++- jumanji/environments/packing/job_shop/env.py | 7 ++- jumanji/environments/packing/knapsack/env.py | 11 ++-- jumanji/environments/packing/tetris/env.py | 9 ++-- jumanji/environments/routing/cleaner/env.py | 7 ++- jumanji/environments/routing/connector/env.py | 9 ++-- jumanji/environments/routing/cvrp/env.py | 11 ++-- jumanji/environments/routing/maze/env.py | 9 ++-- jumanji/environments/routing/mmst/env.py | 11 ++-- .../environments/routing/multi_cvrp/env.py | 11 ++-- jumanji/environments/routing/pac_man/env.py | 7 ++- .../routing/robot_warehouse/env.py | 9 ++-- jumanji/environments/routing/snake/env.py | 11 ++-- jumanji/environments/routing/sokoban/env.py | 7 ++- jumanji/environments/routing/tsp/env.py | 11 ++-- jumanji/testing/fakes.py | 29 +++++++---- jumanji/wrappers.py | 21 +++++--- jumanji/wrappers_test.py | 36 ++++++------- 24 files changed, 181 insertions(+), 140 deletions(-) diff --git a/jumanji/env.py b/jumanji/env.py index d728f86b9..48035a992 100644 --- a/jumanji/env.py +++ b/jumanji/env.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +from functools import cached_property from typing import Any, Generic, Tuple, TypeVar import chex @@ -48,10 +49,10 @@ def __repr__(self) -> str: def __init__(self) -> None: """Initialize environment.""" - self._observation_spec = self._make_observation_spec() - self._action_spec = self._make_action_spec() - self._reward_spec = self._make_reward_spec() - self._discount_spec = self._make_discount_spec() + self.observation_spec + self.action_spec + self.reward_spec + self.discount_spec @abc.abstractmethod def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -80,69 +81,37 @@ def step( timestep: TimeStep object corresponding the timestep returned by the environment, """ - @property + @abc.abstractmethod + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. Returns: observation_spec: a potentially nested `Spec` structure representing the observation. """ - return self._observation_spec @abc.abstractmethod - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. - - Returns: - observation_spec: a potentially nested `Spec` structure representing the observation. - """ - - @property + @cached_property def action_spec(self) -> ActionSpec: """Returns the action spec. Returns: action_spec: a potentially nested `Spec` structure representing the action. """ - return self._action_spec - - @abc.abstractmethod - def _make_action_spec(self) -> ActionSpec: - """Returns new action spec. - - Returns: - action_spec: a potentially nested `Spec` structure representing the action. - """ - @property + @cached_property def reward_spec(self) -> specs.Array: """Returns the reward spec. By default, this is assumed to be a single float. - Returns: - reward_spec: a `specs.Array` spec. - """ - return self._reward_spec - - def _make_reward_spec(self) -> specs.Array: - """Returns new reward spec. By default, this is assumed to be a single float. - Returns: reward_spec: a `specs.Array` spec. """ return specs.Array(shape=(), dtype=float, name="reward") - @property + @cached_property def discount_spec(self) -> specs.BoundedArray: """Returns the discount spec. By default, this is assumed to be a single float between 0 and 1. - Returns: - discount_spec: a `specs.BoundedArray` spec. - """ - return self._discount_spec - - def _make_discount_spec(self) -> specs.BoundedArray: - """Returns new discount spec. By default, this is assumed to be a single float between 0 and 1. - Returns: discount_spec: a `specs.BoundedArray` spec. """ diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 6f217478c..45d189d2e 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -98,7 +99,8 @@ def __repr__(self) -> str: """ return f"2048 Game(board_size={self.board_size})" - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Game2048` environment. Returns: @@ -123,8 +125,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: ), ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py index c57a82b22..b5e65a3e5 100644 --- a/jumanji/environments/logic/graph_coloring/env.py +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -207,8 +208,9 @@ def step( ) return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for the `Observation` whose fields are: @@ -254,7 +256,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: ), ) - def _make_action_spec(self) -> specs.DiscreteArray: + @cached_property + def action_spec(self) -> specs.DiscreteArray: """Specification of the action for the `GraphColoring` environment. Returns: diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 4715145b1..1e9d8d4f1 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -183,7 +184,8 @@ def step( ) return next_state, next_timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Minesweeper` environment. Returns: @@ -230,8 +232,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. An action consists of the height and width of the square to be explored. Returns: diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index f34ab3893..a4472e0ed 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -174,7 +175,8 @@ def step( ) return next_state, next_timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `RubiksCube` environment. Returns: @@ -203,8 +205,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. An action is composed of 3 elements that range in: 6 faces, each + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: diff --git a/jumanji/environments/logic/sudoku/env.py b/jumanji/environments/logic/sudoku/env.py index 9890be728..64a91899d 100644 --- a/jumanji/environments/logic/sudoku/env.py +++ b/jumanji/environments/logic/sudoku/env.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -130,8 +131,9 @@ def step( return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec containing the board and action_mask arrays. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec containing the board and action_mask arrays. Returns: Spec containing all the specifications for all the `Observation` fields: @@ -159,8 +161,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: Observation, "ObservationSpec", board=board, action_mask=action_mask ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. An action is composed of 3 integers: the row index, + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: diff --git a/jumanji/environments/packing/bin_pack/env.py b/jumanji/environments/packing/bin_pack/env.py index f026bae5c..4f410af62 100644 --- a/jumanji/environments/packing/bin_pack/env.py +++ b/jumanji/environments/packing/bin_pack/env.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -172,7 +173,8 @@ def __repr__(self) -> str: ] ) - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `BinPack` environment. Returns: @@ -249,7 +251,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `BinPack` environment. Returns: diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index f835e7436..ec1e22e79 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -357,7 +358,8 @@ def _update_machines( return updated_machines_job_ids, updated_machines_remaining_times - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `JobShop` environment. Returns: @@ -422,7 +424,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action in the `JobShop` environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. diff --git a/jumanji/environments/packing/knapsack/env.py b/jumanji/environments/packing/knapsack/env.py index f24548a6f..3d544132a 100644 --- a/jumanji/environments/packing/knapsack/env.py +++ b/jumanji/environments/packing/knapsack/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -177,8 +178,9 @@ def step( return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for each field in the Observation: @@ -224,8 +226,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/packing/tetris/env.py b/jumanji/environments/packing/tetris/env.py index 379518576..995cb1fd6 100644 --- a/jumanji/environments/packing/tetris/env.py +++ b/jumanji/environments/packing/tetris/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -247,7 +248,8 @@ def render(self, state: State) -> Optional[NDArray]: """ return self._viewer.render(state) - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Tetris` environment. Returns: @@ -286,8 +288,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: ), ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. An action consists of two pieces of information: + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. diff --git a/jumanji/environments/routing/cleaner/env.py b/jumanji/environments/routing/cleaner/env.py index 739b23de9..7dd8c6423 100644 --- a/jumanji/environments/routing/cleaner/env.py +++ b/jumanji/environments/routing/cleaner/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Dict, Optional, Sequence, Tuple import chex @@ -123,7 +124,8 @@ def __repr__(self) -> str: ")" ) - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `Cleaner` environment. Returns: @@ -153,7 +155,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: """Specification of the action for the `Cleaner` environment. Returns: diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index bc8177dc6..fc78dd891 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -319,7 +320,8 @@ def close(self) -> None: """ self._viewer.close() - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Connector` environment. Returns: @@ -357,8 +359,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec for the Connector environment. + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). diff --git a/jumanji/environments/routing/cvrp/env.py b/jumanji/environments/routing/cvrp/env.py index bd68552e3..921dc646e 100644 --- a/jumanji/environments/routing/cvrp/env.py +++ b/jumanji/environments/routing/cvrp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -196,8 +197,9 @@ def step( ) return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for the `Observation` whose fields are: @@ -262,8 +264,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/routing/maze/env.py b/jumanji/environments/routing/maze/env.py index 8de273f3f..c2f0100dd 100644 --- a/jumanji/environments/routing/maze/env.py +++ b/jumanji/environments/routing/maze/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -118,7 +119,8 @@ def __repr__(self) -> str: ] ) - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `Maze` environment. Returns: @@ -160,8 +162,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: action_spec: discrete action space with 4 values. diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index ca7394f12..1b2b12746 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Dict, Optional, Sequence, Tuple import chex @@ -284,8 +285,9 @@ def step_agent_fn( state, timestep = self._state_to_timestep(state, action) return state, timestep - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. Returns: action_spec: a `specs.MultiDiscreteArray` spec. @@ -295,8 +297,9 @@ def _make_action_spec(self) -> specs.MultiDiscreteArray: name="action", ) - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for the `Observation` whose fields are: diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 2590b1539..9eb0fc568 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -189,9 +190,10 @@ def step( return new_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """ - Returns new observation spec. + Returns the observation spec. Returns: observation_spec: a Tuple containing the spec for each of the constituent fields of an @@ -318,9 +320,10 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.BoundedArray: + @cached_property + def action_spec(self) -> specs.BoundedArray: """ - Returns new action spec. + Returns the action spec. Returns: action_spec: a `specs.BoundedArray` spec. diff --git a/jumanji/environments/routing/pac_man/env.py b/jumanji/environments/routing/pac_man/env.py index 0d866d8d4..3007042b2 100644 --- a/jumanji/environments/routing/pac_man/env.py +++ b/jumanji/environments/routing/pac_man/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Any, Optional, Sequence, Tuple import chex @@ -133,7 +134,8 @@ def __init__( self._viewer = viewer or PacManViewer("Pacman", render_mode="human") self.time_limit = 1000 or time_limit - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `PacMan` environment. Returns: @@ -200,7 +202,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: score=score, ) - def _make_action_spec(self) -> specs.DiscreteArray: + @cached_property + def action_spec(self) -> specs.DiscreteArray: """Returns the action spec. 5 actions: [0,1,2,3,4] -> [Up, Right, Down, Left, No-op]. diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py index 290b75020..eb9c2c578 100644 --- a/jumanji/environments/routing/robot_warehouse/env.py +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +from functools import cached_property from typing import List, Optional, Sequence, Tuple import chex @@ -335,7 +336,8 @@ def update_reward_and_request_queue_scan( ) return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the `RobotWarehouse` environment. Returns: Spec for the `Observation`, consisting of the fields: @@ -358,8 +360,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.MultiDiscreteArray: - """Returns new action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + @cached_property + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,). """ diff --git a/jumanji/environments/routing/snake/env.py b/jumanji/environments/routing/snake/env.py index a39e9cb97..0a1d0451c 100644 --- a/jumanji/environments/routing/snake/env.py +++ b/jumanji/environments/routing/snake/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -235,8 +236,9 @@ def step( ) return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for the `Observation` whose fields are: @@ -269,8 +271,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py index 60aeda0ce..6fada2037 100644 --- a/jumanji/environments/routing/sokoban/env.py +++ b/jumanji/environments/routing/sokoban/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -258,7 +259,8 @@ def step( return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """ Returns the specifications of the observation of the `Sokoban` environment. @@ -281,7 +283,8 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: step_count=step_count, ) - def _make_action_spec(self) -> specs.DiscreteArray: + @cached_property + def action_spec(self) -> specs.DiscreteArray: """ Returns the action specification for the Sokoban environment. There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. diff --git a/jumanji/environments/routing/tsp/env.py b/jumanji/environments/routing/tsp/env.py index c945b20ec..f6d57bf93 100644 --- a/jumanji/environments/routing/tsp/env.py +++ b/jumanji/environments/routing/tsp/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -170,8 +171,9 @@ def step( ) return next_state, timestep - def _make_observation_spec(self) -> specs.Spec[Observation]: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. Returns: Spec for the `Observation` whose fields are: @@ -213,8 +215,9 @@ def _make_observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) - def _make_action_spec(self) -> specs.DiscreteArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.DiscreteArray: + """Returns the action spec. Returns: action_spec: a `specs.DiscreteArray` spec. diff --git a/jumanji/testing/fakes.py b/jumanji/testing/fakes.py index 002778a31..a41246e1d 100644 --- a/jumanji/testing/fakes.py +++ b/jumanji/testing/fakes.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: @@ -59,8 +60,9 @@ def __init__( super().__init__() self._example_action = self.action_spec.generate_value() - def _make_observation_spec(self) -> specs.Array: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Array: + """Returns the observation spec. Returns: observation_spec: a `specs.Array` spec. @@ -70,8 +72,9 @@ def _make_observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) - def _make_action_spec(self) -> specs.BoundedArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.BoundedArray: + """Returns the action spec. Returns: action_spec: a `specs.DiscreteArray` spec. @@ -177,8 +180,9 @@ def __init__( ), f"""a leading dimension of size 'num_agents': {num_agents} is expected for the observation, got shape: {observation_shape}.""" - def _make_observation_spec(self) -> specs.Array: - """Returns new observation spec. + @cached_property + def observation_spec(self) -> specs.Array: + """Returns the observation spec. Returns: observation_spec: a `specs.Array` spec. @@ -188,8 +192,9 @@ def _make_observation_spec(self) -> specs.Array: shape=self.observation_shape, dtype=float, name="observation" ) - def _make_action_spec(self) -> specs.BoundedArray: - """Returns new action spec. + @cached_property + def action_spec(self) -> specs.BoundedArray: + """Returns the action spec. Returns: action_spec: a `specs.Array` spec. @@ -199,15 +204,17 @@ def _make_action_spec(self) -> specs.BoundedArray: (self.num_agents,), int, 0, self.num_action_values - 1 ) - def _make_reward_spec(self) -> specs.Array: - """Returns new reward spec. + @cached_property + def reward_spec(self) -> specs.Array: + """Returns the reward spec. Returns: reward_spec: a `specs.Array` spec. """ return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") - def _make_discount_spec(self) -> specs.BoundedArray: + @cached_property + def discount_spec(self) -> specs.BoundedArray: """Describes the discount returned by the environment. Returns: diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index e76e1c878..04a87db45 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from functools import cached_property from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union import chex @@ -81,21 +82,25 @@ def step( """ return self._env.step(state, action) - def _make_observation_spec(self) -> specs.Spec[Observation]: + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec.""" - return self._env._make_observation_spec() + return self._env.observation_spec - def _make_action_spec(self) -> ActionSpec: + @cached_property + def action_spec(self) -> ActionSpec: """Returns the action spec.""" - return self._env._make_action_spec() + return self._env.action_spec - def _make_reward_spec(self) -> specs.Array: + @cached_property + def reward_spec(self) -> specs.Array: """Returns the reward spec.""" - return self._env._make_reward_spec() + return self._env.reward_spec - def _make_discount_spec(self) -> specs.BoundedArray: + @cached_property + def discount_spec(self) -> specs.BoundedArray: """Returns the discount spec.""" - return self._env._make_discount_spec() + return self._env.discount_spec def render(self, state: State) -> Any: """Compute render frames during initialisation of the environment. diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 8f697071d..53d914f3a 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -109,39 +109,39 @@ def test_wrapper__reset( mock_reset.assert_called_once_with(mock_key) - def test_wrapper__make_observation_spec( + def test_wrapper__observation_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: FakeWrapper, + mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper._make_observation_spec` calls the _make_observation_spec function of - the underlying env. - """ - mock_make_obs_spec = mocker.patch.object( - fake_environment, "_make_observation_spec", autospec=True + """Checks `Wrapper.__init__` calls the observation_spec function of the underlying env.""" + mock_obs_spec = mocker.patch.object( + FakeEnvironment, "observation_spec", new_callable=mocker.PropertyMock ) - wrapped_fake_environment._make_observation_spec() + wrapped_fake_environment = mock_wrapper_class(fake_environment) + mock_obs_spec.assert_called_once() - mock_make_obs_spec.assert_called_once() + wrapped_fake_environment.observation_spec + mock_obs_spec.assert_called_once() - def test_wrapper__make_action_spec( + def test_wrapper__action_spec( self, mocker: pytest_mock.MockerFixture, - wrapped_fake_environment: FakeWrapper, + mock_wrapper_class: Type[FakeWrapper], fake_environment: FakeEnvironment, ) -> None: - """Checks `Wrapper._make_action_spec` calls the _make_action_spec function of the underlying - env. - """ - mock_make_action_spec = mocker.patch.object( - fake_environment, "_make_action_spec", autospec=True + """Checks `Wrapper.__init__` calls the action_spec function of the underlying env.""" + mock_action_spec = mocker.patch.object( + FakeEnvironment, "action_spec", new_callable=mocker.PropertyMock ) - wrapped_fake_environment._make_action_spec() + wrapped_fake_environment = mock_wrapper_class(fake_environment) + mock_action_spec.assert_called_once() - mock_make_action_spec.assert_called_once() + wrapped_fake_environment.action_spec + mock_action_spec.assert_called_once() def test_wrapper__repr(self, wrapped_fake_environment: FakeWrapper) -> None: """Checks `Wrapper.__repr__` returns the expected representation string.""" From 063b108df37f4427104ebaacd53b5f60408e28de Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 19:55:29 +0000 Subject: [PATCH 14/16] feat(flatpak): change flatpak specs to cached properties --- jumanji/environments/packing/flat_pack/env.py | 7 +++++-- jumanji/environments/packing/flat_pack/env_test.py | 10 +++++++++- jumanji/environments/routing/sokoban/env_test.py | 10 +++++++++- jumanji/training/networks/flat_pack/actor_critic.py | 3 +-- jumanji/training/networks/flat_pack/random.py | 2 +- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py index 573486a73..9f27df29c 100644 --- a/jumanji/environments/packing/flat_pack/env.py +++ b/jumanji/environments/packing/flat_pack/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Optional, Sequence, Tuple import chex @@ -34,7 +35,7 @@ from jumanji.viewer import Viewer -class FlatPack(Environment[State]): +class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]): """The FlatPack environment with a configurable number of row and column blocks. Here the goal of an agent is to completely fill an empty grid by placing all @@ -129,6 +130,7 @@ def __init__( self.viewer = viewer or FlatPackViewer( "FlatPack", self.num_blocks, render_mode="human" ) + super().__init__() def __repr__(self) -> str: return ( @@ -141,7 +143,6 @@ def reset( self, key: chex.PRNGKey, ) -> Tuple[State, TimeStep[Observation]]: - """Resets the environment. Args: @@ -259,6 +260,7 @@ def close(self) -> None: self.viewer.close() + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec of the environment. @@ -307,6 +309,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask=action_mask, ) + @cached_property def action_spec(self) -> specs.MultiDiscreteArray: """Specifications of the action expected by the `FlatPack` environment. diff --git a/jumanji/environments/packing/flat_pack/env_test.py b/jumanji/environments/packing/flat_pack/env_test.py index 923306349..36b82f77d 100644 --- a/jumanji/environments/packing/flat_pack/env_test.py +++ b/jumanji/environments/packing/flat_pack/env_test.py @@ -28,7 +28,10 @@ CellDenseReward, ) from jumanji.environments.packing.flat_pack.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import StepType, TimeStep @@ -182,6 +185,11 @@ def test_flat_pack__does_not_smoke(flat_pack: FlatPack) -> None: check_env_does_not_smoke(flat_pack) +def test_flat_pack__specs_does_not_smoke(flat_pack: FlatPack) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(flat_pack) + + def test_flat_pack__is_done(flat_pack: FlatPack, key: chex.PRNGKey) -> None: """Test that the is_done method works as expected.""" diff --git a/jumanji/environments/routing/sokoban/env_test.py b/jumanji/environments/routing/sokoban/env_test.py index 8c3d8da93..5579c94af 100644 --- a/jumanji/environments/routing/sokoban/env_test.py +++ b/jumanji/environments/routing/sokoban/env_test.py @@ -26,7 +26,10 @@ SimpleSolveGenerator, ) from jumanji.environments.routing.sokoban.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.types import TimeStep @@ -215,3 +218,8 @@ def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None: def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(sokoban) + + +def test_sokoban__specs_does_not_smoke(sokoban: Sokoban) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(sokoban) diff --git a/jumanji/training/networks/flat_pack/actor_critic.py b/jumanji/training/networks/flat_pack/actor_critic.py index 5c6923b4d..d2c6fc787 100644 --- a/jumanji/training/networks/flat_pack/actor_critic.py +++ b/jumanji/training/networks/flat_pack/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_flat_pack( hidden_size: int, ) -> ActorCriticNetworks: """Make actor-critic networks for the `FlatPack` environment.""" - num_values = np.asarray(flat_pack.action_spec().num_values) + num_values = np.asarray(flat_pack.action_spec.num_values) parametric_action_distribution = FactorisedActionSpaceParametricDistribution( action_spec_num_values=num_values ) @@ -171,7 +171,6 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]: ) # (B, model_size), (B, num_rows-2, num_cols-2, hidden_size) for block_id in range(self.num_transformer_layers): - ( self_attention_mask, # (B, 1, num_blocks, num_blocks) cross_attention_mask, # (B, 1, num_blocks, 1) diff --git a/jumanji/training/networks/flat_pack/random.py b/jumanji/training/networks/flat_pack/random.py index a81ba43f0..7c8c09463 100644 --- a/jumanji/training/networks/flat_pack/random.py +++ b/jumanji/training/networks/flat_pack/random.py @@ -21,7 +21,7 @@ def make_random_policy_flat_pack(flat_pack: FlatPack) -> RandomPolicy: """Make random policy for FlatPack.""" - action_spec_num_values = flat_pack.action_spec().num_values + action_spec_num_values = flat_pack.action_spec.num_values return make_masked_categorical_random_ndim( action_spec_num_values=action_spec_num_values From ed66bf320216bbe6fec9b6941cdf636f56166ac8 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Wed, 13 Mar 2024 15:50:13 -0500 Subject: [PATCH 15/16] fix: constrain types-requests version to <1.27 due to conflict with neptune-client --- requirements/requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index ea0b437c2..58afa9227 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -25,5 +25,5 @@ pytype scipy>=1.7.3 testfixtures types-Pillow -types-requests +types-requests<1.27 types-setuptools From a1aec4f57ed2787e4123ba5a078dbaf975928551 Mon Sep 17 00:00:00 2001 From: Avi Revah Date: Thu, 21 Mar 2024 03:57:47 +0000 Subject: [PATCH 16/16] feat(sliding tile): change sliding tile specs to properties --- .../environments/logic/sliding_tile_puzzle/env.py | 7 +++++-- .../logic/sliding_tile_puzzle/env_test.py | 12 +++++++++++- jumanji/environments/packing/flat_pack/env.py | 2 +- jumanji/environments/routing/mmst/env.py | 2 +- jumanji/environments/routing/multi_cvrp/env.py | 2 +- jumanji/environments/routing/sokoban/env.py | 2 +- .../networks/sliding_tile_puzzle/actor_critic.py | 2 +- 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env.py b/jumanji/environments/logic/sliding_tile_puzzle/env.py index fe6a29f3d..9ab17e6c4 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import Dict, Optional, Sequence, Tuple import chex @@ -40,7 +41,7 @@ from jumanji.viewer import Viewer -class SlidingTilePuzzle(Environment[State]): +class SlidingTilePuzzle(Environment[State, specs.DiscreteArray, Observation]): """Environment for the Sliding Tile Puzzle problem. The problem is a combinatorial optimization task where the goal is @@ -95,8 +96,8 @@ def __init__( grid_size=5, num_random_moves=200 ) self.reward_fn = reward_fn or DenseRewardFn() - self.time_limit = time_limit + super().__init__() # Create viewer used for rendering self._env_viewer = viewer or SlidingTilePuzzleViewer(name="SlidingTilePuzzle") @@ -205,6 +206,7 @@ def _get_extras(self, state: State) -> Dict[str, chex.Array]: num_correct_tiles = jnp.sum(self.solved_puzzle == state.puzzle) return {"prop_correctly_placed": num_correct_tiles / state.puzzle.size} + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec.""" grid_size = self.generator.grid_size @@ -241,6 +243,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ), ) + @cached_property def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" # Up, Right, Down, Left diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py index 31bab5f0e..5ed9ddc75 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env_test.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env_test.py @@ -18,7 +18,10 @@ from jumanji.environments.logic.sliding_tile_puzzle import SlidingTilePuzzle from jumanji.environments.logic.sliding_tile_puzzle.types import State -from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) from jumanji.testing.pytrees import assert_is_jax_array_tree from jumanji.types import TimeStep @@ -88,6 +91,13 @@ def test_sliding_tile_puzzle_does_not_smoke( check_env_does_not_smoke(sliding_tile_puzzle) +def test_sliding_tile_puzzle_specs_does_not_smoke( + sliding_tile_puzzle: SlidingTilePuzzle, +) -> None: + """Test that we access specs without any errors.""" + check_env_specs_does_not_smoke(sliding_tile_puzzle) + + def test_env_one_move_to_solve(sliding_tile_puzzle: SlidingTilePuzzle) -> None: """Test that the environment correctly handles a situation where the puzzle is one move away from being solved. diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py index 9f27df29c..e1125e98b 100644 --- a/jumanji/environments/packing/flat_pack/env.py +++ b/jumanji/environments/packing/flat_pack/env.py @@ -92,7 +92,7 @@ class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` diff --git a/jumanji/environments/routing/mmst/env.py b/jumanji/environments/routing/mmst/env.py index 1b2b12746..386f2dd3c 100644 --- a/jumanji/environments/routing/mmst/env.py +++ b/jumanji/environments/routing/mmst/env.py @@ -125,7 +125,7 @@ class MMST(Environment[State, specs.MultiDiscreteArray, Observation]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` diff --git a/jumanji/environments/routing/multi_cvrp/env.py b/jumanji/environments/routing/multi_cvrp/env.py index 9eb0fc568..2cd53c46c 100644 --- a/jumanji/environments/routing/multi_cvrp/env.py +++ b/jumanji/environments/routing/multi_cvrp/env.py @@ -72,7 +72,7 @@ class MultiCVRP(Environment[State, specs.BoundedArray, Observation]): key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) - action = env.action_spec().generate_value() + action = env.action_spec.generate_value() state, timestep = jax.jit(env.step)(state, action) env.render(state) ``` diff --git a/jumanji/environments/routing/sokoban/env.py b/jumanji/environments/routing/sokoban/env.py index 6fada2037..2433df322 100644 --- a/jumanji/environments/routing/sokoban/env.py +++ b/jumanji/environments/routing/sokoban/env.py @@ -104,7 +104,7 @@ class Sokoban(Environment[State, specs.DiscreteArray, Observation]): key_train = jax.random.PRNGKey(0) state, timestep = jax.jit(env_train.reset)(key_train) env_train.render(state) - action = env_train.action_spec().generate_value() + action = env_train.action_spec.generate_value() state, timestep = jax.jit(env_train.step)(state, action) env_train.render(state) ``` diff --git a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py index 71625f7cc..5c4a6752c 100644 --- a/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py +++ b/jumanji/training/networks/sliding_tile_puzzle/actor_critic.py @@ -40,7 +40,7 @@ def make_actor_critic_networks_sliding_tile_puzzle( value_layers: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `SlidingTilePuzzle` environment.""" - num_actions = sliding_tile_puzzle.action_spec().num_values + num_actions = sliding_tile_puzzle.action_spec.num_values parametric_action_distribution = CategoricalParametricDistribution( num_actions=num_actions )