Skip to content

Commit

Permalink
Merge branch 'master' of github.com:PettingZoo-Team/SuperSuit
Browse files Browse the repository at this point in the history
  • Loading branch information
benblack769 committed Oct 19, 2021
2 parents 101ccea + 24b007e commit 6997b7a
Show file tree
Hide file tree
Showing 20 changed files with 110 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ You can install SuperSuit via `pip install supersuit`

### Required multiagent environment attributes

Many wrappers require an environment to support the optional `possible_agents`, `observation_spaces`, and `action_spaces` attributes. These are required because the
Many wrappers require an environment to support the optional `possible_agents` attribute. These are required because the
wrapper needs to know all the spaces in advance. The following is a complete list of
wrappers which require these attributes:

Expand Down
19 changes: 11 additions & 8 deletions supersuit/aec_vector/async_vector_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import multiprocessing as mp
from pettingzoo.utils.agent_selector import agent_selector
import numpy as np
Expand Down Expand Up @@ -72,7 +71,7 @@ def __init__(self, env_constructors, num_envs):
self.env = self.envs[0]
self.possible_agents = self.env.possible_agents
self.agent_indexes = {agent: i for i, agent in enumerate(self.env.possible_agents)}
self.dead_obss = {agent: np.zeros_like(SpaceWrapper(obs_space).low) for agent, obs_space in self.env.observation_spaces.items()}
self.dead_obss = {agent: np.zeros_like(SpaceWrapper(self.env.observation_space(agent)).low) for agent in self.env.possible_agents}

