From d3db06acb2e3fe235ae512861ff64ee9fbc9ac11 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:28 +0200 Subject: [PATCH] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index f5dcaeb0..8e73973b 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True