From 075267a55fe01671bcbb65c3034ce5c38081f015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20PIERR=C3=89?= <6053592+kir0ul@users.noreply.github.com> Date: Thu, 24 Nov 2022 15:46:38 -0500 Subject: [PATCH] [FrozenLake] Add seed in random map generation (#139) --- gymnasium/envs/toy_text/frozen_lake.py | 10 ++++++++-- tests/envs/test_env_implementation.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/gymnasium/envs/toy_text/frozen_lake.py b/gymnasium/envs/toy_text/frozen_lake.py index cd4f91861..65e51f8cd 100644 --- a/gymnasium/envs/toy_text/frozen_lake.py +++ b/gymnasium/envs/toy_text/frozen_lake.py @@ -9,6 +9,7 @@ from gymnasium import Env, spaces, utils from gymnasium.envs.toy_text.utils import categorical_sample from gymnasium.error import DependencyNotInstalled +from gymnasium.utils import seeding LEFT = 0 DOWN = 1 @@ -51,12 +52,15 @@ def is_valid(board: List[List[str]], max_size: int) -> bool: return False -def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]: +def generate_random_map( + size: int = 8, p: float = 0.8, seed: Optional[int] = None +) -> List[str]: """Generates a random valid map (one that has a path from start to goal) Args: size: size of each side of the grid p: probability that a tile is frozen + seed: optional seed to ensure the generation of reproducible maps Returns: A random valid map @@ -64,9 +68,11 @@ def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]: valid = False board = [] # initialize to make pyright happy + np_random, _ = seeding.np_random(seed) + while not valid: p = min(1, p) - board = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p]) + board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p]) board[0][0] = "S" board[-1][-1] = "G" valid = is_valid(board, size) diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index 708860622..6de341046 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -126,6 +126,16 @@ def test_frozenlake_dfs_map_generation(map_size: int): raise AssertionError("No path through the frozenlake was found.") +@pytest.mark.parametrize("map_size, seed", [(5, 123), (10, 42), (16, 987)]) +def test_frozenlake_map_generation_with_seed(map_size: int, seed: int): + map1 = generate_random_map(size=map_size, seed=seed) + map2 = generate_random_map(size=map_size, seed=seed) + assert map1 == map2 + map1 = generate_random_map(size=map_size, seed=seed) + map2 = generate_random_map(size=map_size, seed=seed + 1) + assert map1 != map2 + + def test_taxi_action_mask(): env = TaxiEnv()