Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Kalashnikov committed Apr 12, 2024
1 parent 1080072 commit 018c80e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 39 deletions.
24 changes: 3 additions & 21 deletions ptls/frames/coles/losses/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,7 @@
from torch.nn import functional as F
import torch.distributed as dist

class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, tensor)
return tuple(gathered)

@staticmethod
def backward(ctx, *grad_outs):
# if os.environ.get('REDUCE_GRADS'):
# grad_outs = torch.stack(grad_outs)
# dist.all_reduce(grad_outs)
return grad_outs[dist.get_rank()]

def all_gather_and_cat(tensor):
return torch.cat(AllGather.apply(tensor))

from dist_utils import all_gather_and_cat

class ContrastiveLoss(nn.Module):
"""
Expand All @@ -29,20 +13,18 @@ class ContrastiveLoss(nn.Module):
https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf
"""

def __init__(self, margin, sampling_strategy, distributed_mode = False, do_loss_mult = False, use_gpu_dependent_labels = False):
def __init__(self, margin, sampling_strategy, distributed_mode = False, do_loss_mult = False):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.pair_selector = sampling_strategy
self.distributed_mode = distributed_mode
self.do_loss_mult = do_loss_mult
self.use_gpu_dependent_labels = use_gpu_dependent_labels

def forward(self, embeddings, target):
if dist.is_initialized() and self.distributed_mode:
dist.barrier()
embeddings = all_gather_and_cat(embeddings)
if self.use_gpu_dependent_labels:
target = target + (target.max()+1) * dist.get_rank()
target = target + (target.max()+1) * dist.get_rank()
target = all_gather_and_cat(target)

positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target)
Expand Down
20 changes: 20 additions & 0 deletions ptls/frames/coles/losses/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.distributed as dist

class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, tensor)
return tuple(gathered)

@staticmethod
def backward(ctx, *grad_outs):
# if os.environ.get('REDUCE_GRADS'):
# grad_outs = torch.stack(grad_outs)
# dist.all_reduce(grad_outs)
return grad_outs[dist.get_rank()]

def all_gather_and_cat(tensor):
return torch.cat(AllGather.apply(tensor))

19 changes: 1 addition & 18 deletions ptls/frames/coles/losses/softmax_loss.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,7 @@
import torch
import torch.distributed as dist

class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, tensor)
return tuple(gathered)

@staticmethod
def backward(ctx, *grad_outs):
# if os.environ.get('REDUCE_GRADS'):
# grad_outs = torch.stack(grad_outs)
# dist.all_reduce(grad_outs)
return grad_outs[dist.get_rank()]

def all_gather_and_cat(tensor):
return torch.cat(AllGather.apply(tensor))


from dist_utils import all_gather_and_cat

class SoftmaxLoss(torch.nn.Module):
"""Also known as NCE loss
Expand Down

0 comments on commit 018c80e

Please sign in to comment.