Skip to content

Commit

Permalink
[doc] ddp multigpu tutorial - small updates (#3141)
Browse files Browse the repository at this point in the history
* [doc] ddp multigpu tutorial - small updates
The main fix is to change `diff` blocks into `python` blocks so that the
user can easily copy/paste the data to run parts of the tutorial.
Fix author's link.

---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
c-p-i-o and svekars authored Nov 4, 2024
1 parent f08670d commit 32d2b29
Showing 1 changed file with 66 additions and 48 deletions.
114 changes: 66 additions & 48 deletions beginner_source/ddp_series_multigpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Multi GPU training with DDP
===========================

Authors: `Suraj Subramanian <https://github.com/suraj813>`__
Authors: `Suraj Subramanian <https://github.com/subramen>`__

.. grid:: 2

Expand All @@ -19,13 +19,13 @@ Authors: `Suraj Subramanian <https://github.com/suraj813>`__
- How to migrate a single-GPU training script to multi-GPU via DDP
- Setting up the distributed process group
- Saving and loading models in a distributed setup

.. grid:: 1

.. grid-item::

:octicon:`code-square;1.0em;` View the code used in this tutorial on `GitHub <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu.py>`__

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

Expand All @@ -45,11 +45,11 @@ In the `previous tutorial <ddp_series_theory.html>`__, we got a high-level overv
In this tutorial, we start with a single-GPU training script and migrate that to running it on 4 GPUs on a single node.
Along the way, we will talk through important concepts in distributed training while implementing them in our code.

.. note::
.. note::
If your model contains any ``BatchNorm`` layers, it needs to be converted to ``SyncBatchNorm`` to sync the running stats of ``BatchNorm``
layers across replicas.

Use the helper function
Use the helper function
`torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) <https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm>`__ to convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``.


Expand All @@ -58,27 +58,27 @@ Diff for `single_gpu.py <https://github.com/pytorch/examples/blob/main/distribut
These are the changes you typically make to a single-GPU training script to enable DDP.

Imports
~~~~~~~
-------
- ``torch.multiprocessing`` is a PyTorch wrapper around Python's native
multiprocessing
- The distributed process group contains all the processes that can
communicate and synchronize with each other.

.. code-block:: diff
.. code-block:: python
import torch
import torch.nn.functional as F
from utils import MyTrainDataset
import torch
import torch.nn.functional as F
from utils import MyTrainDataset
+ import torch.multiprocessing as mp
+ from torch.utils.data.distributed import DistributedSampler
+ from torch.nn.parallel import DistributedDataParallel as DDP
+ from torch.distributed import init_process_group, destroy_process_group
+ import os
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
Constructing the process group
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
------------------------------

- First, before initializing the group process, call `set_device <https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html?highlight=set_device#torch.cuda.set_device>`__,
which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on `GPU:0`
Expand All @@ -90,66 +90,66 @@ Constructing the process group
- Read more about `choosing a DDP
backend <https://pytorch.org/docs/stable/distributed.html#which-backend-to-use>`__

.. code-block:: diff
.. code-block:: python
+ def ddp_setup(rank: int, world_size: int):
+ """
+ Args:
+ rank: Unique identifier of each process
+ world_size: Total number of processes
+ """
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+ torch.cuda.set_device(rank)
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
def ddp_setup(rank: int, world_size: int):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
Constructing the DDP model
~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------

.. code-block:: diff
.. code-block:: python
- self.model = model.to(gpu_id)
+ self.model = DDP(model, device_ids=[gpu_id])
self.model = DDP(model, device_ids=[gpu_id])
Distributing input data
~~~~~~~~~~~~~~~~~~~~~~~
-----------------------

- `DistributedSampler <https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler>`__
chunks the input data across all distributed processes.
- The `DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`__ combines a dataset and a
sampler, and provides an iterable over the given dataset.
- Each process will receive an input batch of 32 samples; the effective
batch size is ``32 * nprocs``, or 128 when using 4 GPUs.

.. code-block:: diff
.. code-block:: python
train_data = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=32,
- shuffle=True,
+ shuffle=False,
+ sampler=DistributedSampler(train_dataset),
shuffle=False, # We don't shuffle
sampler=DistributedSampler(train_dataset), # Use the Distributed Sampler here.
)
- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work
- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work
properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.

.. code-block:: diff
.. code-block:: python
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
+ self.train_data.sampler.set_epoch(epoch)
self.train_data.sampler.set_epoch(epoch) # call this additional line at every epoch
for source, targets in self.train_data:
...
self._run_batch(source, targets)
Saving model checkpoints
~~~~~~~~~~~~~~~~~~~~~~~~
- We only need to save model checkpoints from one process. Without this
------------------------
- We only need to save model checkpoints from one process. Without this
condition, each process would save its copy of the identical mode. Read
more on saving and loading models with
DDP `here <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints>`__
DDP `here <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints>`__

.. code-block:: diff
Expand All @@ -164,18 +164,18 @@ Saving model checkpoints
.. warning::
`Collective calls <https://pytorch.org/docs/stable/distributed.html#collective-functions>`__ are functions that run on all the distributed processes,
and they are used to gather certain states or values to a specific process. Collective calls require all ranks to run the collective code.
In this example, `_save_checkpoint` should not have any collective calls because it is only run on the ``rank:0`` process.
In this example, `_save_checkpoint` should not have any collective calls because it is only run on the ``rank:0`` process.
If you need to make any collective calls, it should be before the ``if self.gpu_id == 0`` check.


Running the distributed training job
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
------------------------------------

- Include new arguments ``rank`` (replacing ``device``) and
``world_size``.
- ``rank`` is auto-allocated by DDP when calling
`mp.spawn <https://pytorch.org/docs/stable/multiprocessing.html#spawning-subprocesses>`__.
- ``world_size`` is the number of processes across the training job. For GPU training,
- ``world_size`` is the number of processes across the training job. For GPU training,
this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.

.. code-block:: diff
Expand All @@ -189,7 +189,7 @@ Running the distributed training job
+ trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
+ destroy_process_group()
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
Expand All @@ -199,13 +199,31 @@ Running the distributed training job
+ world_size = torch.cuda.device_count()
+ mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
Here's what the code looks like:

.. code-block:: python
def main(rank, world_size, total_epochs, save_every):
ddp_setup(rank, world_size)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
Further Reading
---------------

- `Fault Tolerant distributed training <ddp_series_fault_tolerance.html>`__ (next tutorial in this series)
- `Intro to DDP <ddp_series_theory.html>`__ (previous tutorial in this series)
- `Getting Started with DDP <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
- `Getting Started with DDP <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
- `Process Group
initialization <https://pytorch.org/docs/stable/distributed.html#tcp-initialization>`__
Initialization <https://pytorch.org/docs/stable/distributed.html#tcp-initialization>`__

0 comments on commit 32d2b29

Please sign in to comment.