Skip to content

Commit

Permalink
Support multi-GPU training for supervised training (#206)
Browse files Browse the repository at this point in the history
* Add multi-GPU training support for supervised models

* Shuffle labeled data every epoch

* Add docs on multi-gpu and CUDA OOMs

* update roadmap for multi-gpu
  • Loading branch information
ksikka authored Oct 17, 2024
1 parent b9386a5 commit 84538bf
Show file tree
Hide file tree
Showing 17 changed files with 202 additions and 89 deletions.
3 changes: 2 additions & 1 deletion docs/roadmap.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Lightning Pose development roadmap

## General enhancements
- [ ] multi-GPU training for supervised models
- [x] multi-GPU training for supervised models ([#206](https://github.com/paninski-lab/lightning-pose/pull/206))
- [ ] multi-GPU training for unsupervised models
- [ ] introduce jaxtyping (see [here](https://github.com/google/jaxtyping/issues/70))

## Losses and backbones
Expand Down
22 changes: 16 additions & 6 deletions docs/source/faqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,22 @@ FAQs

.. dropdown:: What if I encounter a CUDA out of memory error?

We recommend a GPU with at least 8GB of memory.
Note that both semi-supervised and context models will increase memory usage
(with semi-supervised context models needing the most memory).
If you encounter this error, reduce batch sizes during training or inference.
You can find the relevant parameters to adjust in :ref:`The configuration file <config_file>`
section.
Model training can be GPU-memory-intensive, particularly when using unsupervised losses, the
Temporal Context Network model, multi-view datasets, or high-resolution images. For this reason
we recommend using a GPU with a minimum of 8GB of memory, but preferrably 16GB.

Some users using a combination of the memory-intensive features above may still run into issues.
There are a few techniques available to reduce the memory consumption:

* Reduce ``train_batch_size``. Memory usage is directly proportional to batch size.
* Enable multi-GPU training (only supported for supervised training) using ``num_gpus``.
* Reduce image resolution using ``image_resize_dims``.
* Enable gradient accumulation using ``accumulate_grad_batches``. This parameter is not included
in the config by default and should be added manually to the ``training`` section.

Each technique above has trade-offs. The right choice will be dependent on your individual situation.

See :ref:`The configuration file <config_file>` section for more information about the above parameters.

.. dropdown:: Why does the network produce high confidence values for keypoints even when they are occluded?

Expand Down
9 changes: 9 additions & 0 deletions docs/source/user_guide/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ Below is a list of some commonly modified arguments related to model architectur
Therefore, 300 epochs, at 40 batches per epoch, is equal to 300*40=12k total batches
(or iterations).

* ``training.num_gpus``: enables multi-GPU training, only supported for supervised losses
(``losses_to_use: []``) . ``train_batch_size`` and ``val_batch_size`` must be divisible by
``num_gpus`` as each batch will be divided amongst the GPUs.

* ``training.accumulate_grad_batches``: (experimental) number of batches to accumulate gradients
for before updating weights. Simulates larger batch sizes with memory-constrained GPUs. This
parameter is not included in the config by default and should be added manually to the
``training`` section.

* ``model.model_type``:

* regression: model directly outputs an (x, y) prediction for each keypoint; not recommended
Expand Down
13 changes: 5 additions & 8 deletions lightning_pose/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def train_dataloader(self) -> torch.utils.data.DataLoader:
batch_size=self.train_batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
shuffle=True,
generator=torch.Generator().manual_seed(self.torch_seed),
)

def val_dataloader(self) -> torch.utils.data.DataLoader:
Expand All @@ -161,7 +163,6 @@ def full_labeled_dataloader(self) -> torch.utils.data.DataLoader:
self.dataset,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
shuffle=False,
)


Expand Down Expand Up @@ -225,10 +226,11 @@ def __init__(
self.video_paths_list = video_paths_list
self.filenames = check_video_paths(self.video_paths_list, view_names=view_names)
self.num_workers_for_unlabeled = 1 # WARNING!! do not increase above 1, weird behavior
self.num_workers_for_labeled = num_workers
self.dali_config = dali_config
self.unlabeled_dataloader = None # initialized in setup_unlabeled
self.imgaug = imgaug
# TODO: Should these belong in a setup method that called by lightning,
# rather than __init__? BaseDataModule already follows that pattern.
super().setup()
self.setup_unlabeled()

Expand All @@ -248,12 +250,7 @@ def setup_unlabeled(self) -> None:

def train_dataloader(self) -> CombinedLoader:
loader = SemiSupervisedDataLoaderDict(
labeled=DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers_for_labeled,
persistent_workers=True,
),
labeled=super().train_dataloader(),
unlabeled=self.unlabeled_dataloader,
)
return CombinedLoader(loader, mode="max_size_cycle")
13 changes: 10 additions & 3 deletions lightning_pose/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,20 @@ def evaluate_labeled(
loss_rmse, _ = self.rmse_loss(stage=stage, **data_dict)

if stage:
# logging with sync_dist=True will average the metric across GPUs in
# multi-GPU training. Performance overhead was found negligible.

# log overall supervised loss
self.log(f"{stage}_supervised_loss", loss, prog_bar=True)
self.log(f"{stage}_supervised_loss", loss, prog_bar=True, sync_dist=True)
# log supervised pixel_error
self.log(f"{stage}_supervised_rmse", loss_rmse)
self.log(f"{stage}_supervised_rmse", loss_rmse, sync_dist=True)
# log individual supervised losses
for log_dict in log_list:
self.log(**log_dict)
self.log(
log_dict['name'],
log_dict['value'].to(self.device),
prog_bar=log_dict.get('prog_bar', False),
sync_dist=True)

return loss

Expand Down
25 changes: 22 additions & 3 deletions lightning_pose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from typeguard import typechecked
import warnings
import sys

from lightning_pose.utils import pretty_print_cfg, pretty_print_str
from lightning_pose.utils.io import (
Expand Down Expand Up @@ -98,9 +100,18 @@ def train(cfg: DictConfig) -> None:
limit_train_batches = calculate_train_batches(cfg, dataset)

# set up trainer
trainer = pl.Trainer( # TODO: be careful with devices when scaling to multiple gpus
accelerator="gpu", # TODO: control from outside
devices=1, # TODO: control from outside

# Old configs may have num_gpus: 0. We will remove support in a future release.
if cfg.training.num_gpus == 0:
warnings.warn(
"Config contains unsupported value num_gpus: 0. "
"Update num_gpus to 1 in your config."
)
cfg.training.num_gpus = max(cfg.training.num_gpus, 1)

trainer = pl.Trainer(
accelerator="gpu",
devices=cfg.training.num_gpus,
max_epochs=cfg.training.max_epochs,
min_epochs=cfg.training.min_epochs,
check_val_every_n_epoch=min(
Expand All @@ -113,11 +124,17 @@ def train(cfg: DictConfig) -> None:
limit_train_batches=limit_train_batches,
accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1),
profiler=cfg.training.get("profiler", None),
sync_batchnorm=True,
)

# train model!
trainer.fit(model=model, datamodule=data_module)

# When devices > 0, lightning creates a process per device.
# Kill processes other than the main process, otherwise they all go forward.
if not trainer.is_global_zero:
sys.exit(0)

# ----------------------------------------------------------------------------------
# Post-training analysis
# ----------------------------------------------------------------------------------
Expand Down Expand Up @@ -150,6 +167,8 @@ def train(cfg: DictConfig) -> None:
# ----------------------------------------------------------------------------------
# predict on all labeled frames (train/val/test)
# ----------------------------------------------------------------------------------
# Rebuild trainer with devices=1 for prediction. Training flags not needed.
trainer = pl.Trainer(accelerator="gpu", devices=1)
pretty_print_str("Predicting train/val/test images...")
# compute and save frame-wise predictions
preds_file = os.path.join(hydra_output_directory, "predictions.csv")
Expand Down
48 changes: 40 additions & 8 deletions lightning_pose/utils/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

import imgaug.augmenters as iaa
import lightning.pytorch as pl
Expand All @@ -11,7 +11,9 @@
import torch
from moviepy.editor import VideoFileClip
from omegaconf import DictConfig, OmegaConf
from omegaconf.errors import ValidationError
from typeguard import typechecked
import warnings

from lightning_pose.callbacks import AnnealWeight, UnfreezeBackbone
from lightning_pose.data.augmentations import imgaug_transform
Expand Down Expand Up @@ -135,14 +137,37 @@ def get_data_module(
) -> Union[BaseDataModule, UnlabeledDataModule]:
"""Create a data module that splits a dataset into train/val/test iterators."""

# Old configs may have num_gpus: 0. We will remove support in a future release.
if cfg.training.num_gpus == 0:
warnings.warn(
"Config contains unsupported value num_gpus: 0. "
"Update num_gpus to 1 in your config."
)
cfg.training.num_gpus = max(cfg.training.num_gpus, 1)

semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use)
if not semi_supervised:
if not (cfg.training.gpu_id, int):
raise NotImplementedError("Cannot fit fully supervised model on multiple gpus")
# Divide config batch_size by num_gpus to maintain the same effective batch
# size in a multi-gpu setting.
if cfg.training.train_batch_size % cfg.training.num_gpus != 0:
raise ValidationError(
f"train_batch_size should be a multiple of num_gpus. "
"train_batch_size={cfg.training.train_batch_size}, "
"num_gpus={cfg.training.num_gpus}"
)
if cfg.training.val_batch_size % cfg.training.num_gpus != 0:
raise ValidationError(
f"val_batch_size should be a multiple of num_gpus. "
"val_batch_size={cfg.training.val_batch_size}, "
"num_gpus={cfg.training.num_gpus}"
)
train_batch_size = int(cfg.training.train_batch_size / cfg.training.num_gpus)
val_batch_size = int(cfg.training.val_batch_size / cfg.training.num_gpus)

data_module = BaseDataModule(
dataset=dataset,
train_batch_size=cfg.training.train_batch_size,
val_batch_size=cfg.training.val_batch_size,
train_batch_size=train_batch_size,
val_batch_size=val_batch_size,
test_batch_size=cfg.training.test_batch_size,
num_workers=cfg.training.num_workers,
train_probability=cfg.training.train_prob,
Expand All @@ -152,11 +177,14 @@ def get_data_module(
)
else:
if cfg.model.model_type == "heatmap_mhcrnn" and cfg.dali.context.train.batch_size < 5:
raise ValueError(
raise ValidationError(
"cfg.dali.context.train.batch_size must be >=5 for semi-supervised context models"
)
if not (cfg.training.gpu_id, int):
raise NotImplementedError("Cannot fit semi-supervised model on multiple gpus")
if cfg.training.num_gpus > 1:
raise ValidationError(
"Detected num_gpus > 1 and losses != null. "
"Multi-gpu not yet supported for unsupervised losses."
)
view_names = cfg.data.get("view_names", None)
view_names = list(view_names) if view_names is not None else None
data_module = UnlabeledDataModule(
Expand Down Expand Up @@ -500,6 +528,10 @@ def calculate_train_batches(
num_train_frames = compute_num_train_frames(
data_splits_list[0], cfg.training.get("train_frames", None)
)
# For multi-GPU, the computation is unchanged.
# num_train_frames is divided by num_gpus to get num_train_frames per gpu
# train_batch_size is also divided by num_gpus to get the mini-batch size
# so num_gpus cancels out of the numerator and denominator.
num_labeled_batches = int(np.ceil(num_train_frames / cfg.training.train_batch_size))
limit_train_batches = np.max([num_labeled_batches, 10]) # 10 is minimum
else:
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_crim13.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ training:
log_every_n_steps: 10
# frequency to log validation metrics
check_val_every_n_epoch: 5
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ training:
early_stopping: false
# epochs over which to assess validation metrics for early stopping
early_stop_patience: 3
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_ibl-paw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ training:
log_every_n_steps: 10
# frequency to log validation metrics
check_val_every_n_epoch: 5
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_ibl-pupil.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ training:
log_every_n_steps: 10
# frequency to log validation metrics
check_val_every_n_epoch: 5
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_mirror-fish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ training:
log_every_n_steps: 10
# frequency to log validation metrics
check_val_every_n_epoch: 5
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
4 changes: 1 addition & 3 deletions scripts/configs/config_mirror-mouse-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ training:
# >1: number of total train frames used for training
train_frames: 1
# number of gpus to train a single model
num_gpus: 0
num_gpus: 1
# number of cpu workers for data loaders
num_workers: 4
# epochs over which to assess validation metrics for early stopping
Expand All @@ -79,8 +79,6 @@ training:
log_every_n_steps: 1
# frequency to log validation metrics
check_val_every_n_epoch: 10
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 42
# rng seed for weight initialization
Expand Down
2 changes: 0 additions & 2 deletions scripts/configs/config_mirror-mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ training:
log_every_n_steps: 10
# frequency to log validation metrics
check_val_every_n_epoch: 5
# select gpu for training
gpu_id: 0
# rng seed for labeled batches
rng_seed_data_pt: 0
# rng seed for weight initialization
Expand Down
17 changes: 13 additions & 4 deletions tests/data/test_datamodules.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Test datamodule functionality."""

import pytest
import torch
from lightning.pytorch.utilities import CombinedLoader
import torch
from torch.utils.data import RandomSampler


def test_base_datamodule(cfg, base_data_module):
Expand All @@ -15,17 +16,23 @@ def test_base_datamodule(cfg, base_data_module):
num_targets = base_data_module.dataset.num_targets

# check batch properties
batch = next(iter(base_data_module.train_dataloader()))
train_dataloader = base_data_module.train_dataloader()
assert isinstance(train_dataloader.sampler, RandomSampler) # shuffle=True
batch = next(iter(train_dataloader))
assert batch["images"].shape == (train_size, 3, im_height, im_width)
assert batch["keypoints"].shape == (train_size, num_targets)

batch = next(iter(base_data_module.val_dataloader()))
val_dataloader = base_data_module.val_dataloader()
batch = next(iter(val_dataloader))
assert not isinstance(val_dataloader.sampler, RandomSampler) # shuffle=False
assert batch["images"].shape[1:] == (3, im_height, im_width)
assert batch["keypoints"].shape[1:] == (num_targets,)
assert batch["images"].shape[0] == batch["keypoints"].shape[0]
assert batch["images"].shape[0] <= val_size

batch = next(iter(base_data_module.test_dataloader()))
test_dataloader = base_data_module.test_dataloader()
batch = next(iter(test_dataloader))
assert not isinstance(test_dataloader.sampler, RandomSampler) # shuffle=False
assert batch["images"].shape[1:] == (3, im_height, im_width)
assert batch["keypoints"].shape[1:] == (num_targets,)
assert batch["images"].shape[0] == batch["keypoints"].shape[0]
Expand Down Expand Up @@ -193,6 +200,8 @@ def test_base_data_module_combined(cfg, base_data_module_combined):

# test outputs for single batch
loader = base_data_module_combined.train_dataloader()
assert isinstance(loader.sampler["labeled"], RandomSampler) # shuffle=True

batch = next(iter(loader))
# batch tuple in lightning >=2.0.9
batch = batch[0] if isinstance(batch, tuple) else batch
Expand Down
Loading

0 comments on commit 84538bf

Please sign in to comment.