Skip to content

Commit

Permalink
Added tp_recompute_allgather test
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 30, 2024
1 parent cd84d4f commit d3db06a
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool):
def test_column_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)(
tp_mode=tp_mode, async_communication=async_communication
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)


def _test_column_linear(
parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
Expand All @@ -44,6 +57,7 @@ def _test_column_linear(
mode=tp_mode,
device="cuda",
async_communication=async_communication,
tp_recompute_allgather=tp_recompute_allgather,
)

# Un-sharded
Expand Down Expand Up @@ -86,7 +100,7 @@ def _test_column_linear(
random_input = sharded_random_input
else:
ValueError(f"Unsupported mode: {tp_mode}")
# It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage
# It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
sharded_random_input = sharded_random_input.clone()
random_input.requires_grad = True
sharded_random_input.requires_grad = True
Expand Down

0 comments on commit d3db06a

Please sign in to comment.