Skip to content

Commit

Permalink
Support MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
levje committed May 28, 2024
1 parent e409c5c commit f3953c4
Show file tree
Hide file tree
Showing 17 changed files with 85 additions and 51 deletions.
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer
from TrackToLearn.algorithms.shared.utils import add_item_to_means
from TrackToLearn.environments.env import BaseEnv

from TrackToLearn.utils.torch_utils import get_device

class DDPG(RLAlgorithm):
"""
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
"""
Parameters
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from TrackToLearn.environments.env import BaseEnv

from TrackToLearn.utils.torch_utils import get_device

class RLAlgorithm(object):
"""
Expand All @@ -18,7 +18,7 @@ def __init__(
gamma: float = 0.99,
batch_size: int = 10000,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
"""
Parameters
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from TrackToLearn.algorithms.ddpg import DDPG
from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer

from TrackToLearn.utils.torch_utils import get_device

class SAC(DDPG):
"""
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
""" Initialize the algorithm. This includes the replay buffer,
the policy and the target policy.
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/sac_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from TrackToLearn.algorithms.sac import SAC
from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer

from TrackToLearn.utils.torch_utils import get_device

LOG_STD_MAX = 2
LOG_STD_MIN = -20
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device,
):
"""
Parameters
Expand Down
40 changes: 28 additions & 12 deletions TrackToLearn/algorithms/shared/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch

from typing import Tuple
from TrackToLearn.utils.torch_utils import get_device, get_device_str


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_device()


class OffPolicyReplayBuffer(object):
Expand Down Expand Up @@ -33,16 +33,25 @@ def __init__(
self.size = 0

# Buffers "filled with zeros"


self.state = torch.zeros(
(self.max_size, state_dim), dtype=torch.float32).pin_memory()
(self.max_size, state_dim), dtype=torch.float32)
self.action = torch.zeros(
(self.max_size, action_dim), dtype=torch.float32).pin_memory()
(self.max_size, action_dim), dtype=torch.float32)
self.next_state = torch.zeros(
(self.max_size, state_dim), dtype=torch.float32).pin_memory()
(self.max_size, state_dim), dtype=torch.float32)
self.reward = torch.zeros(
(self.max_size, 1), dtype=torch.float32).pin_memory()
(self.max_size, 1), dtype=torch.float32)
self.not_done = torch.zeros(
(self.max_size, 1), dtype=torch.float32).pin_memory()
(self.max_size, 1), dtype=torch.float32)

if get_device_str() == "cuda":
self.state = self.state.pin_memory()
self.action = self.action.pin_memory()
self.next_state = self.next_state.pin_memory()
self.reward = self.reward.pin_memory()
self.not_done = self.not_done.pin_memory()

def add(
self,
Expand Down Expand Up @@ -112,12 +121,19 @@ def sample(
ind = torch.randperm(self.size, dtype=torch.long)[
:min(self.size, batch_size)]

s = self.state.index_select(0, ind).pin_memory()
a = self.action.index_select(0, ind).pin_memory()
ns = self.next_state.index_select(0, ind).pin_memory()
r = self.reward.index_select(0, ind).squeeze(-1).pin_memory()
s = self.state.index_select(0, ind)
a = self.action.index_select(0, ind)
ns = self.next_state.index_select(0, ind)
r = self.reward.index_select(0, ind).squeeze(-1)
d = self.not_done.index_select(0, ind).to(
dtype=torch.float32).squeeze(-1).pin_memory()
dtype=torch.float32).squeeze(-1)

if get_device_str() == "cuda":
s = s.pin_memory()
a = a.pin_memory()
ns = ns.pin_memory()
r = r.pin_memory()
d = d.pin_memory()

# Return tensors on the same device as the buffer in pinned memory
return (s.to(device=self.device, non_blocking=True),
Expand Down
5 changes: 2 additions & 3 deletions TrackToLearn/environments/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

# from dipy.io.utils import get_reference_info

def collate_fn(data):
return data

class BaseEnv(object):
"""
Expand Down Expand Up @@ -84,9 +86,6 @@ def __init__(
self.dataset_file = subject_data
self.split = split_id

def collate_fn(data):
return data

self.dataset = SubjectDataset(
self.dataset_file, self.split)
self.loader = DataLoader(self.dataset, 1, shuffle=True,
Expand Down
9 changes: 6 additions & 3 deletions TrackToLearn/oracles/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from dipy.tracking.streamline import set_number_of_points

from TrackToLearn.oracles.transformer_oracle import TransformerOracle
from TrackToLearn.utils.torch_utils import get_device_str, get_device
import contextlib

autocast_context = torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext

class OracleSingleton:
_self = None
Expand All @@ -15,7 +18,7 @@ def __new__(cls, *args, **kwargs):
return cls._self

def __init__(self, checkpoint: str, device: str, batch_size=4096):
self.checkpoint = torch.load(checkpoint)
self.checkpoint = torch.load(checkpoint, map_location=get_device())

hyper_parameters = self.checkpoint["hyper_parameters"]
# The model's class is saved in hparams
Expand All @@ -38,7 +41,7 @@ def predict(self, streamlines):
N = len(streamlines)
# Placeholders for input and output data
placeholder = torch.zeros(
(self.batch_size, 127, 3), pin_memory=True)
(self.batch_size, 127, 3), pin_memory=get_device_str() == "cuda")
result = torch.zeros((N), dtype=torch.float, device=self.device)

# Get the first batch
Expand Down Expand Up @@ -70,7 +73,7 @@ def predict(self, streamlines):
# Put the directions in pinned memory
placeholder[:end-start] = torch.from_numpy(dirs)

with torch.cuda.amp.autocast():
with autocast_context():
with torch.no_grad():
predictions = self.model(input_data)
result[
Expand Down
5 changes: 2 additions & 3 deletions TrackToLearn/runners/ttl_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from TrackToLearn.experiment.experiment import Experiment
from TrackToLearn.tracking.tracker import Tracker
from TrackToLearn.utils.torch_utils import get_device

# Define the example model paths from the install folder.
# Hackish ? I'm not aware of a better solution but I'm
Expand Down Expand Up @@ -79,9 +80,7 @@ def __init__(
self.compute_reward = False
self.render = False

self.device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
self.device = get_device()

self.fa_map = None
if 'fa_map_file' in track_dto:
Expand Down
5 changes: 2 additions & 3 deletions TrackToLearn/runners/ttl_track_from_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
add_tractometer_args)
from TrackToLearn.tracking.tracker import Tracker
from TrackToLearn.experiment.experiment import Experiment

from TrackToLearn.utils.torch_utils import get_device

class TrackToLearnValidation(Experiment):
""" TrackToLearn validing script. Should work on any model trained with a
Expand Down Expand Up @@ -98,8 +98,7 @@ def __init__(

self.comet_experiment = None

self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.device = get_device()

self.random_seed = valid_dto['rng_seed']
torch.manual_seed(self.random_seed)
Expand Down
6 changes: 3 additions & 3 deletions TrackToLearn/searchers/sac_auto_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from TrackToLearn.trainers.sac_auto_train import (
parse_args,
SACAutoTrackToLearnTraining)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator
device = get_device()
assert_accelerator()


def main():
Expand Down
6 changes: 3 additions & 3 deletions TrackToLearn/searchers/sac_auto_searcher_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from TrackToLearn.trainers.sac_auto_train import (
parse_args,
SACAutoTrackToLearnTraining)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator
device = get_device()
assert_accelerator()


def main():
Expand Down
5 changes: 3 additions & 2 deletions TrackToLearn/trainers/ddpg_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from TrackToLearn.algorithms.ddpg import DDPG
from TrackToLearn.experiment.train import (
add_training_args, TrackToLearnTraining)
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
device = get_device()
assert_accelerator()


class DDPGTrackToLearnTraining(TrackToLearnTraining):
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/trainers/sac_auto_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from TrackToLearn.algorithms.sac_auto import SACAuto
from TrackToLearn.trainers.train import (TrackToLearnTraining,
add_training_args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from TrackToLearn.utils.torch_utils import get_device
device = get_device()


class SACAutoTrackToLearnTraining(TrackToLearnTraining):
Expand Down
6 changes: 3 additions & 3 deletions TrackToLearn/trainers/sac_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from TrackToLearn.experiment.train import (
add_rl_args,
TrackToLearnTraining)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator
device = get_device()
assert_accelerator()


class SACTrackToLearnTraining(TrackToLearnTraining):
Expand Down
5 changes: 3 additions & 2 deletions TrackToLearn/trainers/td3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from TrackToLearn.experiment.train import (
add_rl_args,
TrackToLearnTraining)
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
device = get_device()
assert_accelerator()


class TD3TrackToLearnTraining(TrackToLearnTraining):
Expand Down
9 changes: 5 additions & 4 deletions TrackToLearn/trainers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from TrackToLearn.experiment.tractometer_validator import TractometerValidator
from TrackToLearn.experiment.experiment import Experiment
from TrackToLearn.tracking.tracker import Tracker
from TrackToLearn.utils.torch_utils import get_device, assert_accelerator


class TrackToLearnTraining(Experiment):
Expand Down Expand Up @@ -99,8 +100,8 @@ def __init__(
self.comet_experiment = comet_experiment
self.last_episode = 0

self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.device = get_device()


self.use_comet = train_dto['use_comet']

Expand Down Expand Up @@ -353,8 +354,8 @@ def run(self):
training loop
"""

assert torch.cuda.is_available(), \
"Training is only supported on CUDA devices."
assert_accelerator(), \
"Training is only supported with hardware accelerated devices."

# Instantiate environment. Actions will be fed to it and new
# states will be returned. The environment updates the streamline
Expand Down
15 changes: 15 additions & 0 deletions TrackToLearn/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")

def assert_accelerator():
assert torch.cuda.is_available() or torch.backends.mps.is_available()

def get_device_str():
return str(get_device())

0 comments on commit f3953c4

Please sign in to comment.