diff --git a/jumanji/environments/logic/sliding_tile_puzzle/constants.py b/jumanji/environments/logic/sliding_tile_puzzle/constants.py new file mode 100644 index 000000000..bcd00a544 --- /dev/null +++ b/jumanji/environments/logic/sliding_tile_puzzle/constants.py @@ -0,0 +1,15 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +EMPTY_TILE = 0 diff --git a/jumanji/environments/logic/sliding_tile_puzzle/env.py b/jumanji/environments/logic/sliding_tile_puzzle/env.py index 23d42b307..001b582c5 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/env.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/env.py @@ -84,7 +84,8 @@ def __init__( reward_fn: RewardFn whose `__call__` method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. - Implemented options are [`DenseRewardFn`, `SparseRewardFn`]. Defaults to `DenseRewardFn`. + Implemented options are [`DenseRewardFn`, `SparseRewardFn`]. + Defaults to `DenseRewardFn`. viewer: environment viewer for rendering. """ self.generator = generator or RandomGenerator(grid_size=5) diff --git a/jumanji/environments/logic/sliding_tile_puzzle/generator.py b/jumanji/environments/logic/sliding_tile_puzzle/generator.py index 99c032dde..6923730ae 100644 --- a/jumanji/environments/logic/sliding_tile_puzzle/generator.py +++ b/jumanji/environments/logic/sliding_tile_puzzle/generator.py @@ -19,6 +19,8 @@ import jax from jax import numpy as jnp +from jumanji.environments.logic.sliding_tile_puzzle.constants import EMPTY_TILE + class Generator(abc.ABC): @property @@ -76,7 +78,7 @@ def __call__(self, key: chex.PRNGKey) -> Tuple[chex.Array, chex.Array]: # Find the position of the empty tile empty_tile_position = jnp.stack( - jnp.unravel_index(jnp.argmax(puzzle == 0), puzzle.shape) + jnp.unravel_index(jnp.argmax(puzzle == EMPTY_TILE), puzzle.shape) ) return puzzle, empty_tile_position @@ -127,7 +129,7 @@ def __call__(self, key: chex.PRNGKey) -> Tuple[chex.Array, chex.Array]: # Find the position of the empty tile empty_tile_position = jnp.stack( - jnp.unravel_index(jnp.argmax(puzzle == 0), puzzle.shape) + jnp.unravel_index(jnp.argmax(puzzle == EMPTY_TILE), puzzle.shape) ) return puzzle, empty_tile_position