From 8cf315b9336c12e74e2e82187cb472d70879a930 Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 09:10:56 -0700 Subject: [PATCH 1/7] Updated actors to use internal supported function instead of limiting to supported agent type --- abmarl/sim/gridworld/actor.py | 40 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/abmarl/sim/gridworld/actor.py b/abmarl/sim/gridworld/actor.py index df820ed3..90aadf2d 100644 --- a/abmarl/sim/gridworld/actor.py +++ b/abmarl/sim/gridworld/actor.py @@ -42,14 +42,17 @@ def key(self): """ pass - @property @abstractmethod - def supported_agent_type(self): + def _supported_agent(self, agent): """ - The type of Agent that this Actor works with. + The qualifications that the agent must satisfy in order to work with this Actor. + + For example, Attack Actors require the agent to be an Attacking Agent. - If an agent is this type, the Actor will add its entry to the - agent's action space and will process actions for this agent. + Args: + agent: The agent to inspect. + Returns: + True if agent satisfies qualities, otherwise False. """ pass @@ -61,7 +64,7 @@ class MoveActor(ActorBaseComponent): def __init__(self, **kwargs): super().__init__(**kwargs) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): if agent.move_range == "FULL": agent.move_range = max(self.rows, self.cols) - 1 agent.action_space[self.key] = Box( @@ -76,12 +79,11 @@ def key(self): """ return "move" - @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Actor works with MovingAgents. """ - return MovingAgent + return isinstance(agent, MovingAgent) def process_action(self, agent, action_dict, **kwargs): """ @@ -99,7 +101,7 @@ def process_action(self, agent, action_dict, **kwargs): Returns: True if the move is successful, False otherwise. """ - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): action = action_dict[self.key] new_position = agent.position + action if 0 <= new_position[0] < self.rows and \ @@ -125,7 +127,7 @@ class CrossMoveActor(ActorBaseComponent): def __init__(self, **kwargs): super().__init__(**kwargs) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.action_space[self.key] = Discrete(5) agent.null_action[self.key] = 0 @@ -136,12 +138,11 @@ def key(self): """ return "move" - @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Actor works with MovingAgent, but the move_range parameter is ignored. """ - return MovingAgent + return isinstance(agent, MovingAgent) def grid_action(self, cross_action): """ @@ -178,7 +179,7 @@ def process_action(self, agent, action_dict, **kwargs): Returns: True if the move is successful, False otherwise. """ - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): cross_action = action_dict[self.key] action = self.grid_action(cross_action) new_position = agent.position + action @@ -252,7 +253,7 @@ def __init__(self, attack_mapping=None, stacked_attacks=False, **kwargs): self.attack_mapping = attack_mapping self.stacked_attacks = stacked_attacks for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): if agent.attack_range == "FULL": agent.attack_range = max(self.rows, self.cols) - 1 self._assign_space(agent) @@ -264,12 +265,11 @@ def key(self): """ return 'attack' - @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Actor works with AttackingAgents. """ - return AttackingAgent + return isinstance(agent, AttackingAgent) @property def attack_mapping(self): @@ -349,7 +349,7 @@ def process_action(self, attacking_agent, action_dict, **kwargs): 2. An attack failed: True, [] 3. An attack was successful: True, [non-empty] """ - if isinstance(attacking_agent, self.supported_agent_type): + if self._supported_agent(attacking_agent): action = action_dict[self.key] attack_status, attacked_agents = self._determine_attack(attacking_agent, action) From 623c9a7c6942af1ab795b27ea735687c7d1b3e03 Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 09:22:32 -0700 Subject: [PATCH 2/7] Observers use new supported agent function --- abmarl/sim/gridworld/observer.py | 56 +++++++++++++++++--------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/abmarl/sim/gridworld/observer.py b/abmarl/sim/gridworld/observer.py index 44fda6dc..34608fb8 100644 --- a/abmarl/sim/gridworld/observer.py +++ b/abmarl/sim/gridworld/observer.py @@ -27,14 +27,19 @@ def key(self): """ pass - @property @abstractmethod - def supported_agent_type(self): + def _supported_agent(self, agent): """ - The type of Agent that this Observer works with. + The qualifications that the agent must satisfy in order to work with this Actor. - If an agent is this type, the Observer will add its entry to the - agent's observation space and will produce observations for this agent. + For example, Grid Observers require the agent to be a Grid Observing Agent. + Ammo Observer requires the agent to be an AmmoAgent and to be an Observing + Agent. + + Args: + agent: The agent to inspect. + Returns: + True if agent satisfies qualities, otherwise False. """ pass @@ -70,7 +75,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) max_encoding = max(self._encodings_in_sim) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): if agent.view_range == "FULL": agent.view_range = max(self.rows, self.cols) - 1 agent.observation_space[self.key] = Box( @@ -87,12 +92,11 @@ def key(self): """ return 'absolute_encoding' - @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Observer work with GridObservingAgents """ - return GridObservingAgent + return isinstance(agent, GridObservingAgent) def get_obs(self, agent, **kwargs): """ @@ -103,7 +107,7 @@ def get_obs(self, agent, **kwargs): masked cells indicated as -2, which are masked either because they are too far away or because they are blocked from view by view-blocking agents. """ - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} # To generate the observation, we first create a local grid and mask using @@ -166,7 +170,7 @@ def __init__(self, observe_self=True, **kwargs): self.observe_self = observe_self max_encoding = max(self._encodings_in_sim) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): if agent.view_range == "FULL": agent.view_range = max(self.rows, self.cols) - 1 agent.observation_space[self.key] = Box( @@ -185,11 +189,11 @@ def key(self): return 'position_centered_encoding' @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Observer works with GridObservingAgents. """ - return GridObservingAgent + return isinstance(agent, GridObservingAgent) @property def observe_self(self): @@ -215,7 +219,7 @@ def get_obs(self, agent, **kwargs): Returns: The observation as a dictionary. """ - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} # Generate a local grid and an observation mask @@ -267,7 +271,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.number_of_encodings = max([agent.encoding for agent in self.agents.values()]) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): if agent.view_range == "FULL": agent.view_range = max(self.rows, self.cols) - 1 agent.observation_space[self.key] = Box( @@ -289,11 +293,11 @@ def key(self): return 'stacked_position_centered_encoding' @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Observer works with GridObservingAgents. """ - return GridObservingAgent + return isinstance(agent, GridObservingAgent) def get_obs(self, agent, **kwargs): """ @@ -307,7 +311,7 @@ def get_obs(self, agent, **kwargs): Returns: The observation as a dictionary. """ - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} # Generate a local grid and an observation mask. @@ -347,7 +351,7 @@ class AbsolutePositionObserver(ObserverBaseComponent): def __init__(self, **kwargs): super().__init__(**kwargs) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.observation_space[self.key] = Box( np.array([0, 0], dtype=int), np.array([self.grid.rows - 1, self.grid.cols - 1], dtype=int), @@ -363,17 +367,17 @@ def key(self): return 'position' @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Observer works with ObservingAgents """ - return ObservingAgent + return isinstance(agent, ObservingAgent) def get_obs(self, agent, **kwargs): """ Agents observe their absolute position. """ - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} else: return {self.key: agent.position} @@ -386,7 +390,7 @@ class AmmoObserver(ObserverBaseComponent): def __init__(self, **kwargs): super().__init__(**kwargs) for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.observation_space[self.key] = Box( 0, agent.initial_ammo, @@ -403,17 +407,17 @@ def key(self): return 'ammo' @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ This Observer works with AmmoObservingAgents. """ - return AmmoObservingAgent + raise RuntimeError("FIX THIS!") def get_obs(self, agent, **kwargs): """ Agents observe their own ammo """ - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} else: return {self.key: agent.ammo} From 307b9d25fe699d0a4e934736210028fd984e207a Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 09:26:47 -0700 Subject: [PATCH 3/7] Actor wrapper uses supported agent interface --- abmarl/sim/gridworld/wrapper.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/abmarl/sim/gridworld/wrapper.py b/abmarl/sim/gridworld/wrapper.py index 00970c8b..1f9d6d30 100644 --- a/abmarl/sim/gridworld/wrapper.py +++ b/abmarl/sim/gridworld/wrapper.py @@ -106,10 +106,10 @@ def __init__(self, component): self.from_space = { agent.id: agent.action_space[self.key] for agent in self.agents.values() - if isinstance(agent, self.supported_agent_type) + if self._supported_agent(agent) } for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): assert self.check_space(agent.action_space[self.key]), \ f"Cannot wrap {self.key} action channel for agent {agent.id}" agent.action_space[self.key] = self.wrap_space(agent.action_space[self.key]) @@ -133,12 +133,11 @@ def key(self): """ return self.wrapped_component.key - @property - def supported_agent_type(self): + def _supported_agent(self, agent): """ - The supported agent type is the same as the wrapped actor's supported agent type. + The supported agent is the same as the wrapped actor's supported agent. """ - return self.wrapped_component.supported_agent_type + return self.wrapped_component._supported_agent(agent) def process_action(self, agent, action_dict, **kwargs): """ @@ -149,7 +148,7 @@ def process_action(self, agent, action_dict, **kwargs): action_dict: The action dictionary for this agent in this step. The action in this channel comes in the wrapped space. """ - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): action = action_dict[self.key] unwrapped_action = self.wrap_point(self.from_space[agent.id], action) return self.wrapped_component.process_action( From 65266be68f21e6bc321a99b2ae0953317fd747ec Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 09:40:36 -0700 Subject: [PATCH 4/7] Resolved issue with ammo observing agent --- abmarl/sim/gridworld/agent.py | 18 ----------------- abmarl/sim/gridworld/observer.py | 6 +++--- docs/src/api.rst | 6 ------ docs/src/gridworld.rst | 4 ---- tests/sim/gridworld/test_agent.py | 29 +-------------------------- tests/sim/gridworld/test_observer.py | 30 +--------------------------- 6 files changed, 5 insertions(+), 88 deletions(-) diff --git a/abmarl/sim/gridworld/agent.py b/abmarl/sim/gridworld/agent.py index 257ae522..76055f62 100644 --- a/abmarl/sim/gridworld/agent.py +++ b/abmarl/sim/gridworld/agent.py @@ -315,24 +315,6 @@ def initial_ammo(self, value): self._initial_ammo = value -class AmmoObservingAgentMeta(type): - """ - AmmoObservingAgentMeta class defines an AmmoObservingAgent as an instance of - AmmoAgent and ObservingAgent. Then, when we check if an agent is an instance - of AmmoObservingAgent, it doesn't have to directly derive from it; it just has - to derive from both AmmoAgent and ObservingAgent. - """ - def __instancecheck__(self, instance): - return isinstance(instance, ObservingAgent) and isinstance(instance, AmmoAgent) - - -class AmmoObservingAgent(AmmoAgent, ObservingAgent, metaclass=AmmoObservingAgentMeta): - """ - Boilterplate class required to work with the AmmoObserver. - """ - pass - - class OrientationAgent(GridWorldAgent): """ Agent that has an orientation, either 1: Left, 2: Down, 3: Right, 4: Up. diff --git a/abmarl/sim/gridworld/observer.py b/abmarl/sim/gridworld/observer.py index 34608fb8..2bda61cb 100644 --- a/abmarl/sim/gridworld/observer.py +++ b/abmarl/sim/gridworld/observer.py @@ -6,7 +6,7 @@ from abmarl.tools import Box from abmarl.sim.agent_based_simulation import ObservingAgent from abmarl.sim.gridworld.base import GridWorldBaseComponent -from abmarl.sim.gridworld.agent import GridObservingAgent, AmmoObservingAgent +from abmarl.sim.gridworld.agent import GridObservingAgent, AmmoAgent import abmarl.sim.gridworld.utils as gu @@ -409,9 +409,9 @@ def key(self): @property def _supported_agent(self, agent): """ - This Observer works with AmmoObservingAgents. + This Observer works with agents that are both AmmoAgents and ObservingAgents. """ - raise RuntimeError("FIX THIS!") + return isinstance(agent, AmmoAgent) and isinstance(agent, ObservingAgent) def get_obs(self, agent, **kwargs): """ diff --git a/docs/src/api.rst b/docs/src/api.rst index 2b336853..e4c9c6db 100644 --- a/docs/src/api.rst +++ b/docs/src/api.rst @@ -199,12 +199,6 @@ Agents :members: :undoc-members: -.. _api_gridworld_agent_ammo_observing: - -.. autoclass:: abmarl.sim.gridworld.agent.AmmoObservingAgent - :members: - :undoc-members: - State ````` diff --git a/docs/src/gridworld.rst b/docs/src/gridworld.rst index 7d69c7cc..a8ff3ea4 100644 --- a/docs/src/gridworld.rst +++ b/docs/src/gridworld.rst @@ -1049,10 +1049,6 @@ Each :ref:`Attack Actor ` interprets the ammo, but the gene idea is that when the attacking agent runs out of ammo, its attacks are no longer successful. -:ref:`AmmoObservingAgents ` work in conjuction -with the :ref:`AmmoObserver ` to observe their own -ammo. - .. _gridworld_attacking: diff --git a/tests/sim/gridworld/test_agent.py b/tests/sim/gridworld/test_agent.py index 2e97d21e..3bd4f478 100644 --- a/tests/sim/gridworld/test_agent.py +++ b/tests/sim/gridworld/test_agent.py @@ -3,7 +3,7 @@ import pytest from abmarl.sim.gridworld.agent import GridWorldAgent, GridObservingAgent, MovingAgent, \ - AttackingAgent, AmmoAgent, AmmoObservingAgent, OrientationAgent + AttackingAgent, AmmoAgent, OrientationAgent from abmarl.sim import PrincipleAgent, ActingAgent, ObservingAgent @@ -223,33 +223,6 @@ def test_ammo_agent(): ) -def test_ammo_observing_agent(): - class CustomAmmoObservingAgent(AmmoAgent, ObservingAgent): pass - - agent = AmmoObservingAgent( - id='agent', - encoding=1, - initial_ammo=4 - ) - assert isinstance(agent, AmmoAgent) - assert isinstance(agent, ObservingAgent) - assert isinstance(agent, GridWorldAgent) - - agent = CustomAmmoObservingAgent( - id='agent', - encoding=1, - initial_ammo=2 - ) - assert isinstance(agent, AmmoObservingAgent) - - with pytest.raises(AssertionError): - agent = AmmoObservingAgent( - id='agent', - encoding=1, - initial_ammo=2.4 - ) - - def test_orientation_agent(): agent = OrientationAgent( id='agent', diff --git a/tests/sim/gridworld/test_observer.py b/tests/sim/gridworld/test_observer.py index 10dcf2bd..1582ca0c 100644 --- a/tests/sim/gridworld/test_observer.py +++ b/tests/sim/gridworld/test_observer.py @@ -7,39 +7,11 @@ PositionCenteredEncodingObserver, StackedPositionCenteredEncodingObserver, \ AbsolutePositionObserver, AmmoObserver from abmarl.sim.gridworld.agent import GridObservingAgent, GridWorldAgent, MovingAgent, \ - AmmoAgent, AmmoObservingAgent + AmmoAgent from abmarl.sim.gridworld.state import PositionState, AmmoState from abmarl.sim.gridworld.grid import Grid -def test_ammo_observer(): - grid = Grid(3, 3) - agents = { - 'agent0': AmmoAgent(id='agent0', encoding=1, initial_ammo=10), - 'agent1': AmmoObservingAgent(id='agent1', encoding=1, initial_ammo=-3), - 'agent2': AmmoObservingAgent(id='agent2', encoding=1, initial_ammo=14), - 'agent3': AmmoObservingAgent(id='agent3', encoding=1, initial_ammo=12), - } - state = AmmoState(grid=grid, agents=agents) - observer = AmmoObserver(grid=grid, agents=agents) - assert isinstance(observer, ObserverBaseComponent) - assert observer._encodings_in_sim == {1} - state.reset() - - assert observer.get_obs(agents['agent1'])['ammo'] == agents['agent1'].ammo - assert observer.get_obs(agents['agent2'])['ammo'] == agents['agent2'].ammo - assert observer.get_obs(agents['agent3'])['ammo'] == agents['agent3'].ammo - - agents['agent0'].ammo -= 16 - agents['agent1'].ammo += 7 - agents['agent2'].ammo -= 15 - assert observer.get_obs(agents['agent1'])['ammo'] == agents['agent1'].ammo - assert observer.get_obs(agents['agent2'])['ammo'] == agents['agent2'].ammo - assert observer.get_obs(agents['agent3'])['ammo'] == agents['agent3'].ammo - - assert not observer.get_obs(agents['agent0']) - - def test_absolute_encoding_observer(): np.random.seed(24) grid = Grid(5, 5, overlapping={1: {6}, 6: {1}}) From fe2d6620cefb794c0c410f6560f23b7e2fde49e1 Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 10:25:18 -0700 Subject: [PATCH 5/7] Hunting down any other uses of supported agent type --- abmarl/examples/sim/comms_blocking.py | 20 ++++++------- abmarl/sim/gridworld/actor.py | 8 +++++- abmarl/sim/gridworld/observer.py | 10 ++----- .../src/tutorials/gridworld/communication.rst | 27 +++++++++--------- tests/sim/gridworld/test_actor.py | 28 ++++++++++++------- tests/sim/gridworld/test_observer.py | 8 ++++-- tests/sim/gridworld/test_wrapper.py | 2 -- 7 files changed, 55 insertions(+), 48 deletions(-) diff --git a/abmarl/examples/sim/comms_blocking.py b/abmarl/examples/sim/comms_blocking.py index def2b38f..f01448db 100644 --- a/abmarl/examples/sim/comms_blocking.py +++ b/abmarl/examples/sim/comms_blocking.py @@ -63,16 +63,15 @@ def __init__(self, broadcast_mapping=None, **kwargs): super().__init__(**kwargs) self.broadcast_mapping = broadcast_mapping for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.action_space[self.key] = Discrete(2) @property def key(self): return 'broadcast' - @property - def supported_agent_type(self): - return BroadcastingAgent + def _supported_agent(self, agent): + return isinstance(agent, BroadcastingAgent) @property def broadcast_mapping(self): @@ -138,7 +137,7 @@ def determine_broadcast(agent): receiving_agents.append(other) return receiving_agents - if isinstance(broadcasting_agent, self.supported_agent_type): + if self._supported_agent(broadcasting_agent): action = action_dict[self.key] if action: # Agent has chosen to attack return determine_broadcast(broadcasting_agent) @@ -182,23 +181,22 @@ def __init__(self, broadcasting_state=None, **kwargs): self._broadcasting_state = broadcasting_state for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.observation_space[self.key] = Dict({ other.id: Box(-1, 1, (1,)) for other in self.agents.values() - if isinstance(other, self.supported_agent_type) + if self._supported_agent(other) }) @property def key(self): return 'message' - @property - def supported_agent_type(self): - return BroadcastingAgent + def _supported_agent(self, agent): + return isinstance(agent, BroadcastingAgent) def get_obs(self, agent, **kwargs): - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} obs = {other: 0 for other in agent.observation_space[self.key]} diff --git a/abmarl/sim/gridworld/actor.py b/abmarl/sim/gridworld/actor.py index 90aadf2d..97c52eb2 100644 --- a/abmarl/sim/gridworld/actor.py +++ b/abmarl/sim/gridworld/actor.py @@ -211,6 +211,12 @@ class DriftMoveActor(CrossMoveActor): that change will fail and it will keep its current orientation, even though it is blocked that way too. """ + def _supported_agent(self, agent): + """ + This Actor works with MovingAgent and OrientationAgent. + """ + return isinstance(agent, MovingAgent) and isinstance(agent, OrientationAgent) + def process_action(self, agent, action_dict, **kwargs): """ The agent can move up, down, left, right, or stay in place. @@ -227,7 +233,7 @@ def process_action(self, agent, action_dict, **kwargs): Returns: True if the move is successful, False otherwise. """ - if isinstance(agent, OrientationAgent) and isinstance(agent, MovingAgent): + if self._supported_agent(agent): cross_action = action_dict[self.key] if cross_action != 0: # Agent has attempted to change directions, let the super process diff --git a/abmarl/sim/gridworld/observer.py b/abmarl/sim/gridworld/observer.py index 2bda61cb..52d1b710 100644 --- a/abmarl/sim/gridworld/observer.py +++ b/abmarl/sim/gridworld/observer.py @@ -6,7 +6,7 @@ from abmarl.tools import Box from abmarl.sim.agent_based_simulation import ObservingAgent from abmarl.sim.gridworld.base import GridWorldBaseComponent -from abmarl.sim.gridworld.agent import GridObservingAgent, AmmoAgent +from abmarl.sim.gridworld.agent import GridObservingAgent, AmmoAgent, GridWorldAgent import abmarl.sim.gridworld.utils as gu @@ -188,7 +188,6 @@ def key(self): """ return 'position_centered_encoding' - @property def _supported_agent(self, agent): """ This Observer works with GridObservingAgents. @@ -292,7 +291,6 @@ def key(self): """ return 'stacked_position_centered_encoding' - @property def _supported_agent(self, agent): """ This Observer works with GridObservingAgents. @@ -366,12 +364,11 @@ def key(self): """ return 'position' - @property def _supported_agent(self, agent): """ - This Observer works with ObservingAgents + This Observer works with agents that are both GridWorldAgents and ObservingAgents. """ - return isinstance(agent, ObservingAgent) + return isinstance(agent, ObservingAgent) and isinstance(agent, GridWorldAgent) def get_obs(self, agent, **kwargs): """ @@ -406,7 +403,6 @@ def key(self): """ return 'ammo' - @property def _supported_agent(self, agent): """ This Observer works with agents that are both AmmoAgents and ObservingAgents. diff --git a/docs/src/tutorials/gridworld/communication.rst b/docs/src/tutorials/gridworld/communication.rst index 59627eb2..c93da5e0 100644 --- a/docs/src/tutorials/gridworld/communication.rst +++ b/docs/src/tutorials/gridworld/communication.rst @@ -176,7 +176,7 @@ to each agent's message. # Tracks agents receiving messages from other agents self.receiving_state = { agent.id: [] for agent in self.agents.values() if isinstance(agent, BroadcastingAgent) - } + }) def update_receipients(self, from_agent, to_agents): """ @@ -225,17 +225,16 @@ a compatible encoding, and (3) is not blocked. super().__init__(**kwargs) self.broadcast_mapping = broadcast_mapping for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.action_space[self.key] = Discrete(2) agent.null_action[self.key] = 0 @property def key(self): return 'broadcast' - - @property - def supported_agent_type(self): - return BroadcastingAgent + + def _supported_agent(self, agent): + return isinstance(agent, BroadcastingAgent) @property def broadcast_mapping(self): @@ -301,7 +300,7 @@ a compatible encoding, and (3) is not blocked. receiving_agents.append(other) return receiving_agents - if isinstance(broadcasting_agent, self.supported_agent_type): + if self._supported_agent(broadcasting_agent): action = action_dict[self.key] if action: # Agent has chosen to attack return determine_broadcast(broadcasting_agent) @@ -326,26 +325,26 @@ component, which will have a small impact in how we initialize the simulation. self._broadcasting_state = broadcasting_state for agent in self.agents.values(): - if isinstance(agent, self.supported_agent_type): + if self._supported_agent(agent): agent.observation_space[self.key] = Dict({ other.id: Box(-1, 1, (1,), float) - for other in self.agents.values() if isinstance(other, self.supported_agent_type) + for other in self.agents.values() if self._supported_agent(other) }) agent.null_observation[self.key] = { other.id: 0. for other in self.agents.values() - if isinstance(other, self.supported_agent_type) + if self._supported_agent(other) } @property def key(self): return 'message' - @property - def supported_agent_type(self): - return BroadcastingAgent + + def _supported_agent(self): + return isinstance(agent, BroadcastingAgent) and isinstance(agent, ObservingAgent) def get_obs(self, agent, **kwargs): - if not isinstance(agent, self.supported_agent_type): + if not self._supported_agent(agent): return {} obs = {other: 0 for other in agent.observation_space[self.key]} diff --git a/tests/sim/gridworld/test_actor.py b/tests/sim/gridworld/test_actor.py index 76987906..69911dfe 100644 --- a/tests/sim/gridworld/test_actor.py +++ b/tests/sim/gridworld/test_actor.py @@ -42,7 +42,7 @@ def test_move_actor(): move_actor = MoveActor(grid=grid, agents=agents) assert isinstance(move_actor, ActorBaseComponent) assert move_actor.key == 'move' - assert move_actor.supported_agent_type == MovingAgent + assert move_actor._supported_agent(agents['agent0']) assert move_actor._encodings_in_sim == {1, 2, 3} assert agents['agent0'].action_space['move'] == Box(-1, 1, (2,), int) assert agents['agent1'].action_space['move'] == Box(-2, 2, (2,), int) @@ -169,7 +169,7 @@ def test_cross_move_actor(): move_actor = CrossMoveActor(grid=grid, agents=agents) assert isinstance(move_actor, ActorBaseComponent) assert move_actor.key == 'move' - assert move_actor.supported_agent_type == MovingAgent + assert move_actor._supported_agent(agents['agent0']) assert move_actor._encodings_in_sim == {1, 2, 3} assert agents['agent0'].action_space['move'] == Discrete(5) assert agents['agent1'].action_space['move'] == Discrete(5) @@ -499,7 +499,8 @@ def test_binary_attack_actor(): attack_actor = BinaryAttackActor(attack_mapping={1: 1}, grid=grid, agents=agents) assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert not attack_actor._supported_agent(agents['agent0']) + assert attack_actor._supported_agent(agents['agent1']) assert agents['agent1'].action_space['attack'] == Discrete(2) assert attack_actor.attack_mapping == {1: {1}} assert attack_actor._encodings_in_sim == {1, 2} @@ -776,7 +777,8 @@ def test_selective_attack_actor(): assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor._encodings_in_sim == {1, 2} assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent1']) + assert not attack_actor._supported_agent(agents['agent0']) assert agents['agent1'].action_space['attack'] == Box(0, 1, (5, 5), int) agents['agent1'].finalize() @@ -939,7 +941,8 @@ def test_selective_attack_actor_ammo(): attack_actor = SelectiveAttackActor(attack_mapping={1: {1}}, grid=grid, agents=agents) assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent1']) + assert not attack_actor._supported_agent(agents['agent0']) assert agents['agent1'].action_space['attack'] == Box(0, 1, (5, 5), int) agents['agent1'].finalize() @@ -1051,7 +1054,8 @@ def test_selective_attack_actor_simultaneous_attacks(): attack_actor = SelectiveAttackActor(attack_mapping={3: {1, 2}}, grid=grid, agents=agents) assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent0']) + assert not attack_actor._supported_agent(agents['agent1']) assert agents['agent0'].action_space['attack'] == Box(0, 3, (3, 3), int) agents['agent0'].finalize() np.testing.assert_array_equal( @@ -1230,7 +1234,8 @@ def test_encoding_based_attack_actor(): assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor._encodings_in_sim == {1, 2, 3} assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent3']) + assert not attack_actor._supported_agent(agents['agent1']) assert agents['agent3'].action_space['attack'] == Dict({ 1: Discrete(2), 2: Discrete(2) @@ -1307,7 +1312,8 @@ def test_encoding_based_attack_actor_ammo(): attack_actor = EncodingBasedAttackActor(attack_mapping={3: {1, 2}}, grid=grid, agents=agents) assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent3']) + assert not attack_actor._supported_agent(agents['agent1']) assert agents['agent3'].action_space['attack'] == Dict({ 1: Discrete(2), 2: Discrete(2) @@ -1497,7 +1503,8 @@ def test_encoding_based_attack_actor_stacked_attack(): ) assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent3']) + assert not attack_actor._supported_agent(agents['agent1']) assert agents['agent3'].action_space['attack'] == Dict({ 1: Discrete(3), 2: Discrete(3) @@ -1564,7 +1571,8 @@ def test_restricted_selective_attack_actor(): assert isinstance(attack_actor, ActorBaseComponent) assert attack_actor._encodings_in_sim == {1, 2, 3} assert attack_actor.key == 'attack' - assert attack_actor.supported_agent_type == AttackingAgent + assert attack_actor._supported_agent(agents['agent3']) + assert not attack_actor._supported_agent(agents['agent1']) assert agents['agent3'].action_space['attack'] == MultiDiscrete([10, 10]) agents['agent3'].finalize() diff --git a/tests/sim/gridworld/test_observer.py b/tests/sim/gridworld/test_observer.py index 1582ca0c..667c298d 100644 --- a/tests/sim/gridworld/test_observer.py +++ b/tests/sim/gridworld/test_observer.py @@ -188,7 +188,8 @@ def test_single_grid_observer(): observer = PositionCenteredEncodingObserver(agents=agents, grid=grid) assert observer._encodings_in_sim == {1, 2, 3, 4, 5, 6} assert observer.key == 'position_centered_encoding' - assert observer.supported_agent_type == GridObservingAgent + assert observer._supported_agent(agents['agent0']) + assert not observer._supported_agent(agents['agent3']) assert isinstance(observer, ObserverBaseComponent) assert agents['agent0'].observation_space['position_centered_encoding'] == Box( -2, 6, (5, 5), int @@ -356,7 +357,8 @@ class HackAgent(GridObservingAgent, MovingAgent): pass observer = StackedPositionCenteredEncodingObserver(agents=agents, grid=grid) assert observer._encodings_in_sim == {1, 2, 3, 4, 5, 6} assert observer.key == 'stacked_position_centered_encoding' - assert observer.supported_agent_type == GridObservingAgent + assert observer._supported_agent(agents['agent0']) + assert not observer._supported_agent(agents['agent5']) assert isinstance(observer, ObserverBaseComponent) assert observer.number_of_encodings == 6 assert agents['agent0'].observation_space['stacked_position_centered_encoding'] == Box( @@ -927,7 +929,7 @@ class PositionObservingAgent(ObservingAgent, GridWorldAgent): pass observer = AbsolutePositionObserver(agents=agents, grid=grid) assert observer._encodings_in_sim == {1, 2, 3, 4, 5, 6} assert observer.key == 'position' - assert observer.supported_agent_type == ObservingAgent + assert observer._supported_agent(agents['agent0']) assert isinstance(observer, ObserverBaseComponent) for agent in agents.values(): agent.finalize() diff --git a/tests/sim/gridworld/test_wrapper.py b/tests/sim/gridworld/test_wrapper.py index 8d8a5be7..59cdea3e 100644 --- a/tests/sim/gridworld/test_wrapper.py +++ b/tests/sim/gridworld/test_wrapper.py @@ -60,7 +60,6 @@ def test_ravelled_move_wrapper_properties(): assert ravelled_move_actor.agents == move_actor.agents assert ravelled_move_actor.grid == move_actor.grid assert ravelled_move_actor.key == move_actor.key - assert ravelled_move_actor.supported_agent_type == move_actor.supported_agent_type def test_ravelled_move_wrapper_agent_spaces(): @@ -158,7 +157,6 @@ def test_exclusive_attack_wrapper_properties(): assert exclusive_attack_actor.agents == attack_actor.agents assert exclusive_attack_actor.grid == attack_actor.grid assert exclusive_attack_actor.key == attack_actor.key - assert exclusive_attack_actor.supported_agent_type == attack_actor.supported_agent_type def test_exclusive_attack_wrapper_agent_spaces(): From eb593a6d979c85d9e3a444f2d75847c0823b6457 Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Wed, 3 Apr 2024 10:41:09 -0700 Subject: [PATCH 6/7] Readded ammo observer test --- abmarl/sim/gridworld/observer.py | 2 +- tests/sim/gridworld/test_observer.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/abmarl/sim/gridworld/observer.py b/abmarl/sim/gridworld/observer.py index 52d1b710..b3f66cb2 100644 --- a/abmarl/sim/gridworld/observer.py +++ b/abmarl/sim/gridworld/observer.py @@ -35,7 +35,7 @@ def _supported_agent(self, agent): For example, Grid Observers require the agent to be a Grid Observing Agent. Ammo Observer requires the agent to be an AmmoAgent and to be an Observing Agent. - + Args: agent: The agent to inspect. Returns: diff --git a/tests/sim/gridworld/test_observer.py b/tests/sim/gridworld/test_observer.py index 667c298d..309c0445 100644 --- a/tests/sim/gridworld/test_observer.py +++ b/tests/sim/gridworld/test_observer.py @@ -12,6 +12,35 @@ from abmarl.sim.gridworld.grid import Grid +def test_ammo_observer(): + class AmmoObservingAgent(AmmoAgent, ObservingAgent): pass + grid = Grid(3, 3) + agents = { + 'agent0': AmmoAgent(id='agent0', encoding=1, initial_ammo=10), + 'agent1': AmmoObservingAgent(id='agent1', encoding=1, initial_ammo=-3), + 'agent2': AmmoObservingAgent(id='agent2', encoding=1, initial_ammo=14), + 'agent3': AmmoObservingAgent(id='agent3', encoding=1, initial_ammo=12), + } + state = AmmoState(grid=grid, agents=agents) + observer = AmmoObserver(grid=grid, agents=agents) + assert isinstance(observer, ObserverBaseComponent) + assert observer._encodings_in_sim == {1} + state.reset() + + assert observer.get_obs(agents['agent1'])['ammo'] == agents['agent1'].ammo + assert observer.get_obs(agents['agent2'])['ammo'] == agents['agent2'].ammo + assert observer.get_obs(agents['agent3'])['ammo'] == agents['agent3'].ammo + + agents['agent0'].ammo -= 16 + agents['agent1'].ammo += 7 + agents['agent2'].ammo -= 15 + assert observer.get_obs(agents['agent1'])['ammo'] == agents['agent1'].ammo + assert observer.get_obs(agents['agent2'])['ammo'] == agents['agent2'].ammo + assert observer.get_obs(agents['agent3'])['ammo'] == agents['agent3'].ammo + + assert not observer.get_obs(agents['agent0']) + + def test_absolute_encoding_observer(): np.random.seed(24) grid = Grid(5, 5, overlapping={1: {6}, 6: {1}}) From f04bfb7303eb79754b52017aa24963713dc5ee18 Mon Sep 17 00:00:00 2001 From: Ephraim Rusu Date: Thu, 4 Apr 2024 10:39:45 -0700 Subject: [PATCH 7/7] syntax fix in rst file --- docs/src/tutorials/gridworld/communication.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/tutorials/gridworld/communication.rst b/docs/src/tutorials/gridworld/communication.rst index c93da5e0..e8cefb82 100644 --- a/docs/src/tutorials/gridworld/communication.rst +++ b/docs/src/tutorials/gridworld/communication.rst @@ -176,7 +176,7 @@ to each agent's message. # Tracks agents receiving messages from other agents self.receiving_state = { agent.id: [] for agent in self.agents.values() if isinstance(agent, BroadcastingAgent) - }) + } def update_receipients(self, from_agent, to_agents): """