Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FlatPack environment #188

Merged
merged 91 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
960ead0
feat: initial jigsaw commit.
RuanJohn May 21, 2023
0e6c994
feat: added puzzle numbers to env viewer
RuanJohn May 22, 2023
19f21f1
feat: initial code for random agent network.
RuanJohn May 22, 2023
3988f4d
feat: remove board action mask.
RuanJohn May 22, 2023
14a53ed
feat: add jigsaw random agent.
RuanJohn May 22, 2023
62625c0
chore: change board_dim to num_rows and num_cols
RuanJohn May 22, 2023
6828cfa
feat: register environment and add random networks
RuanJohn May 23, 2023
b81dda2
feat: full action mask working.
RuanJohn May 25, 2023
f67315c
feat: cleaner action mask generation.
RuanJohn May 25, 2023
941fee1
feat: added jigsaw documentation
RuanJohn May 28, 2023
2839da0
chore: typo fix.
RuanJohn May 28, 2023
3474efa
feat: added class doctring to env.
RuanJohn May 28, 2023
6378477
feat: import jigsaw actor critic network.
RuanJohn May 28, 2023
478b504
wip: work on actor critic networks.
RuanJohn May 28, 2023
316e733
chore: better variable naming
RuanJohn May 28, 2023
4e35b86
chore: variable renaming in jigsaw networks.
RuanJohn May 29, 2023
b681f83
chore: variable renaming in jigsaw networks.
RuanJohn May 29, 2023
ee87ec3
feat: jigsaw networks implemented.
RuanJohn May 29, 2023
8416969
fix: fix action spec off by one.
RuanJohn May 29, 2023
5868835
feat: added jigsaw training config.
RuanJohn May 29, 2023
4a8caad
chore: minor fixes.
RuanJohn May 29, 2023
d0aa02c
chore: fix action space in docs.
RuanJohn May 29, 2023
e876908
chore: docs action mask fix.
RuanJohn May 29, 2023
57b6a81
chore: action mask fix in docs.
RuanJohn May 29, 2023
0267299
chore: action mask fix in docs.
RuanJohn May 29, 2023
6adfaf3
chore: indent docstrings.
RuanJohn May 29, 2023
6206819
Merge branch 'main' into 143-implement-jigsaw-env
RuanJohn May 29, 2023
b344add
fix: action mask indexing bugfix.
RuanJohn May 30, 2023
416c137
test: first flatpack experiments.
RuanJohn May 30, 2023
f5621bd
feat: rename jigsaw to flat_pack
RuanJohn May 31, 2023
66597cb
chore: more renaming
RuanJohn May 31, 2023
2161d2e
feat: rename types and remove unnecessary fields from the state.
RuanJohn Jun 4, 2023
b1c8233
chore: rename viewer.
RuanJohn Jun 4, 2023
0c30053
feat: reword environment to flatpack.
RuanJohn Jun 4, 2023
9692882
feat: add env training config.
RuanJohn Jun 4, 2023
afad966
feat: update docs.
RuanJohn Jun 4, 2023
95677f1
Merge main into flatpack
RuanJohn Jun 4, 2023
661368d
chore: method renaming.
RuanJohn Jun 4, 2023
90f206c
chore: docstring indents.
RuanJohn Jun 4, 2023
bd39f9e
chore: reward descriptions.
RuanJohn Jun 4, 2023
cf30cb9
chore: typo fix.
RuanJohn Jun 4, 2023
989dc58
feat: new env images.
RuanJohn Jun 4, 2023
2f79724
feat: set default env.
RuanJohn Jun 4, 2023
580dfa5
Merge branch 'main' into flatpack-first-test
RuanJohn Jun 26, 2023
490974b
feat: added equally weighted block dense reward
RuanJohn Jun 26, 2023
78a6bb9
feat: flatpack improved training networks.
RuanJohn Jun 28, 2023
3ad9f6a
Merge branch 'main' into flatpack-first-test
RuanJohn Jul 3, 2023
2aace92
feat: correct final convolution and output projection in Unet.
RuanJohn Jul 14, 2023
6425261
chore: renamed variables
RuanJohn Jul 16, 2023
2d982ae
chore: minor test fixes.
RuanJohn Jul 16, 2023
dc3baee
fix: fix unet output bug.
RuanJohn Jul 16, 2023
29e3368
feat: added extra reward tests.
RuanJohn Jul 17, 2023
c88634a
chore: removed todo comments.
RuanJohn Jul 17, 2023
aa3c069
chore: cleaned up doctrings, variable names and comments in environme…
RuanJohn Jul 17, 2023
6f38714
chore: fixed comments and docstrings in reward file.
RuanJohn Jul 17, 2023
d7cb940
chore: fixed comments and docstrings in reward file.
RuanJohn Jul 17, 2023
ac7fae1
chore: fixed comments and docstrings in utils file.
RuanJohn Jul 17, 2023
23c21b7
feat: removed simplified env from registered environments.
RuanJohn Jul 17, 2023
2a16dc7
chore: fixed typos in documentation.
RuanJohn Jul 17, 2023
e05dfea
chore: fixed flatpack actor-critic networks comments.
RuanJohn Jul 17, 2023
7e2805f
feat: added new flatpack gif and png
RuanJohn Jul 19, 2023
bade29b
Merge branch 'main' into implement-flatpack-environment
clement-bonnet Aug 19, 2023
d638261
Merge branch 'main' into implement-flatpack-environment
sash-a Jan 10, 2024
d1b5291
Merge branch 'main' into implement-flatpack-environment
sash-a Feb 14, 2024
00e9f8f
Merge branch 'main' into implement-flatpack-environment
sash-a Mar 7, 2024
e83488d
chore: add action mask shape comment to types
RuanJohn Mar 11, 2024
747abfe
chore: docstring formatting
RuanJohn Mar 11, 2024
da20f74
chore: fix env name in config comments
RuanJohn Mar 11, 2024
cb613e5
chore: reset defualt training agent to random
RuanJohn Mar 11, 2024
aae8d1e
chore: docstring formatting fix
RuanJohn Mar 11, 2024
6f1c9dd
chore: docstring fix
RuanJohn Mar 11, 2024
67a77b6
chore: grammar fix
RuanJohn Mar 11, 2024
2ab98ed
chore: docstring fomatting
RuanJohn Mar 11, 2024
bdb11f3
chore: docstring formatting fix
RuanJohn Mar 11, 2024
1055147
chore: remove manually setting vmap in_axes to defaults
RuanJohn Mar 11, 2024
627ca7c
chore: remove manually setting vmap in_axes to defaults
RuanJohn Mar 11, 2024
e9e1a97
chore: code formatting
RuanJohn Mar 11, 2024
9d463fc
Merge branch 'main' into implement-flatpack-environment
RuanJohn Mar 11, 2024
a05b35f
chore: rename current_grid to grid and make grid the first attribute …
RuanJohn Mar 11, 2024
c08fc62
chore: remove jnp.where from _get_ones_like_expanded_block method
RuanJohn Mar 11, 2024
6a8a77b
chore: remove __init__ from random flatpack generator
RuanJohn Mar 11, 2024
ae48129
chore: rename nibs -> interlock
RuanJohn Mar 11, 2024
8ff3ee1
chore: add comment to explain why grid_with_block is initially empty
RuanJohn Mar 11, 2024
dd8cb49
Merge branch 'main' into implement-flatpack-environment
RuanJohn Mar 11, 2024
dbc8a8d
chore: remove sparse reward
RuanJohn Mar 12, 2024
69ce651
chore: rename env test methods
RuanJohn Mar 12, 2024
fc0d1c0
feat: add flat_pack documentation links to mkdocs.yml
RuanJohn Mar 12, 2024
6d4a797
chore rename action to placed_block in reward classes
RuanJohn Mar 12, 2024
9e2c6b0
feat: determine if action is legal from the state action mask
RuanJohn Mar 12, 2024
95389a6
chore: suggestions from review
sash-a Mar 13, 2024
a87bfd7
chore: change obs and state type to int
sash-a Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ problems.
| 🎨 GraphColoring | Logic | `GraphColoring-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/graph_coloring/) | [doc](https://instadeepai.github.io/jumanji/environments/graph_coloring/) |
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| ✏️ Sudoku | Logic | `Sudoku-v0` <br/>`Sudoku-very-easy-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v2` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| ✏️ Sudoku | Logic | `Sudoku-v0` <br/>`Sudoku-very-easy-v0`| [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🧩 FlatPack (2D Grid filling problem) | Packing | `FlatPack-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/flat_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/flat_pack/) |
sash-a marked this conversation as resolved.
Show resolved Hide resolved
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| ▒ Tetris | Packing | `Tetris-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/tetris/) | [doc](https://instadeepai.github.io/jumanji/environments/tetris/) |
Expand Down
8 changes: 8 additions & 0 deletions docs/api/environments/flat_pack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
::: jumanji.environments.packing.flat_pack.env.FlatPack
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
Binary file added docs/env_anim/flat_pack.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_img/flat_pack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 57 additions & 0 deletions docs/environments/flat_pack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# FlatPack Environment

<p align="center">
<img src="../env_anim/flat_pack.gif" width="500"/>
</p>

We provide here a Jax JIT-able implementation of a packing environment named _flat pack_. The goal of
the agent is to place all the available blocks on an empty 2D grid.
Each time an episode resets a new set of blocks is created and the grid is emptied. Blocks are randomly
shuffled and rotated and all have shape (3, 3).

## Observation
The observation given to the agent gives a view of the current state of the grid as well as
all blocks that can be placed.

- `current_grid`: jax array (float32) of shape `(num_rows, num_cols)` with values in the range
`[0, num_blocks]` (corresponding to the number of each block). This grid will have zeros
where no blocks have been placed and numbers corresponding to each block where that particular
block has been placed.

- `blocks`: jax array (float32) of shape `(num_blocks, 3, 3)` of all possible blocks in
that can fit in the current grid. These blocks are shuffled, rotated and will always have shape `(3, 3)`.

- `action_mask`: jax array (bool) of shape `(num_blocks, 4, num_rows-2, num_cols-2)`, representing
which actions are possible given the current state of the grid. The first index indicates the
number of blocks associated with a given grid. The second index indicates the number of times a block may be rotated.
The third and fourth indices indicate the row and column coordinate of where a blocks top left-most corner may be placed
respectively. Blocks are placed by an agent by specifying the row and column coordinate on the grid where the top left corner
of the selected block should be placed. These values will always be `num_rows-2` and `num_cols-2`
respectively to make it impossible for an agent to place a block outside the current grid.


## Action
The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_blocks`,
sash-a marked this conversation as resolved.
Show resolved Hide resolved
an index between 0 and 4 (since there are 4 possible rotations), an index between 0 and `num_rows-2`
(the possible row coordinates for placing a block) and an index between 0 and `num_cols-2`
(the possible column coordinates for placing a block). An action thus consists of four pieces of
information:

