Skip to content

Commit

Permalink
add all to all function and debug function (#1)
Browse files Browse the repository at this point in the history
* add all to all function

* add debug function

* modify debug function

* polish code
  • Loading branch information
KKZ20 authored Feb 18, 2024
1 parent dc19bb0 commit 5f24031
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
7 changes: 7 additions & 0 deletions dit/debug_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch.distributed as dist


# Print debug information on selected rank
def print_rank(var_name, var_value, rank=0):
if dist.get_rank() == rank:
print(f"[Rank {rank}] {var_name}: {var_value}")
71 changes: 71 additions & 0 deletions dit/models/utils/operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import torch.distributed as dist


# using all_to_all_single api to perform all to all communication
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
else:
input_t = (
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
.transpose(0, 1)
.contiguous()
)

output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()

return output.reshape(
inp_shape[:gather_dim]
+ [
inp_shape[gather_dim] * seq_world_size,
]
+ inp_shape[gather_dim + 1 :]
).contiguous()


# using all_to_all api to perform all to all communication
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()


class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""

@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape

# Todo: Try to make all_to_all_single compatible with a large batch size
if bsz == 1:
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
else:
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)

@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)

0 comments on commit 5f24031

Please sign in to comment.