Skip to content

Commit

Permalink
Multi-GPU support for unsupervised learning (#207)
Browse files Browse the repository at this point in the history
* Cleanup device throughout the codebase

1. PCALoss default device changed to cuda:LOCAL_RANK for multi-gpu
2. Temporal loss device changed to use device of input arguments for
   multi-gpu
3. Like above, generally functions should prefer to use the device of
   the input arguments.
3. Removed _TORCH_DEVICE for the most part.
4. Removed LightningModule parent class of Loss
5. Fix test_train chdir causing subsequent tests to fail on file not
   found errors.

* Multi-GPU support for unsupervised learning

* Update docs for unsupervised multi-GPU

* Cleanup omegaconf.create

* sort imports

* PR comments

* batch size division ceiling

* context batch size adjustment

* update docs

* add doc file

* fold pytest.ini into setup.cfg and update docs
  • Loading branch information
ksikka authored Oct 24, 2024
1 parent 84538bf commit ba95702
Show file tree
Hide file tree
Showing 28 changed files with 496 additions and 306 deletions.
2 changes: 1 addition & 1 deletion docs/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## General enhancements
- [x] multi-GPU training for supervised models ([#206](https://github.com/paninski-lab/lightning-pose/pull/206))
- [ ] multi-GPU training for unsupervised models
- [x] multi-GPU training for unsupervised models ([#207](https://github.com/paninski-lab/lightning-pose/pull/207))
- [ ] introduce jaxtyping (see [here](https://github.com/google/jaxtyping/issues/70))

## Losses and backbones
Expand Down
2 changes: 1 addition & 1 deletion docs/source/faqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ FAQs
There are a few techniques available to reduce the memory consumption:

* Reduce ``train_batch_size``. Memory usage is directly proportional to batch size.
* Enable multi-GPU training (only supported for supervised training) using ``num_gpus``.
* Enable :ref:`multi-GPU training <multi_gpu_training>` using ``num_gpus``.
* Reduce image resolution using ``image_resize_dims``.
* Enable gradient accumulation using ``accumulate_grad_batches``. This parameter is not included
in the config by default and should be added manually to the ``training`` section.
Expand Down
5 changes: 2 additions & 3 deletions docs/source/user_guide/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ Below is a list of some commonly modified arguments related to model architectur
Therefore, 300 epochs, at 40 batches per epoch, is equal to 300*40=12k total batches
(or iterations).

* ``training.num_gpus``: enables multi-GPU training, only supported for supervised losses
(``losses_to_use: []``) . ``train_batch_size`` and ``val_batch_size`` must be divisible by
``num_gpus`` as each batch will be divided amongst the GPUs.
.. _config_num_gpus:
* ``training.num_gpus``: the number of GPUs for :ref:``multi-GPU training <multi_gpu_training>``.

* ``training.accumulate_grad_batches``: (experimental) number of batches to accumulate gradients
for before updating weights. Simulates larger batch sizes with memory-constrained GPUs. This
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide_advanced/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ For each feature, we point out necessary modifications and useful information re
context_frames
multiview_fused
multiview_separate
multi_gpu_training
78 changes: 78 additions & 0 deletions docs/source/user_guide_advanced/multi_gpu_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
.. _multi_gpu_training:

###################
Multi-GPU Training
###################

Multi-GPU training allows you to distribute the load of model training across GPUs.
This helps overcome OOMs in addition to accelerating training.

To use this feature, set :ref:`num_gpus <config_num_gpus>` in your config file.

How to choose batch_size
========================

Multi-GPU training distributes batches across multiple GPUs in a way that maintains the same
effective batch size as if you ran on 1 GPU. **Thus, if you reduced batch size in order to make
your model fit in one GPU, you should increase it back to your desired effective batch size.**

The batch size configuration parameters that this applies to are ``training.train_batch_size`` and
``training.val_batch_size`` for the labeled frames, and ``dali.train.base.sequence_length`` and
``dali.train.context.batch_size`` for unlabeled video frames. Test batch sizes are not relevant
to this document as testing only occurs on one GPU.

Calculate of per-GPU batch size
-------------------------------

Given the above, you need not worry about how lightning-pose calculates per-GPU batch size,
but it is documented here for transparency.

In general the per-GPU batch size will be:

.. code-block:: python
ceil(batch_size / num_gpus)
The exception to this is the unlabeled per-GPU batch size for context models (``heatmap_mhcrnn``):

.. code-block:: python
ceil((batch_size - 4) / num_gpus) + 4
The adjusted calculation for the unlabeled batch size for context models maintains the same
single-GPU effective batch size by accounting for the 4 context frames that are loaded with each
training frame.
For example, if you specified `dali.context.train.batch_size=16`, then your effective batch size
was 16 - 4 = 12.
To maintain 12 with 2 GPUs, each GPU will load 6 frames + 4 context frames, for a per-GPU batch
size of 10.
This is larger than simply dividing the original batch size of 16 across 2 GPUs.

.. _execution_model:

Execution model
===============

.. warning::
The implementation spawns ``num_gpus - 1`` processes of the same command originally executed,
repeating all of the command's execution per process.
Thus it is advised to only run multi-GPU training in a dedicated training script
(``scripts/train_hydra.py``). If you use lightning-pose as part of a custom script and don't
want your entire script to run once per GPU, your script should run ``scripts/train_hydra.py``
rather than directly calling the ``train`` method.

Tensorboard metric calculation
==============================

All metrics can be interpreted the same way as with a single-GPU.
The metrics are the average value across the GPUs.

Specifying the GPUs to run on
=============================

Use the environment variable ``CUDA_VISIBLE_DEVICES`` if you want lightning pose to run on certain
GPUs. For example, if you want to train on only the first two GPUs on your machine,

.. code-block:: bash
CUDA_VISIBLE_DEVICES=0,1 python scripts/train_hydra.py
35 changes: 16 additions & 19 deletions lightning_pose/data/dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,18 @@
"PrepareDALI",
]

_DALI_DEVICE = "gpu" if torch.cuda.is_available() else "cpu"


# cannot typecheck due to way pipeline_def decorator consumes additional args
@pipeline_def
def video_pipe(
filenames: Union[List[str], str],
resize_dims: Optional[List[int]] = None,
random_shuffle: bool = False,
seed: int = 123456,
sequence_length: int = 16,
pad_sequences: bool = True,
initial_fill: int = 16,
normalization_mean: List[float] = _IMAGENET_MEAN,
normalization_std: List[float] = _IMAGENET_STD,
device: str = _DALI_DEVICE,
name: str = "reader",
step: int = 1,
pad_last_batch: bool = False,
Expand All @@ -60,7 +56,6 @@ def video_pipe(
initial_fill: size of the buffer that is used for random shuffling
normalization_mean: mean values in (0, 1) to subtract from each channel
normalization_std: standard deviation values to subtract from each channel
device: "cpu" | "gpu"
name: pipeline name, used to string together DataNode elements
step: number of frames to advance on each read
pad_last_batch:
Expand Down Expand Up @@ -89,10 +84,9 @@ def video_pipe(
orig_size_list = []
for f, filename_list in enumerate(filenames):
video = fn.readers.video(
device=device,
device="gpu",
filenames=filename_list,
random_shuffle=random_shuffle,
seed=seed,
sequence_length=sequence_length,
step=step,
pad_sequences=pad_sequences,
Expand Down Expand Up @@ -325,6 +319,9 @@ def _setup_pipe_dict(
imgaug: str,
) -> Dict[str, dict]:
"""All of the pipeline args in one place."""
# When running with multiple GPUs, the LOCAL_RANK variable correctly
# contains the DDP Local Rank, which is also the cuda device index.
device_id = int(os.environ.get("LOCAL_RANK", "0"))

dict_args = {
"predict": {"context": {}, "base": {}},
Expand All @@ -340,11 +337,11 @@ def _setup_pipe_dict(
"sequence_length": base_train_cfg["sequence_length"],
"step": base_train_cfg["sequence_length"],
"batch_size": 1,
"seed": gen_cfg["seed"],
# Multi-GPU strategy is to have each GPU randomize differently.
"seed": gen_cfg["seed"] + device_id,
"num_threads": self.num_threads,
"device_id": 0,
"device_id": device_id,
"random_shuffle": True,
"device": "gpu",
"imgaug": imgaug,
}

Expand All @@ -356,11 +353,11 @@ def _setup_pipe_dict(
"sequence_length": base_pred_cfg["sequence_length"],
"step": base_pred_cfg["sequence_length"],
"batch_size": 1,
"seed": gen_cfg["seed"],
# Multi-GPU strategy is to have each GPU randomize differently.
"seed": gen_cfg["seed"] + device_id,
"num_threads": self.num_threads,
"device_id": 0,
"device_id": device_id,
"random_shuffle": False,
"device": "gpu",
"name": "reader",
"pad_sequences": True,
"imgaug": "default", # no imgaug when predicting
Expand All @@ -375,11 +372,11 @@ def _setup_pipe_dict(
"step": context_pred_cfg["sequence_length"] - 4,
"batch_size": 1,
"num_threads": self.num_threads,
"device_id": 0,
"device_id": device_id,
"random_shuffle": False,
"device": "gpu",
"name": "reader",
"seed": gen_cfg["seed"],
# Multi-GPU strategy is to have each GPU randomize differently.
"seed": gen_cfg["seed"] + device_id,
"pad_sequences": True,
# "pad_last_batch": True,
"imgaug": "default", # no imgaug when predicting
Expand All @@ -396,11 +393,11 @@ def _setup_pipe_dict(
"sequence_length": context_train_cfg["batch_size"],
"step": context_train_cfg["batch_size"],
"batch_size": 1,
"seed": gen_cfg["seed"],
# Multi-GPU strategy is to have each GPU randomize differently.
"seed": gen_cfg["seed"] + device_id,
"num_threads": self.num_threads,
"device_id": 0,
"device_id": device_id,
"random_shuffle": True,
"device": "gpu",
"imgaug": imgaug,
}
# our floor above should prevent us from getting to the very final batch.
Expand Down
2 changes: 0 additions & 2 deletions lightning_pose/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"UnlabeledDataModule",
]

_TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class BaseDataModule(pl.LightningDataModule):
"""Splits a labeled dataset into train, val, and test data loaders."""
Expand Down
2 changes: 0 additions & 2 deletions lightning_pose/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
"MultiviewHeatmapDataset",
]

_TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class BaseTrackingDataset(torch.utils.data.Dataset):
"""Base dataset that contains images and keypoints as (x, y) pairs."""
Expand Down
48 changes: 29 additions & 19 deletions lightning_pose/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import os
import warnings
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -49,11 +50,15 @@
"get_loss_classes",
]

_TORCH_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_DEFAULT_TORCH_DEVICE = "cpu"
if torch.cuda.is_available():
# When running with multiple GPUs, the LOCAL_RANK variable correctly
# contains the DDP Local Rank, which is also the cuda device index.
_DEFAULT_TORCH_DEVICE = f"cuda:{int(os.environ.get('LOCAL_RANK', '0'))}"


# @typechecked
class Loss(pl.LightningModule):
class Loss:
"""Parent class for all losses."""

def __init__(
Expand All @@ -75,8 +80,8 @@ def __init__(
super().__init__()
self.data_module = data_module
# epsilon can either by a float or a list of floats
self.epsilon = torch.tensor(epsilon, dtype=torch.float, device=self.device)
self.log_weight = torch.tensor(log_weight, dtype=torch.float, device=self.device)
self.epsilon = torch.tensor(epsilon, dtype=torch.float)
self.log_weight = torch.tensor(log_weight, dtype=torch.float)
self.loss_name = "base"

self.reduce_methods_dict = {"mean": torch.mean, "sum": torch.sum}
Expand Down Expand Up @@ -178,7 +183,6 @@ def __call__(
elementwise_loss = self.compute_loss(
targets=clean_targets, predictions=clean_predictions
)
# epsilon_insensitive_loss = self.rectify_epsilon(loss=elementwise_loss)
scalar_loss = self.reduce_loss(elementwise_loss, method="mean")
logs = self.log_loss(loss=scalar_loss, stage=stage)

Expand Down Expand Up @@ -279,10 +283,11 @@ def __init__(
columns_for_singleview_pca: Optional[Union[ListConfig, List]] = None,
data_module: Optional[Union[BaseDataModule, UnlabeledDataModule]] = None,
log_weight: float = 0.0,
device: str = _TORCH_DEVICE,
device: Union[Literal["cuda", "cpu"], torch.device] = _DEFAULT_TORCH_DEVICE,
**kwargs,
) -> None:
super().__init__(data_module=data_module, log_weight=log_weight)
self.device = device
self.loss_name = loss_name

if loss_name == "pca_multiview":
Expand Down Expand Up @@ -337,6 +342,10 @@ def compute_loss(
self,
predictions: TensorType["num_samples", "sample_dim"],
) -> TensorType["num_samples", -1]:
assert predictions.device == torch.device(self.device), (
predictions.device,
torch.device(self.device),
)
# compute either reprojection error or projection onto discarded evecs.
# they will vary in the last dim, hence -1.
return self.pca.compute_reprojection_error(data_arr=predictions)
Expand All @@ -347,7 +356,10 @@ def __call__(
stage: Optional[Literal["train", "val", "test"]] = None,
**kwargs,
) -> Tuple[TensorType[()], List[dict]]:

assert keypoints_pred.device == torch.device(self.device), (
keypoints_pred.device,
torch.device(self.device),
)
keypoints_pred = self.pca._format_data(data_arr=keypoints_pred)
elementwise_loss = self.compute_loss(predictions=keypoints_pred)
epsilon_insensitive_loss = self.rectify_epsilon(loss=elementwise_loss)
Expand All @@ -374,7 +386,7 @@ def __init__(
) -> None:
super().__init__(data_module=data_module, epsilon=epsilon, log_weight=log_weight)
self.loss_name = "temporal"
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float, device=self.device)
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)

def rectify_epsilon(
self, loss: TensorType["batch_minus_one", "num_keypoints"]
Expand All @@ -398,8 +410,10 @@ def remove_nans(
idxs_ignore = confidences < self.prob_threshold
# ignore the loss values in the diff where one of the heatmaps is 'nan'
union_idxs_ignore = torch.zeros(
(confidences.shape[0] - 1, confidences.shape[1]), dtype=torch.bool
).to(_TORCH_DEVICE)
(confidences.shape[0] - 1, confidences.shape[1]),
dtype=torch.bool,
device=loss.device,
)
for i in range(confidences.shape[0] - 1):
union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1])

Expand Down Expand Up @@ -472,9 +486,7 @@ def __init__(
else:
raise NotImplementedError

self.prob_threshold = torch.tensor(
prob_threshold, dtype=torch.float, device=self.device
)
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)

def rectify_epsilon(
self, loss: TensorType["batch_minus_one", "num_valid_keypoints"]
Expand All @@ -499,7 +511,7 @@ def remove_nans(
# ignore the loss values in the diff where one of the heatmaps is 'nan'
union_idxs_ignore = torch.zeros(
(confidences.shape[0] - 1, confidences.shape[1]), dtype=torch.bool
).to(_TORCH_DEVICE)
).to(loss.device)
for i in range(confidences.shape[0] - 1):
union_idxs_ignore[i] = torch.logical_or(idxs_ignore[i], idxs_ignore[i + 1])

Expand All @@ -512,8 +524,8 @@ def compute_loss(
) -> TensorType["batch_minus_one", "num_valid_keypoints"]:
# compute the differences between matching heatmaps for each keypoint

diffs = torch.zeros((predictions.shape[0] - 1, predictions.shape[1])).to(
_TORCH_DEVICE
diffs = torch.zeros(
(predictions.shape[0] - 1, predictions.shape[1]), device=predictions.device
)

for i in range(diffs.shape[0]):
Expand Down Expand Up @@ -576,9 +588,7 @@ def __init__(
self.downsampled_image_width = downsampled_image_width
self.uniform_heatmaps = uniform_heatmaps

self.prob_threshold = torch.tensor(
prob_threshold, dtype=torch.float, device=self.device
)
self.prob_threshold = torch.tensor(prob_threshold, dtype=torch.float)

if self.loss_name == "unimodal_mse":
self.loss = None
Expand Down
2 changes: 1 addition & 1 deletion lightning_pose/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def temporal_norm(keypoints_pred: Union[np.ndarray, torch.Tensor]) -> np.ndarray
t_loss = TemporalLoss()

if not isinstance(keypoints_pred, torch.Tensor):
keypoints_pred = torch.tensor(keypoints_pred, device=t_loss.device, dtype=torch.float32)
keypoints_pred = torch.tensor(keypoints_pred, dtype=torch.float32)

# (samples, n_keypoints, 2) -> (samples, n_keypoints * 2)
if len(keypoints_pred.shape) != 2:
Expand Down
Loading

0 comments on commit ba95702

Please sign in to comment.