From df75a5be2ad1e45a14ed87d60e7c3fba59a4f056 Mon Sep 17 00:00:00 2001 From: Roxana Radulescu <8026679+rradules@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:06:00 +0100 Subject: [PATCH] map randomisation at each episode and argument --- .../envs/item_gathering/item_gathering.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/momadm_benchmarks/envs/item_gathering/item_gathering.py b/momadm_benchmarks/envs/item_gathering/item_gathering.py index 3e8caf58..148fc3eb 100644 --- a/momadm_benchmarks/envs/item_gathering/item_gathering.py +++ b/momadm_benchmarks/envs/item_gathering/item_gathering.py @@ -24,7 +24,7 @@ from gymnasium.spaces import Box, Discrete from pettingzoo.utils import wrappers -from momadm_benchmarks.envs.item_gathering.map_utils import DEFAULT_MAP +from momadm_benchmarks.envs.item_gathering.map_utils import DEFAULT_MAP, randomise_map from momadm_benchmarks.utils.conversions import mo_parallel_to_aec from momadm_benchmarks.utils.env import MOParallelEnv @@ -81,6 +81,7 @@ def __init__( self, num_timesteps=10, initial_map=DEFAULT_MAP, + randomise=False, render_mode=None, ): """Initializes the item gathering domain. @@ -88,11 +89,13 @@ def __init__( Args: num_timesteps: number of timesteps to run the environment for initial_map: map of the environment + randomise: whether to randomise the map, at each episode render_mode: render mode for the environment """ self.num_timesteps = num_timesteps self.current_timestep = 0 self.render_mode = render_mode + self.randomise = randomise # check is the initial map has any entries equal to 2 assert ( @@ -184,7 +187,11 @@ def reset(self, seed=None, options=None): self.terminations = {agent: False for agent in self.agents} self.truncations = {agent: False for agent in self.agents} - self.env_map = deepcopy(self.initial_map) # Reset the environment map to the initial map provided + # Reset the environment map + if self.randomise: + self.env_map = deepcopy(randomise_map(self.initial_map)) + else: + self.env_map = deepcopy(self.initial_map) self.agent_positions = np.argwhere(self.env_map == 1) # store agent positions in separate list self.env_map[self.env_map == 1] = 0 # remove agent starting positions from map