def reset(self):
for env in self.envs:
Expand Down Expand Up @@ -174,7 +173,7 @@ def env_worker(env_constructors, total_num_envs, idx_start, my_num_envs, agent_a
env = _SeperableAECWrapper(env_constructors, my_num_envs)
shared_datas = {
agent: AgentSharedData(
total_num_envs, SpaceWrapper(env.env.observation_spaces[agent]), SpaceWrapper(env.env.action_spaces[agent]), agent_arrays[agent]
total_num_envs, SpaceWrapper(env.env.observation_space(agent)), SpaceWrapper(env.env.action_space(agent)), agent_arrays[agent]
)
for agent in env.possible_agents
}
Expand Down Expand Up @@ -238,8 +237,6 @@ def __init__(self, env_constructors, num_cpus=None, return_copy=True):
self.env = env = env_constructors[0]()
self.max_num_agents = len(self.env.possible_agents)
self.possible_agents = self.env.possible_agents
self.observation_spaces = copy.copy(self.env.observation_spaces)
self.action_spaces = copy.copy(self.env.action_spaces)
self.order_is_nondeterministic = False
self.num_envs = num_envs

Expand All @@ -250,14 +247,14 @@ def __init__(self, env_constructors, num_cpus=None, return_copy=True):
all_arrays = {
agent: create_shared_data(
num_envs,
SpaceWrapper(self.observation_spaces[agent]),
SpaceWrapper(self.action_spaces[agent]),
SpaceWrapper(self.observation_space(agent)),
SpaceWrapper(self.action_space(agent)),
)
for agent in self.possible_agents
}

self.shared_datas = {
agent: AgentSharedData(num_envs, SpaceWrapper(env.observation_spaces[agent]), SpaceWrapper(env.action_spaces[agent]), all_arrays[agent])
agent: AgentSharedData(num_envs, SpaceWrapper(env.observation_space(agent)), SpaceWrapper(env.action_space(agent)), all_arrays[agent])
for agent in env.possible_agents
}

Expand Down Expand Up @@ -367,6 +364,12 @@ def seed(self, seed):

self._receive_info()

def action_space(self, agent):
return self.env.action_space(agent)

def observation_space(self, agent):
return self.env.observation_space(agent)

def __del__(self):
for cin in self.con_ins:
cin.send(("terminate", None))
Expand Down
2 changes: 1 addition & 1 deletion supersuit/aec_vector/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def vectorize_aec_env_v0(aec_env, num_envs, num_cpus=0):
assert isinstance(aec_env, AECEnv), "pettingzoo_env_to_vec_env takes in a pettingzoo AECEnv."
assert hasattr(aec_env, 'action_spaces') and hasattr(aec_env, 'possible_agents') and hasattr(aec_env, 'observation_spaces'), "environment passed to vectorize_aec_env must have action_spaces, observation_spaces, and possible_agents attributes."
assert hasattr(aec_env, 'possible_agents'), "environment passed to vectorize_aec_env must have possible_agents attribute."

def env_fn():
return cloudpickle.loads(cloudpickle.dumps(aec_env))
Expand Down
11 changes: 7 additions & 4 deletions supersuit/aec_vector/vector_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from pettingzoo.utils.agent_selector import agent_selector
import numpy as np
from .base_aec_vec_env import VectorAECEnv
Expand All @@ -14,10 +13,14 @@ def __init__(self, env_constructors):
self.env = self.envs[0]
self.max_num_agents = self.env.max_num_agents
self.possible_agents = self.env.possible_agents
self.observation_spaces = copy.copy(self.env.observation_spaces)
self.action_spaces = copy.copy(self.env.action_spaces)
self._agent_selector = agent_selector(self.possible_agents)

def action_space(self, agent):
return self.env.action_space(agent)

def observation_space(self, agent):
return self.env.observation_space(agent)

def _find_active_agent(self):
cur_selection = self.agent_selection
while not any(cur_selection == env.agent_selection for env in self.envs):
Expand Down Expand Up @@ -58,7 +61,7 @@ def seed(self, seed=None):
def observe(self, agent):
observations = []
for env in self.envs:
obs = env.observe(agent) if agent in env.dones else np.zeros_like(self.observation_spaces[agent].low)
obs = env.observe(agent) if agent in env.dones else np.zeros_like(env.observation_space(agent).low)
observations.append(obs)
return np.stack(observations)

Expand Down
11 changes: 6 additions & 5 deletions supersuit/generic_wrappers/frame_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def step(self, action):
self.agents = self.env.agents[:]
orig_agents = set(action.keys())

total_reward = make_defaultdict({agent: 0.0 for agent in self.agents})
total_dones = {}
total_infos = {}
total_obs = {}

for x in range(num_skips):
obs, rews, done, info = super().step(action)
if x == 0:
total_reward = make_defaultdict({agent: 0.0 for agent in self.env.agents})
total_dones = {}
total_infos = {}
total_obs = {}

for agent, rew in rews.items():
total_reward[agent] += rew
Expand All @@ -157,6 +157,7 @@ def step(self, action):
del total_infos[agent]
del total_obs[agent]

self.agents = self.env.agents[:]
return total_obs, total_reward, total_dones, total_infos


Expand Down
5 changes: 5 additions & 0 deletions supersuit/generic_wrappers/utils/shared_wrapper_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import gym
from pettingzoo.utils.wrappers import OrderEnforcingWrapper as PettingzooWrap
from supersuit.utils.wrapper_chooser import WrapperChooser
Expand All @@ -13,9 +14,11 @@ def __init__(self, env, modifier_class):
if hasattr(self.env, 'possible_agents'):
self.add_modifiers(self.env.possible_agents)

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent))

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return self.modifiers[agent].modify_action_space(self.env.action_space(agent))

Expand Down Expand Up @@ -65,9 +68,11 @@ def __init__(self, env, modifier_class):
if hasattr(self.env, 'possible_agents'):
self.add_modifiers(self.env.possible_agents)

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
return self.modifiers[agent].modify_obs_space(self.env.observation_space(agent))

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return self.modifiers[agent].modify_action_space(self.env.action_space(agent))

Expand Down
6 changes: 4 additions & 2 deletions supersuit/lambda_wrappers/action_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from gym.spaces import Space
from supersuit.utils.base_aec_wrapper import BaseWrapper
from supersuit.utils.wrapper_chooser import WrapperChooser
Expand All @@ -14,14 +15,15 @@ def __init__(self, env, change_action_fn, change_space_fn):
self.change_space_fn = change_space_fn

