Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Support for agents #388

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1d49049
Single actor critic shared params
hades-rp2010 Sep 1, 2020
ef4a179
Shared layers for multi ACs
hades-rp2010 Sep 1, 2020
2ecd086
Merge branch 'master' of https://github.com/SforAiDl/genrl
hades-rp2010 Sep 1, 2020
53450a8
Fix lint errors (1)
hades-rp2010 Sep 1, 2020
274aff9
Fixed tests
hades-rp2010 Sep 1, 2020
38f95f0
Changes to dicstrings and classes
hades-rp2010 Sep 2, 2020
835819e
Renaming Multi -> Two and comments
hades-rp2010 Sep 4, 2020
43ed950
Remove compute_advantage from rollout buffer class
hades-rp2010 Sep 6, 2020
edf1b07
Merge branch 'master' into new
hades-rp2010 Sep 6, 2020
2f0a749
Remove duplication
hades-rp2010 Sep 6, 2020
86f89d6
Merge
hades-rp2010 Sep 7, 2020
aa305e7
Fixing tests (1)
hades-rp2010 Sep 7, 2020
dc3dc18
unified gae and normal adv
hades-rp2010 Sep 8, 2020
1a802fc
Shift fn to OnPolicyAgent
hades-rp2010 Sep 9, 2020
6609952
Remove redundant line
hades-rp2010 Sep 9, 2020
6453ee0
New file distributed.py
hades-rp2010 Sep 9, 2020
3160e73
Fix LGTM
hades-rp2010 Sep 10, 2020
4d7ec9c
Merge branch 'master' of https://github.com/SforAiDl/genrl into new
hades-rp2010 Sep 11, 2020
6d8ed2d
Docstring
hades-rp2010 Sep 11, 2020
c94a9a1
Merge branch 'master' of https://github.com/SforAiDl/genrl
hades-rp2010 Sep 12, 2020
bf71710
Adding tutorial
hades-rp2010 Sep 12, 2020
fc356b9
Small change
hades-rp2010 Sep 12, 2020
844c53d
Index
hades-rp2010 Sep 13, 2020
920d570
Merge branch 'master' of https://github.com/SforAiDl/genrl into new
hades-rp2010 Sep 13, 2020
d15b9db
Adding return statement
hades-rp2010 Oct 10, 2020
ca28fa4
Fix discount.py
hades-rp2010 Oct 12, 2020
f1c5673
Merge branch 'master' of https://github.com/SforAiDl/genrl into new
hades-rp2010 Oct 12, 2020
b9207af
Merge branch 'master' of https://github.com/SforAiDl/genrl into CUDA
hades-rp2010 Nov 23, 2020
29b6134
CUDA support for onp agents
hades-rp2010 Nov 23, 2020
25adb47
Merge branch 'new' of https://github.com/hades-rp2010/genrl into CUDA
hades-rp2010 Nov 24, 2020
5b915dd
All agents except DQN
hades-rp2010 Nov 25, 2020
eae42fd
Update to miniconda-v2
hades-rp2010 Nov 27, 2020
6ccd463
Update to miniconda-v2 (2)
hades-rp2010 Nov 27, 2020
f79f41e
Update to miniconda-v2 (3)
hades-rp2010 Nov 27, 2020
d62f44e
Update to miniconda-v2 (4)
hades-rp2010 Nov 27, 2020
b2f71a6
Update to miniconda-v2 (4)
hades-rp2010 Nov 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- uses: goanpeca/setup-miniconda@v1
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- uses: goanpeca/setup-miniconda@v1
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- uses: goanpeca/setup-miniconda@v1
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
rev: 20.8b1
hooks:
- id: black
language_version: python3.7
language_version: python3.8

- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
Expand Down
10 changes: 5 additions & 5 deletions genrl/agents/deep/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def select_action(
action, dist = self.ac.get_action(state, deterministic=deterministic)
value = self.ac.get_value(state)

return action.detach(), value, dist.log_prob(action).cpu()
return action.detach(), value, dist.log_prob(action)

def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None:
"""Get loss from trajectory traversed by agent during rollouts
Expand All @@ -129,8 +129,8 @@ def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None:
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`list` of bool): Game over statuses of each environment
"""
compute_returns_and_advantage(
self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy()
self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage(
self.rollout, values.detach(), dones.to(self.device)
)

def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor):
Expand All @@ -150,7 +150,7 @@ def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor):
states, actions = states.to(self.device), actions.to(self.device)
_, dist = self.ac.get_action(states, deterministic=False)
values = self.ac.get_value(states)
return values, dist.log_prob(actions).cpu(), dist.entropy().cpu()
return values, dist.log_prob(actions), dist.entropy()

