Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Fix distributed sampler for multi GPU #3053

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,30 @@ def __init__(
batch_size=batch_size,
drop_last=drop_last,
)
# do not touch batch size here, sampler gives batched indices
# This disables PyTorch automatic batching, which is necessary
# for fast access to sparse matrices
self.kwargs.update({"batch_size": None, "shuffle": False})
else:
if "save_path" not in kwargs:
kwargs["save_path"] = "/."
if "num_processes" not in kwargs:
kwargs["num_processes"] = 1
sampler = BatchDistributedSampler(
self.dataset,
batch_size=batch_size,
drop_last=drop_last,
drop_dataset_tail=drop_dataset_tail,
shuffle=shuffle,
**kwargs,
)
# do not touch batch size here, sampler gives batched indices
# This disables PyTorch automatic batching, which is necessary
# for fast access to sparse matrices
self.kwargs.update({"batch_size": None, "shuffle": False})

self.kwargs.update({"sampler": sampler})

if iter_ndarray:
self.kwargs.update({"collate_fn": lambda x: x})

super().__init__(self.dataset, **self.kwargs)
for redundant_key in ["save_path", "num_processes"]:
if redundant_key in self.kwargs:
self.kwargs.pop(redundant_key)
super().__init__(self.dataset, drop_last=drop_dataset_tail, **self.kwargs)
19 changes: 18 additions & 1 deletion src/scvi/dataloaders/_concat_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ class ConcatDataLoader(DataLoader):
Dictionary with keys representing keys in data registry (``adata_manager.data_registry``)
and value equal to desired numpy loading type (later made into torch tensor).
If ``None``, defaults to all registered data.
drop_last
If `True` and the dataset is not evenly divisible by `batch_size`, the last
incomplete batch is dropped. If `False` and the dataset is not evenly divisible
by `batch_size`, then the last batch will be smaller than `batch_size`.
distributed_sampler
``EXPERIMENTAL`` Whether to use :class:`~scvi.dataloaders.BatchDistributedSampler` as the
sampler. If `True`, `sampler` must be `None`.
data_loader_kwargs
Keyword arguments for :class:`~torch.utils.data.DataLoader`
"""
Expand All @@ -37,6 +44,7 @@ def __init__(
batch_size: int = 128,
data_and_attributes: dict | None = None,
drop_last: bool | int = False,
distributed_sampler: bool = False,
**data_loader_kwargs,
):
self.adata_manager = adata_manager
Expand All @@ -45,6 +53,11 @@ def __init__(
self._shuffle = shuffle
self._batch_size = batch_size
self._drop_last = drop_last
self._drop_dataset_tail = (
self.dataloader_kwargs["drop_dataset_tail"]
if "drop_dataset_tail" in self.dataloader_kwargs.keys()
else False
)

self.dataloaders = []
for indices in indices_list:
Expand All @@ -56,12 +69,16 @@ def __init__(
batch_size=batch_size,
data_and_attributes=data_and_attributes,
drop_last=drop_last,
distributed_sampler=distributed_sampler,
**self.dataloader_kwargs,
)
)
lens = [len(dl) for dl in self.dataloaders]
self.largest_dl = self.dataloaders[np.argmax(lens)]
super().__init__(self.largest_dl, **data_loader_kwargs)
for redundant_key in ["save_path", "num_processes", "drop_dataset_tail"]:
if redundant_key in data_loader_kwargs:
data_loader_kwargs.pop(redundant_key)
super().__init__(self.largest_dl, drop_last=self._drop_dataset_tail, **data_loader_kwargs)

def __len__(self):
return len(self.largest_dl)
Expand Down
21 changes: 21 additions & 0 deletions src/scvi/dataloaders/_samplers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch.utils.data import Dataset, DistributedSampler


Expand Down Expand Up @@ -36,6 +37,26 @@ def __init__(
drop_dataset_tail: bool = False,
**kwargs,
):
if not torch.distributed.is_initialized() and torch.cuda.is_available():
# initializes the distributed backend that takes care of synchronizing processes
torch.distributed.init_process_group(
"nccl", # backend that works on all systems
init_method="file://" + kwargs["save_path"] + "/dist_file",
rank=0,
world_size=kwargs["num_processes"],
store=None,
)

for redundant_key in [
"save_path",
"pin_memory",
"num_processes",
"num_workers",
"persistent_workers",
]:
if redundant_key in kwargs:
kwargs.pop(redundant_key)

super().__init__(dataset, drop_last=drop_dataset_tail, **kwargs)
self.batch_size = batch_size
self.drop_last_batch = drop_last # drop_last already defined in parent
Expand Down
75 changes: 62 additions & 13 deletions tests/dataloaders/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os

import numpy as np
import pytest
import torch
from tests.data.utils import generic_setup_adata_manager

import scvi
from scvi import REGISTRY_KEYS
from scvi.model import SCANVI


class TestSemiSupervisedTrainingPlan(scvi.train.SemiSupervisedTrainingPlan):
Expand Down Expand Up @@ -92,29 +95,75 @@ def test_anndataloader_distributed_sampler_init():


def multiprocessing_worker(
rank: int, world_size: int, manager: scvi.data.AnnDataManager, save_path: str
rank: int,
world_size: int,
manager: scvi.data.AnnDataManager,
save_path: str,
datasplitter_kwargs,
):
# initializes the distributed backend that takes care of synchronizing processes
torch.distributed.init_process_group(
"gloo", # backend that works on all systems
"nccl", # backend that works on all systems
init_method=f"file://{save_path}/dist_file",
rank=rank,
world_size=world_size,
store=None,
)

_ = scvi.dataloaders.AnnDataLoader(manager, distributed_sampler=True)
_ = scvi.dataloaders.AnnDataLoader(manager, **datasplitter_kwargs)

return


@pytest.mark.optional
def test_anndataloader_distributed_sampler(save_path: str, num_processes: int = 2):
adata = scvi.data.synthetic_iid()
manager = generic_setup_adata_manager(adata)
@pytest.mark.parametrize("num_processes", [1, 2])
def test_anndataloader_distributed_sampler(num_processes: int, save_path: str):
if torch.cuda.is_available():
adata = scvi.data.synthetic_iid()
manager = generic_setup_adata_manager(adata)

torch.multiprocessing.spawn(
multiprocessing_worker,
args=(num_processes, manager, save_path),
nprocs=num_processes,
join=True,
)
file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)

torch.multiprocessing.spawn(
multiprocessing_worker,
args=(num_processes, manager, save_path, {}),
nprocs=num_processes,
join=True,
)


@pytest.mark.parametrize("num_processes", [1, 2])
def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str):
if torch.cuda.is_available():
adata = scvi.data.synthetic_iid()
manager = generic_setup_adata_manager(adata)
SCANVI.setup_anndata(
adata,
"labels",
"label_0",
batch_key="batch",
)
file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)
datasplitter_kwargs = {}
# Multi-GPU settings
datasplitter_kwargs["distributed_sampler"] = True
datasplitter_kwargs["save_path"] = save_path
datasplitter_kwargs["num_processes"] = num_processes
datasplitter_kwargs["drop_dataset_tail"] = True
datasplitter_kwargs["drop_last"] = False
if num_processes == 1:
datasplitter_kwargs["distributed_sampler"] = False
datasplitter_kwargs["drop_dataset_tail"] = False
model = SCANVI(adata, n_latent=10)

torch.multiprocessing.spawn(
multiprocessing_worker,
args=(num_processes, manager, save_path, {}),
nprocs=num_processes,
join=True,
)

model.train(1, datasplitter_kwargs=datasplitter_kwargs)
28 changes: 28 additions & 0 deletions tests/dataloaders/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from math import ceil, floor

import numpy as np
Expand All @@ -8,14 +9,21 @@
from scvi.dataloaders import BatchDistributedSampler


@pytest.mark.parametrize("num_processes", [1, 2])
def test_batchdistributedsampler_init(
num_processes: int,
save_path: str,
batch_size: int = 128,
n_batches: int = 2,
):
adata = scvi.data.synthetic_iid(batch_size=batch_size, n_batches=n_batches)
manager = generic_setup_adata_manager(adata)
dataset = manager.create_torch_dataset()

file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)

sampler = BatchDistributedSampler(
dataset,
num_replicas=1,
Expand All @@ -24,6 +32,8 @@ def test_batchdistributedsampler_init(
shuffle=True,
drop_last=True,
drop_dataset_tail=True,
num_processes=num_processes,
save_path=save_path,
)
assert sampler.batch_size == batch_size
assert sampler.rank == 0
Expand All @@ -35,9 +45,12 @@ def test_batchdistributedsampler_init(

@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("drop_dataset_tail", [True, False])
@pytest.mark.parametrize("num_processes", [1, 2])
def test_batchdistributedsampler_drop_last(
num_processes: int,
drop_last: bool,
drop_dataset_tail: bool,
save_path: str,
batch_size: int = 128,
n_batches: int = 3,
num_replicas: int = 2,
Expand Down Expand Up @@ -101,6 +114,10 @@ def check_samplers(samplers: list, sampler_batch_size: int):
assert len(all_indices) == effective_n_obs_per_sampler
assert [len(indices) for indices in batch_indices] == batch_sizes

file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)

for sampler_batch_size in [batch_size, batch_size - 1, batch_size + 1]:
samplers = [
BatchDistributedSampler(
Expand All @@ -110,13 +127,18 @@ def check_samplers(samplers: list, sampler_batch_size: int):
batch_size=sampler_batch_size,
drop_last=drop_last,
drop_dataset_tail=drop_dataset_tail,
num_processes=num_processes,
save_path=save_path,
)
for i in range(num_replicas)
]
check_samplers(samplers, sampler_batch_size)


@pytest.mark.parametrize("num_processes", [1, 2])
def test_batchdistributedsampler_indices(
num_processes: int,
save_path: str,
batch_size: int = 128,
n_batches: int = 3,
num_replicas: int = 2,
Expand All @@ -125,12 +147,18 @@ def test_batchdistributedsampler_indices(
manager = generic_setup_adata_manager(adata)
dataset = manager.create_torch_dataset()

file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)

samplers = [
BatchDistributedSampler(
dataset,
num_replicas=num_replicas,
rank=i,
batch_size=batch_size,
num_processes=num_processes,
save_path=save_path,
)
for i in range(num_replicas)
]
Expand Down
Loading