Skip to content

Commit

Permalink
cleaned up multiagent
Browse files Browse the repository at this point in the history
  • Loading branch information
veds12 committed Oct 22, 2020
1 parent 25eb018 commit 256f705
Show file tree
Hide file tree
Showing 9 changed files with 564 additions and 16 deletions.
34 changes: 34 additions & 0 deletions genrl/agents/multiagent/base/offpolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import collections
from abc import ABC

import torch
import torch.nn as nn
import torch.optim as opt

from genrl.core import MultiAgentReplayBuffer
from genrl.utils import MutiAgentEnvInterface


class MultiAgentOffPolicy(ABC):
"""Base class for multiagent algorithms with OffPolicy agents
Attributes:
network (str): The network type of the Q-value function.
Supported types: ["cnn", "mlp"]
env (Environment): The environment that the agent is supposed to act on
agents (list) : A list of all the agents to be used
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
seed (int): Seed for randomness
render (bool): Should the env be rendered during training?
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""

raise NotImplementedError
Empty file.
133 changes: 133 additions & 0 deletions genrl/agents/multiagent/maddpg/maddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch

from genrl.agents import DDPG
from genrl.utils import MultiAgentReplayBuffer, PettingZooInterface, get_model


class MADDPG(ABC):
"""MultiAgent Controller using the MADDPG algorithm
Attributes:
network (str): The network type of the Q-value function of the agents.
Supported types: ["mlp"]
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the critic
replay_size (int): Capacity of the Replay Buffer
polyak (float): Target model update parameter (1 for hard update)
env (Environment): The environment that the agent is supposed to act on
replay_size (int): Capacity of the Replay Buffer
render (bool): Should the env be rendered during training?
noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration
noise_std (float): Standard deviation of the action noise distribution
seed (int): Seed for randomness
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""

def __init__(
self,
*args,
env,
replay_size: int = int(1e6),
render: bool = False,
noise: ActionNoise = None,
noise_std: float = 0.2,
warmup_steps=1000,
**kwargs,
):
self.env = env
self.network = network
self.num_agents = self.env.num_agents
self.replay_buffer = MultiAgentReplayBuffer(self.num_agents, buffer_maxlen)
self.EnvInterface = PettingZooInterface(self.env, self.agents)
self.render = render
self.warmup_steps = warmup_steps
self.shared_layers = shared_layers
ac = self._create_model()
self.agents = [
DDPG(ac, noise, noise_std, **kwargs) for agent in self.env.agents
]

def _create_model(self):
state_dim, action_dim, discrete, _ = self.EnvInterface.get_env_properties()
if discrete:
raise Exception(
"Discrete Environments not supported for {}.".format(__class__.__name__)
)
model = get_models("ac", self.network)(
state_dim, action_dim, self.shared_layers,
)

def update(self, batch_size):
(
obs_batch,
indiv_action_batch,
indiv_reward_batch,
next_obs_batch,
global_state_batch,
global_actions_batch,
global_next_state_batch,
done_batch,
) = self.replay_buffer.sample(batch_size)
for i in range(self.num_agents):
obs_batch_i = obs_batch[i]
indiv_action_batch_i = indiv_action_batch[i]
indiv_reward_batch_i = indiv_reward_batch[i]
next_obs_batch_i = next_obs_batch[i]
next_global_actions = []
(
next_obs_batch_i,
indiv_next_action,
next_global_actions,
) = self.EnvInterface.trainer(indiv_next_action)
next_global_actions = torch.cat(
[next_actions_i for next_actions_i in next_global_actions], 1
)
self.EnvInterface.update_agents(
indiv_reward_batch_i,
obs_batch_i,
global_state_batch,
global_actions_batch,
global_next_state_batch,
next_global_actions,
)

def train(self, max_episode, max_steps, batch_size):
episode_rewards = []
for episode in range(max_episode):
states = self.env.reset()
episode_reward = 0
step = -1
for step in range(max_steps):
if self.render:
self.env.render(mode="human")