def update_params(self) -> None:
"""Updates the the A2C network
Expand All @@ -171,7 +171,7 @@ def update_params(self) -> None:
policy_loss = -torch.mean(policy_loss)
self.logs["policy_loss"].append(policy_loss.item())

value_loss = self.value_coeff * F.mse_loss(rollout.returns, values.cpu())
value_loss = self.value_coeff * F.mse_loss(rollout.returns, values)
self.logs["value_loss"].append(torch.mean(value_loss).item())

entropy_loss = -torch.mean(entropy) # Change this to entropy
Expand Down
23 changes: 18 additions & 5 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.replay_buffer = PrioritizedBuffer(self.replay_size)
else:
raise NotImplementedError
# self.replay_buffer = self.replay_buffer.to(self.device)

def update_params_before_select_action(self, timestep: int) -> None:
"""Update any parameters before selecting action like epsilon for decaying epsilon greedy
Expand Down Expand Up @@ -107,6 +108,7 @@ def sample_from_buffer(self, beta: float = None):
)
else:
raise NotImplementedError
# print(batch.device)
return batch

def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Expand All @@ -118,9 +120,13 @@ def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Returns:
loss (:obj:`torch.Tensor`): Calculated loss of the Q-function
"""
q_values = self.get_q_values(batch.states, batch.actions)
q_values = self.get_q_values(
batch.states.to(self.device), batch.actions.to(self.device)
)
target_q_values = self.get_target_q_values(
batch.next_states, batch.rewards, batch.dones
batch.next_states.to(self.device),
batch.rewards.to(self.device),
batch.dones.to(self.device),
)
loss = F.mse_loss(q_values, target_q_values)
return loss
Expand Down Expand Up @@ -167,15 +173,16 @@ def select_action(
Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
"""
state = state.to(self.device)
action, _ = self.ac.get_action(state, deterministic)
action = action.detach()

# add noise to output from policy network
if self.noise is not None:
action += self.noise()
action += self.noise().to(self.device)

return torch.clamp(
action, self.env.action_space.low[0], self.env.action_space.high[0]
action.cpu(), self.env.action_space.low[0], self.env.action_space.high[0]
)

def update_target_model(self) -> None:
Expand All @@ -199,6 +206,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
Returns:
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
states, actions = states.to(self.device), actions.to(self.device)
if self.doublecritic:
q_values = self.ac.get_value(
torch.cat([states, actions], dim=-1), mode="both"
Expand All @@ -221,6 +229,7 @@ def get_target_q_values(
Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3
"""
next_states = next_states.to(self.device)
next_target_actions = self.ac_target.get_action(next_states, True)[0]

if self.doublecritic:
Expand All @@ -231,7 +240,10 @@ def get_target_q_values(
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1)
)
target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values
target_q_values = (
rewards.to(self.device)
+ self.gamma * (1 - dones.to(self.device)) * next_q_target_values
)

return target_q_values

Expand Down Expand Up @@ -265,6 +277,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
Returns:
loss (:obj:`torch.Tensor`): Calculated policy loss
"""
states = states.to(self.device)
next_best_actions = self.ac.get_action(states, True)[0]
q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1))
policy_loss = -torch.mean(q_values)
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/base/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

if buffer_type == "rollout":
self.rollout = RolloutBuffer(
self.rollout_size, self.env, gae_lambda=gae_lambda
self.rollout_size, self.env, gae_lambda=gae_lambda, device=self.device
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -73,7 +73,7 @@ def collect_rollouts(self, state: torch.Tensor):
dones (:obj:`torch.Tensor`): Game over statuses of each environment
"""
for i in range(self.rollout_size):
action, values, old_log_probs = self.select_action(state)
action, values, old_log_probs = self.select_action(state.to(self.device))

next_state, reward, dones, _ = self.env.step(action)

Expand Down
12 changes: 6 additions & 6 deletions genrl/agents/deep/ppo1/ppo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def select_action(
action, dist = self.ac.get_action(state, deterministic=deterministic)
value = self.ac.get_value(state)

return action.detach(), value, dist.log_prob(action).cpu()
return action.detach(), value, dist.log_prob(action)

def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor):
"""Evaluates actions taken by actor
Expand All @@ -132,7 +132,7 @@ def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor):
states, actions = states.to(self.device), actions.to(self.device)
_, dist = self.ac.get_action(states, deterministic=False)
values = self.ac.get_value(states)
return values, dist.log_prob(actions).cpu(), dist.entropy().cpu()
return values, dist.log_prob(actions), dist.entropy()

