From 3796f1e4b2ac8fd13620c8265e9f0cad70533f32 Mon Sep 17 00:00:00 2001 From: Maxim Kalashnikov <38400083+FuTSy13@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:59:02 +0300 Subject: [PATCH] Add multi-gpu (#139) * Multi-GPU * delete use_gpu_dependent_labels and add do_loss_mult parameters * small refactoring * fix typo in softmax loss --------- Co-authored-by: Maxim Kalashnikov --- ptls/data_load/datasets/__init__.py | 2 +- ptls/data_load/datasets/parquet_dataset.py | 113 +++++++++++++++++++ ptls/frames/coles/losses/contrastive_loss.py | 17 ++- ptls/frames/coles/losses/dist_utils.py | 20 ++++ ptls/frames/coles/losses/softmax_loss.py | 10 +- ptls/pl_train_module.py | 3 + 6 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 ptls/frames/coles/losses/dist_utils.py diff --git a/ptls/data_load/datasets/__init__.py b/ptls/data_load/datasets/__init__.py index 9f0cac16..9f296ca6 100644 --- a/ptls/data_load/datasets/__init__.py +++ b/ptls/data_load/datasets/__init__.py @@ -2,6 +2,6 @@ from .persist_dataset import PersistDataset from .duckdb_dataset import DuckDbDataset from .memory_dataset import MemoryMapDataset, MemoryIterableDataset -from .parquet_dataset import ParquetFiles, ParquetDataset +from .parquet_dataset import ParquetFiles, ParquetDataset, DistributedParquetDataset from .parquet_file_scan import parquet_file_scan from .dataloaders import inference_data_loader diff --git a/ptls/data_load/datasets/parquet_dataset.py b/ptls/data_load/datasets/parquet_dataset.py index 815c1307..6c090dd3 100644 --- a/ptls/data_load/datasets/parquet_dataset.py +++ b/ptls/data_load/datasets/parquet_dataset.py @@ -7,6 +7,8 @@ import numpy as np import torch +import torch.distributed as dist +import pyarrow.parquet as pq from omegaconf import ListConfig from ptls.data_load import read_pyarrow_file @@ -14,6 +16,16 @@ logger = logging.getLogger(__name__) +def iter_with_max_num (iter, max_num = None): + num = 0 + for i in iter: + yield i + num += 1 + if max_num is not None and num >= max_num: + if dist.is_initialized(): + print("STOPPING WORKER on GPU", dist.get_rank()) + break + class ParquetFiles: """Helper file which search parquet files in specified path. @@ -162,3 +174,104 @@ def to_torch(x): if type(x) is np.ndarray and x.dtype.kind in ('i', 'f'): return torch.from_numpy(x) return x + +class DistributedParquetDataset(ParquetDataset): + """Modification of ParquetDataset for working with DDP + Each GPU processes its own set of files + Make sure that number of parquet files > number of GPUs + + max_items_per_file is used to ensure that dataloader on different GPUs produce the same amount of batches + it's better to calculate it manually (minimal number of rows in parquet files) before training + + + Parameters + ---------- + data_files: + ParquetFile object with list of files or just list of files + post_processing: + - deprecated, use i_filters + i_filters: + - list of `ptls.data_load.iterable_processing` filters + shuffle_files: + - shuffle data_files before reading when True. + cache_schema: + - dict schema (feature names) will be read once + shuffle_seed: + - random seed for shuffle_files + max_items_per_file: + - if passed, worker reads max_items_per_file rows from parquet file to ensure equal amount of examples for different GPUs, + else this quantity is calculated before training which take significant amount of time. + You should calculate it by yourself by counting minimal number of rows in all parquet files + repeat_items: + - whether to start reading same files again on the worker for preventing deadlocks + (caused by inability to calculate exact number of yielded items per worker + and as a result inability to yield the same number of items on different GPUs) + + """ + def __init__(self, data_files: Union[ParquetFiles, List[str]], + post_processing=None, + i_filters: List = None, + shuffle_files=False, cache_schema=True, shuffle_seed=42, max_items_per_file = None, repeat_items = True): + super().__init__(data_files = data_files, + post_processing=post_processing, + i_filters = i_filters, + shuffle_files=shuffle_files, cache_schema=cache_schema, shuffle_seed=shuffle_seed) + self.max_items_per_file = max_items_per_file + self.items_per_worker = None + self.repeat_items = repeat_items + + def _calc_min_items_per_worker(self): + nums = [] + for rank in range(dist.get_world_size()): + per_gpu = 0 + for fname in self.data_files[rank::dist.get_world_size()]: + if self.max_items_per_file is not None: + per_gpu += self.max_items_per_file + else: + per_gpu += pq.read_table(fname).shape[0] + nums.append(per_gpu) + return min(nums) // self._num_workers + + def _get_my_files(self): + my_files = [name for i, name in enumerate(sorted(self.data_files)) if i % self.real_num_workers == self.real_worker_id] + return my_files + + def _init_worker(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # single-process data loading, return the full iterator + self._worker_id = 0 + self._num_workers = 1 + self._shuffle_seed = self.shuffle_seed + else: # in a worker process + self._worker_id = worker_info.id + self._num_workers = worker_info.num_workers + self._shuffle_seed = worker_info.seed + + self.real_worker_id = self._worker_id + self.real_num_workers = self._num_workers + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + self.real_num_workers = self._num_workers * world_size + self.real_worker_id = rank + self._worker_id * world_size + logger.debug(f'Started [{self.real_worker_id:02d}/{self.real_num_workers:02d}]') + + def __iter__(self): + self._init_worker() + if dist.is_initialized() and self.items_per_worker is None: + self.items_per_worker = self._calc_min_items_per_worker() + + my_files = self._get_my_files() + if self.shuffle_files: + rs = np.random.RandomState(self._shuffle_seed % 2**32) + rs.shuffle(my_files) + + logger.debug(f'Iter [{self._worker_id:02d}/{self._num_workers:02d}]: {my_files}') + if self.repeat_items: + gen = chain(*[self.iter_file(name) for _ in range(2) for name in my_files]) + else: + gen = chain(*[self.iter_file(name) for name in my_files]) + + if self.post_processing is not None: + gen = self.post_processing(gen) + return iter_with_max_num(gen, self.items_per_worker) \ No newline at end of file diff --git a/ptls/frames/coles/losses/contrastive_loss.py b/ptls/frames/coles/losses/contrastive_loss.py index e6e2ff34..56d2313e 100644 --- a/ptls/frames/coles/losses/contrastive_loss.py +++ b/ptls/frames/coles/losses/contrastive_loss.py @@ -1,7 +1,9 @@ import torch from torch import nn as nn from torch.nn import functional as F +import torch.distributed as dist +from ptls.frames.coles.losses.dist_utils import all_gather_and_cat class ContrastiveLoss(nn.Module): """ @@ -11,12 +13,19 @@ class ContrastiveLoss(nn.Module): https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf """ - def __init__(self, margin, sampling_strategy): + def __init__(self, margin, sampling_strategy, distributed_mode = False, do_loss_mult = False): super(ContrastiveLoss, self).__init__() self.margin = margin self.pair_selector = sampling_strategy + self.distributed_mode = distributed_mode + self.do_loss_mult = do_loss_mult def forward(self, embeddings, target): + if dist.is_initialized() and self.distributed_mode: + dist.barrier() + embeddings = all_gather_and_cat(embeddings) + target = target + (target.max()+1) * dist.get_rank() + target = all_gather_and_cat(target) positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target) positive_loss = F.pairwise_distance(embeddings[positive_pairs[:, 0]], embeddings[positive_pairs[:, 1]]).pow(2) @@ -26,4 +35,8 @@ def forward(self, embeddings, target): ).pow(2) loss = torch.cat([positive_loss, negative_loss], dim=0) - return loss.sum() \ No newline at end of file + if dist.is_initialized() and self.do_loss_mult: + loss_mult = dist.get_world_size() + else: + loss_mult = 1 + return loss.sum() * loss_mult \ No newline at end of file diff --git a/ptls/frames/coles/losses/dist_utils.py b/ptls/frames/coles/losses/dist_utils.py new file mode 100644 index 00000000..34ab9f25 --- /dev/null +++ b/ptls/frames/coles/losses/dist_utils.py @@ -0,0 +1,20 @@ +import torch +import torch.distributed as dist + +class AllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor): + gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, tensor) + return tuple(gathered) + + @staticmethod + def backward(ctx, *grad_outs): + # if os.environ.get('REDUCE_GRADS'): + # grad_outs = torch.stack(grad_outs) + # dist.all_reduce(grad_outs) + return grad_outs[dist.get_rank()] + +def all_gather_and_cat(tensor): + return torch.cat(AllGather.apply(tensor)) + diff --git a/ptls/frames/coles/losses/softmax_loss.py b/ptls/frames/coles/losses/softmax_loss.py index 30cb5f58..e5419e27 100644 --- a/ptls/frames/coles/losses/softmax_loss.py +++ b/ptls/frames/coles/losses/softmax_loss.py @@ -1,5 +1,7 @@ import torch +import torch.distributed as dist +from ptls.frames.coles.losses.dist_utils import all_gather_and_cat class SoftmaxLoss(torch.nn.Module): """Also known as NCE loss @@ -13,12 +15,18 @@ class SoftmaxLoss(torch.nn.Module): `softmax(distances / temperature)` - scale a sub-exponent expression. default 0.05 value is for l2-normalized `embeddings` where dot product distance is in range [-1, 1] """ - def __init__(self, temperature=0.05): + def __init__(self, temperature=0.05, distributed_mode = False): super().__init__() self.temperature = temperature + self.distributed_mode = distributed_mode def forward(self, embeddings, classes): + if dist.is_initialized() and self.distributed_mode: + dist.barrier() + embeddings = all_gather_and_cat(embeddings) + classes = classes + (classes.max()+1) * dist.get_rank() + classes = all_gather_and_cat(classes) d = torch.einsum('bh,kh->bk', embeddings, embeddings) / self.temperature ix_pos = classes.unsqueeze(1) == classes.unsqueeze(0) diff --git a/ptls/pl_train_module.py b/ptls/pl_train_module.py index fd21c26a..fc652a23 100644 --- a/ptls/pl_train_module.py +++ b/ptls/pl_train_module.py @@ -49,6 +49,9 @@ def main(conf: DictConfig): save_dir='lightning_logs', name=conf.get('logger_name'), ) + if not isinstance(_trainer_params.get('strategy', ''), str): # if strategy not exist or str do nothing, + _trainer_params_additional['strategy'] = hydra.utils.instantiate(_trainer_params.strategy) + del _trainer_params.strategy lr_monitor = LearningRateMonitor(logging_interval='step') _trainer_params_callbacks.append(lr_monitor)