super().__init__(env)
if hasattr(self, 'action_spaces'):
for agent in self.action_spaces:
if hasattr(self, 'possible_agents'):
for agent in self.possible_agents:
# call any validation logic in this function
self.action_space(agent)

def _modify_observation(self, agent, observation):
return observation

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
old_act_space = self.env.action_space(agent)
try:
Expand Down
13 changes: 7 additions & 6 deletions supersuit/lambda_wrappers/observation_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import gym
import numpy as np
from gym.spaces import Box, Space
Expand All @@ -15,20 +16,20 @@ def __init__(self, env, change_observation_fn, change_obs_space_fn=None):

super().__init__(env)

if hasattr(self, 'observation_spaces'):
for agent in self.observation_spaces:
if hasattr(self, 'possible_agents'):
for agent in self.possible_agents:
# call any validation logic in this function
self.observation_space(agent)

def _modify_action(self, agent, action):
return action

def _check_wrapper_params(self):
if self.change_obs_space_fn is None and hasattr(self, 'observation_spaces'):
spaces = self.observation_spaces.values()
for space in spaces:
assert isinstance(space, Box), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces"
if self.change_obs_space_fn is None and hasattr(self, 'possible_agents'):
for agent in self.possible_agents:
assert isinstance(self.observation_space(agent), Box), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces"

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
if self.change_obs_space_fn is None:
space = self.env.observation_space(agent)
Expand Down
7 changes: 3 additions & 4 deletions supersuit/multiagent_wrappers/agent_indication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@

def agent_indicator_v0(env, type_only=False):
assert isinstance(env, AECEnv) or isinstance(env, ParallelEnv), "agent_indicator_v0 only accepts an AECEnv or ParallelEnv"
if not hasattr(env, 'observation_spaces') or not hasattr(env, 'possible_agents'):
raise AssertionError("environment passed to agent indicator wrapper must have the possible_agents and observation_spaces attributes.")
assert hasattr(env, 'possible_agents'), "environment passed to agent indicator wrapper must have the possible_agents attribute."

indicator_map = agent_ider.get_indicator_map(env.possible_agents, type_only)
num_indicators = len(set(indicator_map.values()))

if hasattr(env, 'observation_spaces'):
agent_ider.check_params(env.observation_spaces.values())
obs_spaces = [env.observation_space(agent) for agent in env.possible_agents]
agent_ider.check_params(obs_spaces)

