Skip to content

Commit

Permalink
Add multi-gpu (#139)
Browse files Browse the repository at this point in the history
* Multi-GPU
* delete use_gpu_dependent_labels and add do_loss_mult parameters
* small refactoring
* fix typo in softmax loss
---------
Co-authored-by: Maxim Kalashnikov <[email protected]>
  • Loading branch information
FuTSy13 authored Apr 16, 2024
1 parent c5c26ef commit 3796f1e
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ptls/data_load/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .persist_dataset import PersistDataset
from .duckdb_dataset import DuckDbDataset
from .memory_dataset import MemoryMapDataset, MemoryIterableDataset
from .parquet_dataset import ParquetFiles, ParquetDataset
from .parquet_dataset import ParquetFiles, ParquetDataset, DistributedParquetDataset
from .parquet_file_scan import parquet_file_scan
from .dataloaders import inference_data_loader
113 changes: 113 additions & 0 deletions ptls/data_load/datasets/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@

import numpy as np
import torch
import torch.distributed as dist
import pyarrow.parquet as pq
from omegaconf import ListConfig

from ptls.data_load import read_pyarrow_file
from ptls.data_load import IterableChain

logger = logging.getLogger(__name__)

def iter_with_max_num (iter, max_num = None):
num = 0
for i in iter:
yield i
num += 1
if max_num is not None and num >= max_num:
if dist.is_initialized():
print("STOPPING WORKER on GPU", dist.get_rank())
break


class ParquetFiles:
"""Helper file which search parquet files in specified path.
Expand Down Expand Up @@ -162,3 +174,104 @@ def to_torch(x):
if type(x) is np.ndarray and x.dtype.kind in ('i', 'f'):
return torch.from_numpy(x)
return x

class DistributedParquetDataset(ParquetDataset):
"""Modification of ParquetDataset for working with DDP
Each GPU processes its own set of files
Make sure that number of parquet files > number of GPUs
max_items_per_file is used to ensure that dataloader on different GPUs produce the same amount of batches
it's better to calculate it manually (minimal number of rows in parquet files) before training
Parameters
----------
data_files:
ParquetFile object with list of files or just list of files
post_processing:
- deprecated, use i_filters
i_filters:
- list of `ptls.data_load.iterable_processing` filters
shuffle_files:
- shuffle data_files before reading when True.
cache_schema:
- dict schema (feature names) will be read once
shuffle_seed:
- random seed for shuffle_files
max_items_per_file:
- if passed, worker reads max_items_per_file rows from parquet file to ensure equal amount of examples for different GPUs,
else this quantity is calculated before training which take significant amount of time.
You should calculate it by yourself by counting minimal number of rows in all parquet files
repeat_items:
- whether to start reading same files again on the worker for preventing deadlocks
(caused by inability to calculate exact number of yielded items per worker
and as a result inability to yield the same number of items on different GPUs)
"""
def __init__(self, data_files: Union[ParquetFiles, List[str]],
post_processing=None,
i_filters: List = None,
shuffle_files=False, cache_schema=True, shuffle_seed=42, max_items_per_file = None, repeat_items = True):
super().__init__(data_files = data_files,
post_processing=post_processing,
i_filters = i_filters,
shuffle_files=shuffle_files, cache_schema=cache_schema, shuffle_seed=shuffle_seed)
self.max_items_per_file = max_items_per_file
self.items_per_worker = None
self.repeat_items = repeat_items

def _calc_min_items_per_worker(self):
nums = []
for rank in range(dist.get_world_size()):
per_gpu = 0
for fname in self.data_files[rank::dist.get_world_size()]:
if self.max_items_per_file is not None:
per_gpu += self.max_items_per_file
else:
per_gpu += pq.read_table(fname).shape[0]
nums.append(per_gpu)
return min(nums) // self._num_workers

def _get_my_files(self):
my_files = [name for i, name in enumerate(sorted(self.data_files)) if i % self.real_num_workers == self.real_worker_id]
return my_files

def _init_worker(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
self._worker_id = 0
self._num_workers = 1
self._shuffle_seed = self.shuffle_seed
else: # in a worker process
self._worker_id = worker_info.id
self._num_workers = worker_info.num_workers
self._shuffle_seed = worker_info.seed

self.real_worker_id = self._worker_id
self.real_num_workers = self._num_workers
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
self.real_num_workers = self._num_workers * world_size
self.real_worker_id = rank + self._worker_id * world_size
logger.debug(f'Started [{self.real_worker_id:02d}/{self.real_num_workers:02d}]')

def __iter__(self):
self._init_worker()
if dist.is_initialized() and self.items_per_worker is None:
self.items_per_worker = self._calc_min_items_per_worker()

my_files = self._get_my_files()
if self.shuffle_files:
rs = np.random.RandomState(self._shuffle_seed % 2**32)
rs.shuffle(my_files)

logger.debug(f'Iter [{self._worker_id:02d}/{self._num_workers:02d}]: {my_files}')
if self.repeat_items:
gen = chain(*[self.iter_file(name) for _ in range(2) for name in my_files])
else:
gen = chain(*[self.iter_file(name) for name in my_files])

if self.post_processing is not None:
gen = self.post_processing(gen)
return iter_with_max_num(gen, self.items_per_worker)
17 changes: 15 additions & 2 deletions ptls/frames/coles/losses/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torch import nn as nn
from torch.nn import functional as F
import torch.distributed as dist

from ptls.frames.coles.losses.dist_utils import all_gather_and_cat

class ContrastiveLoss(nn.Module):
"""
Expand All @@ -11,12 +13,19 @@ class ContrastiveLoss(nn.Module):
https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf
"""

def __init__(self, margin, sampling_strategy):
def __init__(self, margin, sampling_strategy, distributed_mode = False, do_loss_mult = False):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.pair_selector = sampling_strategy
self.distributed_mode = distributed_mode
self.do_loss_mult = do_loss_mult

def forward(self, embeddings, target):
if dist.is_initialized() and self.distributed_mode:
dist.barrier()
embeddings = all_gather_and_cat(embeddings)
target = target + (target.max()+1) * dist.get_rank()
target = all_gather_and_cat(target)

positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target)
positive_loss = F.pairwise_distance(embeddings[positive_pairs[:, 0]], embeddings[positive_pairs[:, 1]]).pow(2)
Expand All @@ -26,4 +35,8 @@ def forward(self, embeddings, target):
).pow(2)
loss = torch.cat([positive_loss, negative_loss], dim=0)

return loss.sum()
if dist.is_initialized() and self.do_loss_mult:
loss_mult = dist.get_world_size()
else:
loss_mult = 1
return loss.sum() * loss_mult
20 changes: 20 additions & 0 deletions ptls/frames/coles/losses/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.distributed as dist

class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, tensor)
return tuple(gathered)

@staticmethod
def backward(ctx, *grad_outs):
# if os.environ.get('REDUCE_GRADS'):
# grad_outs = torch.stack(grad_outs)
# dist.all_reduce(grad_outs)
return grad_outs[dist.get_rank()]

def all_gather_and_cat(tensor):
return torch.cat(AllGather.apply(tensor))

10 changes: 9 additions & 1 deletion ptls/frames/coles/losses/softmax_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import torch.distributed as dist

from ptls.frames.coles.losses.dist_utils import all_gather_and_cat

class SoftmaxLoss(torch.nn.Module):
"""Also known as NCE loss
Expand All @@ -13,12 +15,18 @@ class SoftmaxLoss(torch.nn.Module):
`softmax(distances / temperature)` - scale a sub-exponent expression.
default 0.05 value is for l2-normalized `embeddings` where dot product distance is in range [-1, 1]
"""
def __init__(self, temperature=0.05):
def __init__(self, temperature=0.05, distributed_mode = False):
super().__init__()

self.temperature = temperature
self.distributed_mode = distributed_mode

def forward(self, embeddings, classes):
if dist.is_initialized() and self.distributed_mode:
dist.barrier()
embeddings = all_gather_and_cat(embeddings)
classes = classes + (classes.max()+1) * dist.get_rank()
classes = all_gather_and_cat(classes)
d = torch.einsum('bh,kh->bk', embeddings, embeddings) / self.temperature

ix_pos = classes.unsqueeze(1) == classes.unsqueeze(0)
Expand Down
3 changes: 3 additions & 0 deletions ptls/pl_train_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def main(conf: DictConfig):
save_dir='lightning_logs',
name=conf.get('logger_name'),
)
if not isinstance(_trainer_params.get('strategy', ''), str): # if strategy not exist or str do nothing,
_trainer_params_additional['strategy'] = hydra.utils.instantiate(_trainer_params.strategy)
del _trainer_params.strategy

lr_monitor = LearningRateMonitor(logging_interval='step')
_trainer_params_callbacks.append(lr_monitor)
Expand Down

0 comments on commit 3796f1e

Please sign in to comment.