step += 1
actions = self.EnvInterface.get_actions(states, steps, warmup_steps)
next_states, rewards, dones, _ = self.env.step(actions)
rewards = self.EnvInterface.flatten(rewards)
episode_reward += np.mean(agent_rewards)
dones = self.EnvInterface.flatten(dones)
if all(dones) or step == max_steps - 1:
dones = [1 for _ in range(self.num_agents)]
self.replay_buffer.push(
states, actions, rewards, next_states, dones
)
episode_rewards.append(episode_reward)
print(
f"Episode: {episode + 1} | Steps Taken: {step +1} | Reward {episode_reward}"
)
break
else:
dones = [0 for _ in range(self.num_agents)]
self.replay_buffer.push(
states, actions, rewards, next_states, dones
)
states = next_states
if len(self.replay_buffer) > batch_size:
self.update(batch_size)
2 changes: 2 additions & 0 deletions genrl/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa
from genrl.core.buffers import ReplayBuffer # noqa
from genrl.core.buffers import ReplayBufferSamples # noqa
from genrl.core.buffers import MultiAgentReplayBuffer
from genrl.core.noise import ActionNoise # noqa
from genrl.core.noise import NoisyLinear # noqa
from genrl.core.noise import NormalActionNoise # noqa
Expand All @@ -16,6 +17,7 @@
get_policy_from_name,
)
from genrl.core.rollout_storage import RolloutBuffer # noqa
from genrl.core.rollout_storage import MultiAgentRolloutBuffer # noqa
from genrl.core.values import ( # noqa
BaseValue,
CnnCategoricalValue,
Expand Down
120 changes: 111 additions & 9 deletions genrl/core/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,7 @@ def sample(

return [
torch.as_tensor(v, dtype=torch.float32)
for v in [
states,
actions,
rewards,
next_states,
dones,
indices,
weights,
]
for v in [states, actions, rewards, next_states, dones, indices, weights,]
]

def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> None:
Expand All @@ -181,3 +173,113 @@ def __len__(self) -> int:
@property
def pos(self):
return len(self.buffer)


class MultiAgentReplayBuffer:
"""
Implements the basic Experience Replay Mechanism for MultiAgents
by feeding in global states, global actions, global rewards,
global next_states, global dones
:param capacity: Size of the replay buffer
:type capacity: int
:param num_agents: Number of agents in the environment
:type num_agents: int
"""

def __init__(self, num_agents: int, capacity: int):
"""
Initialising the buffer
:param num_agents: number of agents in the environment
:type num_agents: int
:param capacity: Max buffer size
:type capacity: int
"""
self.capacity = capacity
self.num_agents = num_agents
self.buffer = deque(maxlen=self.capacity)

def push(self, inp: Tuple) -> None:
"""
Adds new experience to buffer
:param inp: (Tuple containing `state`, `action`, `reward`,
`next_state` and `done`)
:type inp: tuple
:returns: None
"""
self.buffer.append(inp)

def sample(self, batch_size):

"""
Returns randomly sampled experiences from replay memory
:param batch_size: Number of samples per batch
:type batch_size: int
:returns: (Tuple composing of `indiv_obs_batch`,
`indiv_action_batch`, `indiv_reward_batch`, `indiv_next_obs_batch`,
`global_state_batch`, `global_actions_batch`, `global_next_state_batch`,
`done_batch`)
"""
indiv_obs_batch = [
[] for _ in range(self.num_agents)
] # [ [states of agent 1], ... ,[states of agent n] ] ]
indiv_action_batch = [
[] for _ in range(self.num_agents)
] # [ [actions of agent 1], ... , [actions of agent n]]
indiv_reward_batch = [[] for _ in range(self.num_agents)]
indiv_next_obs_batch = [[] for _ in range(self.num_agents)]

global_state_batch = []
global_next_state_batch = []
global_actions_batch = []
done_batch = []

batch = random.sample(self.buffer, batch_size)

for experience in batch:
state, action, reward, next_state, done = experience

for i in range(self.num_agents):
indiv_obs_batch[i].append(state[i])
indiv_action_batch[i].append(action[i])
indiv_reward_batch[i].append(reward[i])
indiv_next_obs_batch[i].append(next_state[i])

global_state_batch.append(torch.cat(state))
global_actions_batch.append(torch.cat(action))
global_next_state_batch.append(torch.cat(next_state))
done_batch.append(done)

global_state_batch = torch.stack(global_state_batch)
global_actions_batch = torch.stack(global_actions_batch)
global_next_state_batch = torch.stack(global_next_state_batch)
done_batch = torch.stack(done_batch)
indiv_obs_batch = torch.stack(
[torch.FloatTensor(obs) for obs in indiv_obs_batch]
)
indiv_action_batch = torch.stack(
[torch.FloatTensor(act) for act in indiv_action_batch]
)
indiv_reward_batch = torch.stack(
[torch.FloatTensor(rew) for rew in indiv_reward_batch]
)
indiv_next_obs_batch = torch.stack(
[torch.FloatTensor(next_obs) for next_obs in indiv_next_obs_batch]
)

return (
indiv_obs_batch,
indiv_action_batch,
indiv_reward_batch,
indiv_next_obs_batch,
global_state_batch,
global_actions_batch,
global_next_state_batch,
done_batch,
)

def __len__(self):
"""
Gives number of experiences in buffer currently
:returns: Length of replay memory
"""
return len(self.buffer)
Loading

0 comments on commit 256f705

Please sign in to comment.