Skip to content

Commit

Permalink
Merge pull request #500 from LLNL/abmarl-473-predator-prey-resources
Browse files Browse the repository at this point in the history
Abmarl 473 predator prey resources
  • Loading branch information
rusu24edward authored Mar 19, 2024
2 parents a60a87b + b32eab8 commit 3b9e96b
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 0 deletions.
1 change: 1 addition & 0 deletions abmarl/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .sim import EmptyABS, MultiAgentSim, MultiAgentGymSpacesSim, \
MultiAgentContinuousGymSpaceSim, MultiAgentSameSpacesSim
from .sim import ReachTheTargetSim, RunningAgent, TargetAgent, BarrierAgent
from .sim import PredatorPreyResourcesSim, ResourceAgent, PreyAgent, PredatorAgent
3 changes: 3 additions & 0 deletions abmarl/examples/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
MultiAgentContinuousGymSpaceSim, MultiAgentSameSpacesSim
from .multi_agent_grid_sim import MultiAgentGridSim
from .reach_the_target import ReachTheTargetSim, RunningAgent, TargetAgent, BarrierAgent
from .predator_prey_resources import (
PredatorPreyResourcesSim, ResourceAgent, PreyAgent, PredatorAgent
)
111 changes: 111 additions & 0 deletions abmarl/examples/sim/predator_prey_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@

from abmarl.sim import Agent
from abmarl.sim.gridworld.agent import MovingAgent, AttackingAgent, \
GridObservingAgent, HealthAgent
from abmarl.sim.gridworld.smart import SmartGridWorldSimulation
from abmarl.sim.gridworld.actor import MoveActor, BinaryAttackActor


class ResourceAgent(HealthAgent):
def __init__(
self,
encoding=1,
render_shape='s',
render_color='g',
**kwargs
):
super().__init__(
encoding=encoding,
render_shape=render_shape,
render_color=render_color,
**kwargs
)


class PreyAgent(HealthAgent, MovingAgent, AttackingAgent, GridObservingAgent):
def __init__(
self,
encoding=2,
render_color='b',
move_range=1,
attack_range=1,
attack_strength=1,
attack_accuracy=1,
view_range=3,
**kwargs
):
super().__init__(
encoding=encoding,
render_color=render_color,
move_range=move_range,
attack_range=attack_range,
attack_strength=attack_strength,
attack_accuracy=attack_accuracy,
view_range=view_range,
**kwargs
)


class PredatorAgent(HealthAgent, MovingAgent, AttackingAgent, GridObservingAgent):
def __init__(
self,
encoding=3,
render_color='r',
render_shape='d',
move_range=1,
attack_range=2,
attack_strength=1,
attack_accuracy=1,
view_range=3,
**kwargs
):
super().__init__(
encoding=encoding,
render_color=render_color,
render_shape=render_shape,
move_range=move_range,
attack_range=attack_range,
attack_strength=attack_strength,
attack_accuracy=attack_accuracy,
view_range=view_range,
**kwargs
)


class PredatorPreyResourcesSim(SmartGridWorldSimulation):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.move_actor = MoveActor(**kwargs)
self.attack_actor = BinaryAttackActor(**kwargs)

self.finalize()

def step(self, action_dict, **kwargs):
# Process the attacks
for agent_id, action in action_dict.items():
agent = self.agents[agent_id]
if agent.active:
attack_status, attacked_agents = \
self.attack_actor.process_action(agent, action, **kwargs)
if attack_status: # Attack was attempted
if not attacked_agents: # Attack failed
self.rewards[agent_id] -= 0.1
else:
for attacked_agent in attacked_agents:
if not attacked_agent.active: # Agent has died
self.rewards[agent_id] += 1
if isinstance(attacked_agent, Agent):
self.rewards[attacked_agent.id] -= 1

# Process the moves
for agent_id, action in action_dict.items():
agent = self.agents[agent_id]
if agent.active:
move_result = self.move_actor.process_action(agent, action, **kwargs)
if not move_result:
self.rewards[agent_id] -= 0.1

# Entropy penalty
for agent_id in action_dict:
self.rewards[agent_id] -= 0.01
96 changes: 96 additions & 0 deletions examples/rllib_predator_prey_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

from abmarl.examples import ResourceAgent, PreyAgent, PredatorAgent, PredatorPreyResourcesSim
from abmarl.managers import AllStepManager
from abmarl.external import MultiAgentWrapper

resources = {
f'resource_{i}': ResourceAgent(id=f'resource_{i}') for i in range(11)
}
prey = {
f'prey_{i}': PreyAgent(id=f'prey_{i}') for i in range(5)
}
predators = {
f'predator_{i}': PredatorAgent(id=f'predator_{i}') for i in range(2)
}
agents = {**resources, **prey, **predators}

overlap_map = {
1: {2, 3},
2: {1, 2, 3},
3: {1, 2},
}
attack_map = {
2: {1},
3: {2}
}
sim = MultiAgentWrapper(
AllStepManager(
PredatorPreyResourcesSim.build_sim(
20, 20,
agents=agents,
overlapping=overlap_map,
attack_mapping=attack_map,
target_mapping=attack_map,
states={'PositionState', 'HealthState'},
observers={'PositionCenteredEncodingObserver'},
dones={'ActiveDone', 'TargetEncodingInactiveDone'}
),
randomize_action_input=True
)
)


sim_name = "PredatorPreyResources"
from ray.tune.registry import register_env
register_env(sim_name, lambda sim_config: sim)


policies = {
'prey': (None, prey['prey_0'].observation_space, prey['prey_0'].action_space, {}),
'predator': (
None, predators['predator_0'].observation_space, predators['predator_0'].action_space, {}
),
}


def policy_mapping_fn(agent_id):
if agents[agent_id].encoding == 1:
return 'prey'
if agents[agent_id].encoding == 2:
return 'predator'


# Experiment parameters
params = {
'experiment': {
'title': f'{sim_name}',
'sim_creator': lambda config=None: sim,
},
'ray_tune': {
'run_or_experiment': 'PPO',
'checkpoint_freq': 50,
'checkpoint_at_end': True,
'stop': {
'episodes_total': 20_000,
},
'verbose': 2,
'local_dir': 'output_dir',
'config': {
# --- Simulation ---
'disable_env_checking': False,
'env': sim_name,
'horizon': 200,
'env_config': {},
# --- Multiagent ---
'multiagent': {
'policies': policies,
'policy_mapping_fn': policy_mapping_fn,
},
# --- Parallelism ---
# Number of workers per experiment: int
"num_workers": 7,
# Number of simulations that each worker starts: int
"num_envs_per_worker": 1, # This must be 1 because we are not "threadsafe"
},
}
}

0 comments on commit 3b9e96b

Please sign in to comment.