def get_traj_loss(self, values, dones):
"""Get loss from trajectory traversed by agent during rollouts
Expand All @@ -143,10 +143,10 @@ def get_traj_loss(self, values, dones):
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`list` of bool): Game over statuses of each environment
"""
compute_returns_and_advantage(
self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage(
self.rollout,
values.detach().cpu().numpy(),
dones.cpu().numpy(),
values.detach(),
dones.to(self.device),
use_gae=True,
)

Expand Down Expand Up @@ -180,7 +180,7 @@ def update_params(self):
values = values.flatten()

value_loss = self.value_coeff * nn.functional.mse_loss(
rollout.returns, values.cpu()
rollout.returns, values
)
self.logs["value_loss"].append(torch.mean(value_loss).item())

Expand Down
20 changes: 13 additions & 7 deletions genrl/agents/deep/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def _create_model(self, **kwargs) -> None:
else:
self.action_scale = torch.FloatTensor(
(self.env.action_space.high - self.env.action_space.low) / 2.0
)
).to(self.device)
self.action_bias = torch.FloatTensor(
(self.env.action_space.high + self.env.action_space.low) / 2.0
)
).to(self.device)

if isinstance(self.network, str):
state_dim, action_dim, discrete, _ = get_env_properties(
Expand All @@ -89,7 +89,7 @@ def _create_model(self, **kwargs) -> None:
sac=True,
action_scale=self.action_scale,
action_bias=self.action_bias,
)
).to(self.device)
else:
self.model = self.network

Expand All @@ -102,7 +102,7 @@ def _create_model(self, **kwargs) -> None:
self.target_entropy = -torch.prod(
torch.Tensor(self.env.action_space.shape)
).item()
self.log_alpha = torch.zeros(1, requires_grad=True)
self.log_alpha = torch.zeros(1, device=self.device, requires_grad=True)
self.optimizer_alpha = opt.Adam([self.log_alpha], lr=self.lr_policy)

def select_action(
Expand All @@ -119,8 +119,9 @@ def select_action(
Returns:
action (:obj:`np.ndarray`): Action taken by the agent
"""
action, _, _ = self.ac.get_action(state, deterministic)
return action.detach()
state = state.to(self.device)
action, _, _ = self.ac.get_action(state.to(self.device), deterministic)
return action.detach().cpu()

def update_target_model(self) -> None:
"""Function to update the target Q model
Expand All @@ -147,11 +148,15 @@ def get_target_q_values(
Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the SAC
"""
next_states = next_states.to(self.device)
next_target_actions, next_log_probs, _ = self.ac.get_action(next_states)
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1), mode="min"
).squeeze() - self.alpha * next_log_probs.squeeze(1)
target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values
target_q_values = (
rewards.to(self.device)
+ self.gamma * (1 - dones.to(self.device)) * next_q_target_values
)
return target_q_values

def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
Expand All @@ -163,6 +168,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
Returns:
loss (:obj:`torch.Tensor`): Calculated policy loss
"""
states = states.to(self.device)
next_best_actions, log_probs, _ = self.ac.get_action(states)
q_values = self.ac.get_value(
torch.cat([states, next_best_actions], dim=-1), mode="min"
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def _create_model(self) -> None:
value_layers=self.value_layers,
val_type="Qsa",
discrete=False,
)
).to(self.device)
else:
self.ac = self.network
self.ac = self.network.to(self.device)

if self.noise is not None:
self.noise = self.noise(
Expand Down
10 changes: 5 additions & 5 deletions genrl/agents/deep/vpg/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def select_action(

return (
action.detach(),
torch.zeros((1, self.env.n_envs)),
dist.log_prob(action).cpu(),
torch.zeros((1, self.env.n_envs), device=self.device),
dist.log_prob(action),
)

def get_log_probs(self, states: torch.Tensor, actions: torch.Tensor):
Expand All @@ -105,7 +105,7 @@ def get_log_probs(self, states: torch.Tensor, actions: torch.Tensor):
"""
states, actions = states.to(self.device), actions.to(self.device)
_, dist = self.actor.get_action(states, deterministic=False)
return dist.log_prob(actions).cpu()
return dist.log_prob(actions)

def get_traj_loss(self, values, dones):
"""Get loss from trajectory traversed by agent during rollouts
Expand All @@ -116,8 +116,8 @@ def get_traj_loss(self, values, dones):
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`list` of bool): Game over statuses of each environment
"""
compute_returns_and_advantage(
self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy()
self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage(
self.rollout, values.detach().to(self.device), dones.to(self.device)
)

def update_params(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions genrl/core/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class ReplayBuffer:
:type capacity: int
"""

def __init__(self, capacity: int):
def __init__(self, capacity: int, device="cpu"):
self.capacity = capacity
self.memory = deque([], maxlen=capacity)
self.device = device

def push(self, inp: Tuple) -> None:
"""
Expand All @@ -60,7 +61,7 @@ def sample(
batch = random.sample(self.memory, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return [
torch.from_numpy(v).float()
torch.from_numpy(v).float().to(self.device)
for v in [state, action, reward, next_state, done]
]

Expand Down
Loading