diff --git a/pettingzoo/utils/wrappers/supersuit/__init__.py b/pettingzoo/utils/wrappers/supersuit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/utils/wrappers/supersuit/basic_wrappers.py b/pettingzoo/utils/wrappers/supersuit/basic_wrappers.py new file mode 100644 index 000000000..1ce7ffaae --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/basic_wrappers.py @@ -0,0 +1,95 @@ +from reward_lambda import reward_lambda_v0, AecRewardLambda +from observation_lambda import observation_lambda_v0, AecObservationLambda +from utils.basic_transforms import color_reduction +from typing import Literal, Any +from types import ModuleType +from pettingzoo import AECEnv, ParallelEnv +from gymnasium.spaces import Space +import numpy as np + + +def basic_obs_wrapper(env: AECEnv | ParallelEnv, module: ModuleType, param: Any) -> AecObservationLambda: + """ + Wrap an environment to modify its observation space and observations using a specified module and parameter. + + This function takes an environment, a module, and a parameter, and creates a new environment with an observation + space and observations modified based on the provided module and parameter. + + Parameters: + - env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped. + - module: The module responsible for modifying the observation space and observations. + - param: The parameter used to modify the observation space and observations. + + Returns: + - AecObservationLambda: A wrapped environment that applies the observation space and observation modifications. #TODO fix this line + + Example: + ```python + modified_env = basic_obs_wrapper(original_env, my_module, my_param) + ``` + In the above example, `modified_env` is a new environment that has its observation space and observations modified + according to the `my_module` and `my_param`. + """ + + def change_space(space: Space): # Box? + module.check_param(space, param) + space = module.change_obs_space(space, param) + return space + + def change_obs(obs: np.ndarray, obs_space: Space): # not sure about ndarray + return module.change_observation(obs, obs_space, param) + + return observation_lambda_v0(env, change_obs, change_space) + + +def color_reduction_v0(env: AECEnv | ParallelEnv, mode: Literal["full", "R", "G", "B"] = "full") -> AecObservationLambda: + """ + Wrap an environment to perform color reduction on its observations. + + This function takes an environment and an optional mode to specify the color reduction technique. It then creates + a new environment that performs color reduction on the observations based on the specified mode. + + Parameters: + - env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped. + - mode (Union[str, color_reduction.COLOR_RED_LIST], optional): The color reduction mode to apply (default is "full"). + Valid modes are defined in the color_reduction module. + + Returns: + - AecObservationLambda: A wrapped environment that applies color reduction to its observations. #TODO fix this line + + Example: + ```python + reduced_color_env = color_reduction_v0(original_env, mode="grayscale") + ``` + In the above example, `reduced_color_env` is a new environment that performs grayscale color reduction on its + observations. + """ + + return basic_obs_wrapper(env, color_reduction, mode) + + +def clip_reward_v0(env: AECEnv | ParallelEnv, lower_bound: float = -1, upper_bound: float = 1) -> AecRewardLambda: + """ + Clip rewards in an environment using the specified lower and upper bounds. + + This function applies a reward clipping transformation to an environment's rewards. It takes an environment and + two optional bounds: `lower_bound` and `upper_bound`. Any reward in the environment that falls below the + `lower_bound` will be set to `lower_bound`, and any reward that exceeds the `upper_bound` will be set to + `upper_bound`. Rewards within the specified range are left unchanged. + + Parameters: + - env (Generic[AgentID, ObsType, ActionType]): The environment on which to apply the reward clipping. + - lower_bound (float, optional): The lower bound for clipping rewards (default is -1). + - upper_bound (float, optional): The upper bound for clipping rewards (default is 1). + + Returns: + - AecRewardLambda: A reward transformation function that applies the specified reward clipping when called. #TODO fix this line + + Example: + ```python + clipped_env = clip_reward_v0(my_environment, lower_bound=-0.5, upper_bound=0.5) + ``` + In the above example, the rewards in `my_environment` will be clipped to the range [-0.5, 0.5]. + """ + + return reward_lambda_v0(env, lambda rew: max(min(rew, upper_bound), lower_bound)) diff --git a/pettingzoo/utils/wrappers/supersuit/observation_lambda.py b/pettingzoo/utils/wrappers/supersuit/observation_lambda.py new file mode 100644 index 000000000..295bc1313 --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/observation_lambda.py @@ -0,0 +1,140 @@ +import functools +import numpy as np +from gymnasium.spaces import Box, Discrete +from utils.base_aec_wrapper import BaseWrapper +from typing import Callable +from pettingzoo import AECEnv, ParallelEnv +from pettingzoo.utils.env import ActionType, AgentID + + +class AecObservationLambda(BaseWrapper): + """ + A wrapper for AEC environments that allows the modification of observation spaces and observations. + + Args: + env (AECEnv | ParallelEnv): The environment to be wrapped. + change_observation_fn (Callable): A function that modifies observations. + change_obs_space_fn (Callable, optional): A function that modifies observation spaces. Default is None. + + Raises: + AssertionError: If `change_observation_fn` is not callable, or if `change_obs_space_fn` is provided and is not callable. + + Note: + - The `change_observation_fn` should be a function that accepts observation data and optionally the observation space and agent ID as arguments and returns a modified observation. + - The `change_obs_space_fn` should be a function that accepts an old observation space and optionally the agent ID as arguments and returns a modified observation space. + + Attributes: + change_observation_fn (Callable): The function used to modify observations. + change_obs_space_fn (Callable, optional): The function used to modify observation spaces. + + Methods: + _modify_action(agent: str, action: Discrete) -> Discrete: + Modify the action. + + _check_wrapper_params() -> None: + Check wrapper parameters for consistency. + + observation_space(agent: str) -> Box: + Get the modified observation space for a specific agent. + + _modify_observation(agent: str, observation: Box) -> Box: + Modify the observation. + + """ + def __init__(self, env: AECEnv | ParallelEnv, change_observation_fn: Callable, change_obs_space_fn: Callable = None): + assert callable( + change_observation_fn + ), "change_observation_fn needs to be a function. It is {}".format( + change_observation_fn + ) + assert change_obs_space_fn is None or callable( + change_obs_space_fn + ), "change_obs_space_fn needs to be a function. It is {}".format( + change_obs_space_fn + ) + + self.change_observation_fn = change_observation_fn + self.change_obs_space_fn = change_obs_space_fn + + super().__init__(env) + + if hasattr(self, "possible_agents"): + for agent in self.possible_agents: + # call any validation logic in this function + self.observation_space(agent) + + def _modify_action(self, agent: AgentID, action: ActionType) -> ActionType: + """ + Modify the action. + + Args: + agent (str): The agent for which to modify the action. + action (Discrete): The original action. + + Returns: + Discrete: The modified action. + """ + return action + + def _check_wrapper_params(self) -> None: + """ + Check wrapper parameters for consistency. + + Raises: + AssertionError: If the provided parameters are inconsistent. + """ + if self.change_obs_space_fn is None and hasattr(self, "possible_agents"): + for agent in self.possible_agents: + assert isinstance( + self.observation_space(agent), Box + ), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces" + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent: AgentID) -> Box: + """ + Get the modified observation space for a specific agent. + + Args: + agent (str): The agent for which to retrieve the observation space. + + Returns: + Box: The modified observation space. + """ + if self.change_obs_space_fn is None: + space = self.env.observation_space(agent) + try: + trans_low = self.change_observation_fn(space.low, space, agent) + trans_high = self.change_observation_fn(space.high, space, agent) + except TypeError: + trans_low = self.change_observation_fn(space.low, space) + trans_high = self.change_observation_fn(space.high, space) + new_low = np.minimum(trans_low, trans_high) + new_high = np.maximum(trans_low, trans_high) + + return Box(low=new_low, high=new_high, dtype=new_low.dtype) + else: + old_obs_space = self.env.observation_space(agent) + try: + return self.change_obs_space_fn(old_obs_space, agent) + except TypeError: + return self.change_obs_space_fn(old_obs_space) + + def _modify_observation(self, agent: AgentID, observation: Box) -> Box: + """ + Modify the observation. + + Args: + agent (str): The agent for which to modify the observation. + observation (Box): The original observation. + + Returns: + Box: The modified observation. + """ + old_obs_space = self.env.observation_space(agent) + try: + return self.change_observation_fn(observation, old_obs_space, agent) + except TypeError: + return self.change_observation_fn(observation, old_obs_space) + + +observation_lambda_v0 = AecObservationLambda diff --git a/pettingzoo/utils/wrappers/supersuit/reward_lambda.py b/pettingzoo/utils/wrappers/supersuit/reward_lambda.py new file mode 100644 index 000000000..df347e1f3 --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/reward_lambda.py @@ -0,0 +1,92 @@ +from utils.base_aec_wrapper import PettingzooWrap +from utils.make_defaultdict import make_defaultdict +from typing import Callable +from pettingzoo import AECEnv, ParallelEnv +from pettingzoo.utils.env import ActionType + + +class AecRewardLambda(PettingzooWrap): + """ + A wrapper for AEC environments that allows the modification of rewards. + + Args: + env (AECEnv | ParallelEnv): The environment to be wrapped. + change_reward_fn (Callable): A function that modifies rewards. + + Raises: + AssertionError: If `change_reward_fn` is not callable. + + Attributes: + _change_reward_fn (Callable): The function used to modify rewards. + + Methods: + reset(seed: int = None, options: dict = None) -> None: + Reset the environment, applying the reward modification to initial rewards. + + step(action: ActionType) -> None: + Take a step in the environment, applying the reward modification to the received rewards. + + """ + def __init__(self, env: AECEnv | ParallelEnv, change_reward_fn: Callable): + assert callable( + change_reward_fn + ), f"change_reward_fn needs to be a function. It is {change_reward_fn}" + self._change_reward_fn = change_reward_fn + + super().__init__(env) + + def _check_wrapper_params(self) -> None: + """ + Check wrapper parameters for consistency. + + This method is currently empty and does not perform any checks. + """ + pass + + def _modify_spaces(self) -> None: + """ + Modify the spaces of the wrapped environment. + + This method is currently empty and does not modify the spaces. + """ + pass + + def reset(self, seed: int = None, options: dict = None) -> None: + """ + Reset the environment, applying the reward modification to initial rewards. + + Args: + seed (int, optional): A seed for environment randomization. Default is None. + options (dict, optional): Additional options for environment initialization. Default is None. + """ + super().reset(seed=seed, options=options) + self.rewards = { + agent: self._change_reward_fn(reward) + for agent, reward in self.rewards.items() + } + self.__cumulative_rewards = make_defaultdict({a: 0 for a in self.agents}) + self._accumulate_rewards() + + def step(self, action: ActionType) -> None: + """ + Take a step in the environment, applying the reward modification to the received rewards. + + Args: + action (ActionType): The action to be taken in the environment. + """ + agent = self.env.agent_selection + super().step(action) + self.rewards = { + agent: self._change_reward_fn(reward) + for agent, reward in self.rewards.items() + } + self.__cumulative_rewards[agent] = 0 + self._cumulative_rewards = self.__cumulative_rewards + self._accumulate_rewards() + + +reward_lambda_v0 = AecRewardLambda +""" example: +reward_lambda_v0 = WrapperChooser( + aec_wrapper=AecRewardLambda, par_wrapper=ParRewardLambda +)""" diff --git a/pettingzoo/utils/wrappers/supersuit/test/__init__.py b/pettingzoo/utils/wrappers/supersuit/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/utils/wrappers/supersuit/test/dummy_aec_env.py b/pettingzoo/utils/wrappers/supersuit/test/dummy_aec_env.py new file mode 100644 index 000000000..ae2f20964 --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/test/dummy_aec_env.py @@ -0,0 +1,57 @@ +from pettingzoo import AECEnv +from pettingzoo.utils.agent_selector import agent_selector + + +class DummyEnv(AECEnv): + metadata = {"render_modes": ["human"], "is_parallelizable": True} + + def __init__(self, observations, observation_spaces, action_spaces): + super().__init__() + self._observations = observations + self._observation_spaces = observation_spaces + + self.agents = sorted([x for x in observation_spaces.keys()]) + self.possible_agents = self.agents[:] + self._agent_selector = agent_selector(self.agents) + self.agent_selection = self._agent_selector.reset() + self._action_spaces = action_spaces + + self.steps = 0 + + def observation_space(self, agent): + return self._observation_spaces[agent] + + def action_space(self, agent): + return self._action_spaces[agent] + + def observe(self, agent): + return self._observations[agent] + + def step(self, action, observe=True): + if ( + self.terminations[self.agent_selection] + or self.truncations[self.agent_selection] + ): + return self._was_dead_step(action) + self._cumulative_rewards[self.agent_selection] = 0 + self.agent_selection = self._agent_selector.next() + self.steps += 1 + if self.steps >= 5 * len(self.agents): + self.truncations = {a: True for a in self.agents} + + self._accumulate_rewards() + self._deads_step_first() + + def reset(self, seed=None, options=None): + self.agents = self.possible_agents[:] + self._agent_selector = agent_selector(self.agents) + self.agent_selection = self._agent_selector.reset() + self.rewards = {a: 1 for a in self.agents} + self._cumulative_rewards = {a: 0 for a in self.agents} + self.terminations = {a: False for a in self.agents} + self.truncations = {a: False for a in self.agents} + self.infos = {a: {} for a in self.agents} + self.steps = 0 + + def close(self): + pass diff --git a/pettingzoo/utils/wrappers/supersuit/test/test_wrappers.py b/pettingzoo/utils/wrappers/supersuit/test/test_wrappers.py new file mode 100644 index 000000000..d85effe93 --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/test/test_wrappers.py @@ -0,0 +1,140 @@ +from pettingzoo.utils.wrappers.supersuit.basic_wrappers import color_reduction_v0, clip_reward_v0 +import numpy as np +from gymnasium.spaces import Box, Discrete +from pettingzoo.utils.wrappers import OrderEnforcingWrapper as PettingzooWrap +from pettingzoo.utils.wrappers.supersuit.test.dummy_aec_env import DummyEnv +import pytest +from pettingzoo.utils.wrappers.supersuit.reward_lambda import reward_lambda_v0 +from pettingzoo.utils.wrappers.supersuit.observation_lambda import observation_lambda_v0 + + +def new_base_env(): + base_obs = { + f"a{idx}": np.zeros([8, 8, 3], dtype=np.float32) + np.arange(3) + idx + for idx in range(2) + } + base_obs_space = { + f"a{idx}": Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3]) + for idx in range(2) + } + base_act_spaces = {f"a{idx}": Discrete(5) for idx in range(2)} + + return DummyEnv(base_obs, base_obs_space, base_act_spaces) + + +def new_dummy(): + base_obs = { + f"a_{idx}": (np.zeros([8, 8, 3], dtype=np.float32) + np.arange(3) + idx).astype( + np.float32 + ) + for idx in range(2) + } + base_obs_space = { + f"a_{idx}": Box(low=np.float32(0.0), high=np.float32(10.0), shape=[8, 8, 3]) + for idx in range(2) + } + base_act_spaces = {f"a_{idx}": Discrete(5) for idx in range(2)} + + return PettingzooWrap(DummyEnv(base_obs, base_obs_space, base_act_spaces)) + + +wrappers = [ + color_reduction_v0(new_dummy(), "R"), + clip_reward_v0(new_dummy()), + reward_lambda_v0(new_dummy(), lambda x: x / 10), +] + + +@pytest.mark.parametrize("env", wrappers) +def test_basic_wrappers(env): + env.reset(seed=5) + obs, _, _, _, _ = env.last() + act_space = env.action_space(env.agent_selection) + obs_space = env.observation_space(env.agent_selection) + first_obs = env.observe("a_0") + assert obs_space.contains(first_obs) + assert first_obs.dtype == obs_space.dtype + env.step(act_space.sample()) + for agent in env.agent_iter(): + act_space = env.action_space(env.agent_selection) + env.step( + act_space.sample() + if not (env.truncations[agent] or env.terminations[agent]) + else None + ) + + +def test_rew_lambda(): + env = reward_lambda_v0(new_dummy(), lambda x: x / 10) + env.reset() + assert env.rewards[env.agent_selection] == 1.0 / 10 + + +def test_observation_lambda(): + def add1(obs, obs_space): + return obs + 1 + + base_env = new_base_env() + env = observation_lambda_v0(base_env, add1) + env.reset() + obs0, _, _, _, _ = env.last() + assert int(obs0[0][0][0]) == 1 + env = observation_lambda_v0(env, add1) + env.reset() + obs0, _, _, _, _ = env.last() + assert int(obs0[0][0][0]) == 2 + + def tile_obs(obs, obs_space): + shape_size = len(obs.shape) + tile_shape = [1] * shape_size + tile_shape[0] *= 2 + return np.tile(obs, tile_shape) + + env = observation_lambda_v0(env, tile_obs) + env.reset() + obs0, _, _, _, _ = env.last() + assert env.observation_space(env.agent_selection).shape == (16, 8, 3) + + def change_shape_fn(obs_space): + return Box(low=0, high=1, shape=(32, 8, 3)) + + env = observation_lambda_v0(env, tile_obs) + env.reset() + obs0, _, _, _, _ = env.last() + assert env.observation_space(env.agent_selection).shape == (32, 8, 3) + assert obs0.shape == (32, 8, 3) + + base_env = new_base_env() + env = observation_lambda_v0( + base_env, + lambda obs, obs_space, agent: obs + base_env.possible_agents.index(agent), + ) + env.reset() + obs0 = env.observe(env.agents[0]) + obs1 = env.observe(env.agents[1]) + + assert int(obs0[0][0][0]) == 0 + assert int(obs1[0][0][0]) == 2 + assert ( + env.observation_space(env.agents[0]).high + 1 + == env.observation_space(env.agents[1]).high + ).all() + + base_env = new_base_env() + env = observation_lambda_v0( + base_env, + lambda obs, obs_space, agent: obs + base_env.possible_agents.index(agent), + lambda obs_space, agent: Box( + obs_space.low, obs_space.high + base_env.possible_agents.index(agent) + ), + ) + env.reset() + obs0 = env.observe(env.agents[0]) + obs1 = env.observe(env.agents[1]) + + assert int(obs0[0][0][0]) == 0 + assert int(obs1[0][0][0]) == 2 + assert ( + env.observation_space(env.agents[0]).high + 1 + == env.observation_space(env.agents[1]).high + ).all() diff --git a/pettingzoo/utils/wrappers/supersuit/utils/base_aec_wrapper.py b/pettingzoo/utils/wrappers/supersuit/utils/base_aec_wrapper.py new file mode 100644 index 000000000..0fbc001a6 --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/utils/base_aec_wrapper.py @@ -0,0 +1,115 @@ +from pettingzoo.utils.wrappers import OrderEnforcingWrapper as PettingzooWrap +from pettingzoo import AECEnv, ParallelEnv +from pettingzoo.utils.env import ActionType, AgentID +from gymnasium.spaces import Space + + +class BaseWrapper(PettingzooWrap): + def __init__(self, env: AECEnv | ParallelEnv): + """ + Creates a wrapper around `env`. Extend this class to create changes to the space. + """ + super().__init__(env) + + self._check_wrapper_params() + + self._modify_spaces() + + def _check_wrapper_params(self) -> None: + """ + Check wrapper parameters for consistency. + + This method is currently empty and does not perform any checks. + """ + pass + + def _modify_spaces(self) -> None: + """ + Modify the spaces of the wrapped environment. + + This method is currently empty and does not modify the spaces. + """ + pass + + def _modify_action(self, agent: AgentID, action: ActionType) -> None: + """ + Modify the action for the given agent. + + This method should be implemented by subclasses. + + Args: + agent (AgentID): The agent for which to modify the action. + action (ActionType): The original action to be modified. + + Raises: + NotImplementedError: This method should be implemented in subclasses. + """ + raise NotImplementedError() + + def _modify_observation(self, agent: AgentID, observation: Space) -> None: + """ + Modify the observation for the given agent. + + This method should be implemented by subclasses. + + Args: + agent (AgentID): The agent for which to modify the observation. + observation (Space): The original observation to be modified. + + Raises: + NotImplementedError: This method should be implemented in subclasses. + """ + raise NotImplementedError() + + def _update_step(self, agent: AgentID) -> None: + """ + Update the step for the given agent. + + This method can be implemented by subclasses if needed. + + Args: + agent (AgentID): The agent for which to update the step. + """ + pass + + def reset(self, seed: int = None, options: dict = None) -> None: + """ + Reset the environment, optionally setting a seed and providing additional options. + + Args: + seed (int, optional): A seed for environment randomization. Default is None. + options (dict, optional): Additional options for environment initialization. Default is None. + """ + super().reset(seed=seed, options=options) + self._update_step(self.agent_selection) + + def observe(self, agent: AgentID) -> Space: + """ + Observe the environment's state for the specified agent, modifying the observation if needed. + + Args: + agent (AgentID): The agent for which to observe the environment. + + Returns: + Space: The modified observation of the environment state for the specified agent. + """ + obs = super().observe( + agent + ) # problem is in this line, the obs is sometimes a different size from the obs space + observation = self._modify_observation(agent, obs) + return observation + + def step(self, action: ActionType) -> None: + """ + Take a step in the environment with the given action, modifying the action if required. + + Args: + action (ActionType): The action to be taken in the environment. + """ + agent = self.env.agent_selection + if not self.terminations[agent] or self.truncations[agent]: + action = self._modify_action(agent, action) + + super().step(action) + + self._update_step(self.agent_selection) diff --git a/pettingzoo/utils/wrappers/supersuit/utils/basic_transforms/color_reduction.py b/pettingzoo/utils/wrappers/supersuit/utils/basic_transforms/color_reduction.py new file mode 100644 index 000000000..735175a4d --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/utils/basic_transforms/color_reduction.py @@ -0,0 +1,70 @@ +import numpy as np +from pettingzoo.utils.wrappers.supersuit.utils.convert_box import convert_box +from gymnasium.spaces import Box + + +COLOR_RED_LIST = ["full", "R", "G", "B"] +GRAYSCALE_WEIGHTS = np.array([0.299, 0.587, 0.114], dtype=np.float32) + + +def check_param(space: Box, color_reduction: str) -> None: + """ + Check if the provided parameters are valid for color reduction. #TODO fix this line + + Args: + space (Box): The observation space as a Box object. + color_reduction (str): The desired color reduction method. + + Raises: + AssertionError: If the color_reduction is not a string or is not in the COLOR_RED_LIST, + or if the shape of the space is not a 3D image with the last dimension + having a size of 3. + """ + assert isinstance( + color_reduction, str + ), f"color_reduction must be str. It is {color_reduction}" + assert color_reduction in COLOR_RED_LIST, "color_reduction must be in {}".format( + COLOR_RED_LIST + ) + assert ( + len(space.low.shape) == 3 and space.low.shape[2] == 3 + ), "To apply color_reduction, shape must be a 3d image with last dimension of size 3. Shape is {}".format( + space.low.shape + ) + + +def change_obs_space(obs_space: Box, param: str) -> Box: + """ + Change the observation space based on a color reduction parameter. #TODO fix this line + + Args: + obs_space (Box): The original observation space as a Box object. + param (str): The color reduction parameter. + + Returns: + Box: The modified observation space. + """ + return convert_box(lambda obs: change_observation(obs, obs_space, param), obs_space) + + +def change_observation(obs: np.ndarray, obs_space: Box, color_reduction: str) -> np.ndarray: + """ + Apply color reduction to an observation based on the specified color_reduction parameter. #TODO fix this line + + Args: + obs (np.ndarray): The input observation as a NumPy array. + obs_space (Box): The original observation space as a Box object. + color_reduction (str): The color reduction method to be applied. + + Returns: + np.ndarray: The modified observation after applying color reduction. + """ + if color_reduction == "R": + obs = obs[:, :, 0] + if color_reduction == "G": + obs = obs[:, :, 1] + if color_reduction == "B": + obs = obs[:, :, 2] + if color_reduction == "full": + obs = (obs.astype(np.float32) @ GRAYSCALE_WEIGHTS).astype(np.uint8) + return obs diff --git a/pettingzoo/utils/wrappers/supersuit/utils/basic_transforms/dtype.py b/pettingzoo/utils/wrappers/supersuit/utils/basic_transforms/dtype.py new file mode 100644 index 000000000..e69de29bb diff --git a/pettingzoo/utils/wrappers/supersuit/utils/convert_box.py b/pettingzoo/utils/wrappers/supersuit/utils/convert_box.py new file mode 100644 index 000000000..7ba06b4cc --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/utils/convert_box.py @@ -0,0 +1,21 @@ +from gymnasium.spaces import Box +from typing import Callable + + +def convert_box(convert_obs_fn: Callable, old_box: Box) -> Box: + """ + Convert the bounds of a Box object using a given conversion function. + + Args: + convert_obs_fn (Callable): A function that takes an ndarray and returns a transformed ndarray. + old_box (Box): The original Box object to be converted. + + Returns: + Box: A new Box object with transformed lower and upper bounds based on the conversion function. + + Note: + The `convert_obs_fn` should be a function that accepts and returns NumPy ndarrays of the same shape. + """ + new_low = convert_obs_fn(old_box.low) + new_high = convert_obs_fn(old_box.high) + return Box(low=new_low, high=new_high, dtype=new_low.dtype) diff --git a/pettingzoo/utils/wrappers/supersuit/utils/make_defaultdict.py b/pettingzoo/utils/wrappers/supersuit/utils/make_defaultdict.py new file mode 100644 index 000000000..c91fda20e --- /dev/null +++ b/pettingzoo/utils/wrappers/supersuit/utils/make_defaultdict.py @@ -0,0 +1,26 @@ +import warnings +from collections import defaultdict +from typing import Union + + +def make_defaultdict(d: dict) -> Union[defaultdict, dict]: + """ + Create a defaultdict with the same types as the input dictionary or return an empty dictionary. + + Args: + d (dict): The input dictionary. + + Returns: + defaultdict: A defaultdict with the same types as the input dictionary or a dict + + Note: + If the input dictionary is empty, a warning is issued, and an empty dictionary is returned. + """ + try: + dd = defaultdict(type(next(iter(d.values())))) + for k, v in d.items(): + dd[k] = v + return dd + except StopIteration: + warnings.warn("No agents left in the environment!") + return {}