- Block to place,

- Number of 90 degree rotations to make to a chosen block ({0, 90, 180, 270} degrees),

- Row coordinate for placing the rotated block's top left corner,

- Column coordinate for placing the rotated block's top left corner.


## Reward
The reward function is configurable, but by default is a fully dense reward giving the sum of the number of non-zero
cells in a placed block normalised by the total number of cells in the grid at each timestep. The episode
terminates if either the grid is filled or `num_blocks` steps have been taken by an agent.


## Registered Versions 📖
- `FlatPack-v0`, a flat pack environment grid with 11 rows and 11 columns containing 5 row blocks and 5 column blocks
for a total of 25 blocks that can be placed on the grid. This version has a dense reward.
8 changes: 7 additions & 1 deletion examples/load_checkpoints.ipynb
clement-bonnet marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"metadata": {
"collapsed": false
},
"source": [
"## Load configs"
]
Expand Down Expand Up @@ -194,6 +197,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -243,6 +247,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -279,6 +284,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down
4 changes: 4 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@
# given in the observation.
register(id="BinPack-v2", entry_point="jumanji.environments:BinPack")

# 2D grid filling problem with 25 blocks, an 11x11 grid and a random grid generator.
# The grid must be filled in `num_blocks` steps.
register(id="FlatPack-v0", entry_point="jumanji.environments:FlatPack")