return observation_lambda_v0(env,
lambda obs, obs_space, agent: agent_ider.change_observation(
Expand Down
9 changes: 6 additions & 3 deletions supersuit/multiagent_wrappers/black_death.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from supersuit.utils.base_aec_wrapper import BaseWrapper
from gym.spaces import Box
import numpy as np
Expand All @@ -13,12 +14,14 @@ def _modify_action(self, agent, action):

class black_death_aec(ObservationWrapper):
def _check_wrapper_params(self):
if not hasattr(self, 'observation_spaces') or not hasattr(self, 'possible_agents'):
raise AssertionError("environment passed to black death wrapper must have the possible_agents and observation_spaces attributes.")
if not hasattr(self, 'possible_agents'):
raise AssertionError("environment passed to black death wrapper must have the possible_agents attribute.")

for space in self.observation_spaces.values():
for agent in self.possible_agents:
space = self.observation_space(agent)
assert isinstance(space, gym.spaces.Box), f"observation sapces for black death must be Box spaces, is {space}"

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
old_obs_space = self.env.observation_space(agent)
return Box(low=np.minimum(0, old_obs_space.low), high=np.maximum(0, old_obs_space.high), dtype=old_obs_space.dtype)
Expand Down
11 changes: 6 additions & 5 deletions supersuit/multiagent_wrappers/padding_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

def pad_action_space_v0(env):
assert isinstance(env, AECEnv) or isinstance(env, ParallelEnv), "pad_action_space_v0 only accepts an AECEnv or ParallelEnv"
assert hasattr(env, 'action_spaces'), "environment passed to pad_observations must have a action_spaces dict."
homogenize_ops.check_homogenize_spaces(list(env.action_spaces.values()))
padded_space = homogenize_ops.homogenize_spaces(list(env.action_spaces.values()))
assert hasattr(env, 'possible_agents'), "environment passed to pad_observations must have a possible_agents list."
spaces = [env.action_space(agent) for agent in env.possible_agents]
homogenize_ops.check_homogenize_spaces(spaces)
padded_space = homogenize_ops.homogenize_spaces(spaces)
return action_lambda_v1(env,
lambda action, act_space: homogenize_ops.dehomogenize_actions(act_space, action),
lambda act_space: padded_space)


def pad_observations_v0(env):
assert isinstance(env, AECEnv) or isinstance(env, ParallelEnv), "pad_observations_v0 only accepts an AECEnv or ParallelEnv"
assert hasattr(env, 'observation_spaces'), "environment passed to pad_observations must have a observation_spaces dict."
spaces = list(env.observation_spaces.values())
assert hasattr(env, 'possible_agents'), "environment passed to pad_observations must have a possible_agents list."
spaces = [env.observation_space(agent) for agent in env.possible_agents]
homogenize_ops.check_homogenize_spaces(spaces)
padded_space = homogenize_ops.homogenize_spaces(spaces)
return observation_lambda_v0(env,
Expand Down
11 changes: 7 additions & 4 deletions supersuit/utils/make_defaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@


def make_defaultdict(d):
dd = defaultdict(type(next(iter(d.values()))))
for k, v in d.items():
dd[k] = v
return dd
try:
dd = defaultdict(type(next(iter(d.values()))))
for k, v in d.items():
dd[k] = v
return dd
except StopIteration:
return {}
8 changes: 4 additions & 4 deletions supersuit/vector/markov_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def __init__(self, par_env, black_death=False):
"""
self.par_env = par_env
self.metadata = par_env.metadata
self.observation_space = list(par_env.observation_spaces.values())[0]
self.action_space = list(par_env.action_spaces.values())[0]
self.observation_space = par_env.observation_space(par_env.possible_agents[0])
self.action_space = par_env.action_space(par_env.possible_agents[0])
assert all(
self.observation_space == obs_space for obs_space in par_env.observation_spaces.values()
self.observation_space == par_env.observation_space(agent) for agent in par_env.possible_agents
), "observation spaces not consistent. Perhaps you should wrap with `supersuit.aec_wrappers.pad_observations`?"
assert all(
self.action_space == obs_space for obs_space in par_env.action_spaces.values()
self.action_space == par_env.action_space(agent) for agent in par_env.possible_agents
), "action spaces not consistent. Perhaps you should wrap with `supersuit.aec_wrappers.pad_actions`?"
self.num_envs = len(par_env.possible_agents)
self.black_death = black_death
Expand Down
2 changes: 1 addition & 1 deletion supersuit/vector/vector_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def concat_vec_envs_v0(vec_env, num_vec_envs, num_cpus=0, base_class='gym'):

def pettingzoo_env_to_vec_env_v0(parallel_env):
assert isinstance(parallel_env, ParallelEnv), "pettingzoo_env_to_vec_env takes in a pettingzoo ParallelEnv. Can create a parallel_env with pistonball.parallel_env() or convert it from an AEC env with `from pettingzoo.utils.conversions import to_parallel; to_parallel(env)``"
assert hasattr(parallel_env, 'action_spaces') and hasattr(parallel_env, 'possible_agents') and hasattr(parallel_env, 'observation_spaces'), "environment passed to pettingzoo_env_to_vec_env must have action_spaces, observation_spaces, and possible_agents attributes."
assert hasattr(parallel_env, 'possible_agents'), "environment passed to pettingzoo_env_to_vec_env must have possible_agents attribute."
return MarkovVectorEnv(parallel_env)
Loading

0 comments on commit 6997b7a

Please sign in to comment.