Skip to content

Commit

Permalink
Merge branch 'fix_tp_mem_cache' into mem_fix_async
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 30, 2024
2 parents 81e7a54 + d3db06a commit 6f82050
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List

import torch
from torch import nn
Expand Down
33 changes: 20 additions & 13 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,14 @@ def forward(
# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
Expand All @@ -390,29 +393,33 @@ def backward(ctx, grad_output: torch.Tensor):
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if tp_recompute_allgather:
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input, weight, bias = ctx.saved_tensors

# Get the grad_output and total_input on the correct views to be able to transpose them below.
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
assert grad_output.dim() == 3
grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2))
total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2))
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)

# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
if group.size() == 1:
sub_grad_input = grad_input
else:
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None

return sub_grad_input, grad_weight, grad_bias, None, None
Expand Down
22 changes: 18 additions & 4 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool):
def test_column_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)(
tp_mode=tp_mode, async_communication=async_communication
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)


def _test_column_linear(
parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
Expand All @@ -44,6 +57,7 @@ def _test_column_linear(
mode=tp_mode,
device="cuda",
async_communication=async_communication,
tp_recompute_allgather=tp_recompute_allgather,
)

# Un-sharded
Expand Down Expand Up @@ -86,7 +100,7 @@ def _test_column_linear(
random_input = sharded_random_input
else:
ValueError(f"Unsupported mode: {tp_mode}")
# It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage
# It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
sharded_random_input = sharded_random_input.clone()
random_input.requires_grad = True
sharded_random_input.requires_grad = True
Expand Down

0 comments on commit 6f82050

Please sign in to comment.