# Job-shop scheduling problem with 20 jobs, 10 machines, at most
# 8 operations per job, and a max operation duration of 6 timesteps.
register(id="JobShop-v0", entry_point="jumanji.environments:JobShop")
Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from jumanji.environments.logic.minesweeper import Minesweeper
from jumanji.environments.logic.rubiks_cube import RubiksCube
from jumanji.environments.logic.sudoku import Sudoku
from jumanji.environments.packing import bin_pack, job_shop, knapsack, tetris
from jumanji.environments.packing import bin_pack, flat_pack, job_shop, knapsack, tetris
from jumanji.environments.packing.bin_pack.env import BinPack
from jumanji.environments.packing.flat_pack.env import FlatPack
from jumanji.environments.packing.job_shop.env import JobShop
from jumanji.environments.packing.knapsack.env import Knapsack
from jumanji.environments.packing.tetris.env import Tetris
Expand Down
16 changes: 16 additions & 0 deletions jumanji/environments/packing/flat_pack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jumanji.environments.packing.flat_pack.env import FlatPack
from jumanji.environments.packing.flat_pack.types import Observation, State
162 changes: 162 additions & 0 deletions jumanji/environments/packing/flat_pack/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax
import jax.numpy as jnp
import pytest


@pytest.fixture
def key() -> chex.PRNGKey:
"""A determinstic key."""

