Skip to content

Commit

Permalink
[FrozenLake] Add seed in random map generation (openai#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
kir0ul authored Nov 24, 2022
1 parent f747e45 commit 075267a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
10 changes: 8 additions & 2 deletions gymnasium/envs/toy_text/frozen_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from gymnasium import Env, spaces, utils
from gymnasium.envs.toy_text.utils import categorical_sample
from gymnasium.error import DependencyNotInstalled
from gymnasium.utils import seeding

LEFT = 0
DOWN = 1
Expand Down Expand Up @@ -51,22 +52,27 @@ def is_valid(board: List[List[str]], max_size: int) -> bool:
return False


def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
def generate_random_map(
size: int = 8, p: float = 0.8, seed: Optional[int] = None
) -> List[str]:
"""Generates a random valid map (one that has a path from start to goal)
Args:
size: size of each side of the grid
p: probability that a tile is frozen
seed: optional seed to ensure the generation of reproducible maps
Returns:
A random valid map
"""
valid = False
board = [] # initialize to make pyright happy

np_random, _ = seeding.np_random(seed)

while not valid:
p = min(1, p)
board = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p])
board[0][0] = "S"
board[-1][-1] = "G"
valid = is_valid(board, size)
Expand Down
10 changes: 10 additions & 0 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ def test_frozenlake_dfs_map_generation(map_size: int):
raise AssertionError("No path through the frozenlake was found.")


@pytest.mark.parametrize("map_size, seed", [(5, 123), (10, 42), (16, 987)])
def test_frozenlake_map_generation_with_seed(map_size: int, seed: int):
map1 = generate_random_map(size=map_size, seed=seed)
map2 = generate_random_map(size=map_size, seed=seed)
assert map1 == map2
map1 = generate_random_map(size=map_size, seed=seed)
map2 = generate_random_map(size=map_size, seed=seed + 1)
assert map1 != map2


def test_taxi_action_mask():
env = TaxiEnv()

Expand Down

0 comments on commit 075267a

Please sign in to comment.