return jax.random.PRNGKey(0)


@pytest.fixture
def block() -> chex.Array:
"""A mock block for testing."""

return jnp.array(
[
[0.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
[0.0, 0.0, 1.0],
]
)


@pytest.fixture
def solved_grid() -> chex.Array:
"""A mock solved grid for testing."""

return jnp.array(
[
[1.0, 1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0, 2.0],
[3.0, 1.0, 4.0, 4.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 4.0, 4.0],
],
)


@pytest.fixture
def grid_with_block_one_placed() -> chex.Array:
"""A grid with only block one placed."""

return jnp.array(
[
[1.0, 1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
],
)


@pytest.fixture()
def block_one_placed_at_0_0(grid_with_block_one_placed: chex.Array) -> chex.Array:
"""A 2D array of zeros where block one has been placed with it left top-most
corner at position (0, 0).
"""

return grid_with_block_one_placed


@pytest.fixture()
def block_one_placed_at_1_1(grid_with_block_one_placed: chex.Array) -> chex.Array:
"""A 2D array of zeros where block one has been placed with it left top-most
corner at position (1, 1).
"""

# Shift all elements in the array one down and one to the right
partially_placed_block = jnp.roll(grid_with_block_one_placed, shift=1, axis=0)
partially_placed_block = jnp.roll(partially_placed_block, shift=1, axis=1)

return partially_placed_block


@pytest.fixture()
def action_mask_with_block_1_placed() -> chex.Array:
"""Action mask for a 4 piece grid where only block 1 has been placed with its
left top-most corner at (1, 1).
"""

return jnp.array(
[
[
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
],
[
[[False, False, True], [False, False, True], [False, True, True]],
[[False, False, True], [False, True, True], [False, True, True]],
[[False, False, False], [False, False, True], [True, False, True]],
[[False, False, False], [False, False, True], [False, False, True]],
],
[
[[False, False, False], [False, False, True], [True, False, True]],
[[False, False, False], [False, False, True], [False, False, True]],
[[False, False, False], [False, False, True], [False, False, True]],
[[False, False, True], [False, True, True], [True, True, True]],
],
[
[[False, False, False], [False, False, True], [False, False, True]],
[[False, False, True], [False, False, True], [False, True, True]],
[[False, False, False], [False, False, True], [False, False, True]],
[[False, False, True], [False, False, True], [False, True, True]],
],
]
)


@pytest.fixture()
def action_mask_without_only_block_1_placed() -> chex.Array:
"""Action mask for a 4 piece grid where only block 1 can be placed with its
left top-most corner at (1, 1).
"""

return jnp.array(
[
[
[[True, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
],
[
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
],
[
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
],
[
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
[[False, False, False], [False, False, False], [False, False, False]],
],
]
)
Loading
Loading