From 12ec0866286ea9b258d531917c8a0797d9fe92d2 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sat, 22 Jul 2023 11:14:09 +0000 Subject: [PATCH 1/4] gpt2 fix --- src/tensor_parallel/autoconfig.py | 10 ++++++---- src/tensor_parallel/slicing_configs.py | 6 +++--- src/tensor_parallel/state_actions.py | 23 +++++++++++++---------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/tensor_parallel/autoconfig.py b/src/tensor_parallel/autoconfig.py index 0b71aee..1e31952 100644 --- a/src/tensor_parallel/autoconfig.py +++ b/src/tensor_parallel/autoconfig.py @@ -6,7 +6,7 @@ from torch.nn.modules import conv from tensor_parallel.config import Config -from tensor_parallel.state_actions import Scale, Split, SplitInsideChunks +from tensor_parallel.state_actions import Scale, Split, SplitInGroupedChunks def get_default_config(module: nn.Module, device_ids: Sequence[torch.device]) -> Config: @@ -63,11 +63,13 @@ def get_default_config(module: nn.Module, device_ids: Sequence[torch.device]) -> if not module.transposed: state_rules[f"^{name}.weight$"] = Split(world_size=len(device_ids), dim=1) else: - state_rules[f"^{name}.weight$"] = SplitInsideChunks( - world_size=len(device_ids), dim=0, num_chunks=groups + state_rules[f"^{name}.weight$"] = SplitInGroupedChunks( + world_size=len(device_ids), dim=0, num_groups=groups, chunk_size=1 ) if module.bias is not None: state_rules[f"^{name}.bias$"] = Scale(world_size=len(device_ids)) - input_rules[f"^{name}$"] = {0: SplitInsideChunks(world_size=len(device_ids), dim=1, num_chunks=groups)} + input_rules[f"^{name}$"] = { + 0: SplitInGroupedChunks(world_size=len(device_ids), dim=1, num_groups=groups, chunk_size=1) + } output_rules[f"^{name}$"] = {0: "sum"} return Config(state_rules, input_rules, output_rules, {}) diff --git a/src/tensor_parallel/slicing_configs.py b/src/tensor_parallel/slicing_configs.py index f8504fa..5412fe2 100644 --- a/src/tensor_parallel/slicing_configs.py +++ b/src/tensor_parallel/slicing_configs.py @@ -23,7 +23,7 @@ from tensor_parallel.communications import CollectiveOperation from tensor_parallel.config import Config from tensor_parallel.per_device_tensors import PerDeviceTensors -from tensor_parallel.state_actions import Scale, Split, SplitInChunks, SplitInsideChunks +from tensor_parallel.state_actions import Scale, Split, SplitInChunks, SplitInGroupedChunks ConfigGetter = Callable[[PretrainedConfig, Sequence[torch.device]], Config] @@ -188,8 +188,8 @@ def get_gpt2_config(model_config: GPT2Config, devices: Sequence[torch.device]) - return Config( state_rules={ # GPT2Attention - r".*c_attn\.weight$": SplitInsideChunks(world_size=world_size, dim=1, num_chunks=3), - r".*c_attn\.bias$": SplitInsideChunks(world_size=world_size, dim=0, num_chunks=3), + r".*c_attn\.weight$": SplitInGroupedChunks(world_size=world_size, dim=1, num_groups=3, chunk_size=head_dim), + r".*c_attn\.bias$": SplitInGroupedChunks(world_size=world_size, dim=0, num_groups=3, chunk_size=head_dim), r".*q_attn\.weight$": SplitInChunks(world_size=world_size, dim=1, chunk_size=head_dim), r".*q_attn\.bias$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), r".*attn\.c_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), diff --git a/src/tensor_parallel/state_actions.py b/src/tensor_parallel/state_actions.py index 042750f..2908f9f 100644 --- a/src/tensor_parallel/state_actions.py +++ b/src/tensor_parallel/state_actions.py @@ -90,20 +90,23 @@ def undo(self, tensors: Sequence[Tensor]) -> Tensor: return torch.cat([tensor.cpu() for tensor in tensors], dim=self.dim) -class SplitInsideChunks(Split): - """AAABBBCCCDDDEEE -> (world_size = 3, num_chunks = 5) -> [ABCDE, ABCDE, ABCDE]""" +class SplitInGroupedChunks(Split): + """AABBCCDDEE AABBCCDDEE AABBCCDDEE -> (world_size = 3, num_groups = 3, chunk_size = 2) -> [AABB AABB AABB, CCDD CCDD CCDD, EE EE EE]""" - def __init__(self, world_size: int, dim: int, num_chunks: int) -> None: + def __init__(self, world_size: int, dim: int, num_groups: int, chunk_size: int) -> None: super().__init__(world_size, dim) - self.num_chunks = num_chunks + self.num_groups = num_groups + self.chunk_size = chunk_size def __call__(self, tensor: Tensor, rank: int) -> Tensor: - shape = list(tensor.shape) - shape[self.dim] = shape[self.dim] // self.num_chunks - shape.insert(self.dim, self.num_chunks) - grouped_tensor = tensor.reshape(*shape) - grouped_shard = torch.tensor_split(grouped_tensor, self.world_size, dim=self.dim + 1)[rank] - return torch.flatten(grouped_shard, start_dim=self.dim, end_dim=self.dim + 1) + shape = list(tensor.shape) # ... x hidden_size x ... + shape[self.dim] //= self.num_groups + shape.insert(self.dim, self.num_groups) # ... group x group_size x ... + shape[self.dim + 1] //= self.chunk_size + shape.insert(self.dim + 2, self.chunk_size) # ... group x chunk x chunk_size ... + return ( + tensor.reshape(shape).tensor_split(self.world_size, dim=self.dim + 1)[rank].flatten(self.dim, self.dim + 2) + ) def undo(self, tensors: Sequence[Tensor]) -> Tensor: grouped_tensor = [] From dd6250b6ed7feb9a0079221639c5df6fcdaa1650 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sat, 22 Jul 2023 11:14:51 +0000 Subject: [PATCH 2/4] version update --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 030677c..2ceb9ea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = tensor_parallel -version = 1.2.9 +version = 1.3.0 author = Andrei Panferov and Yaroslav Lisnyak author_email = yalisnyak@nes.com description = Automatically shard your large model between multiple GPUs, works without torch.distributed From 8f1b5839a185d1ab410a0e746e8126986ede1c09 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sat, 22 Jul 2023 11:34:24 +0000 Subject: [PATCH 3/4] undo fix --- src/tensor_parallel/state_actions.py | 12 ++++-------- tests/test_saving.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/tensor_parallel/state_actions.py b/src/tensor_parallel/state_actions.py index 2908f9f..81d0731 100644 --- a/src/tensor_parallel/state_actions.py +++ b/src/tensor_parallel/state_actions.py @@ -111,13 +111,9 @@ def __call__(self, tensor: Tensor, rank: int) -> Tensor: def undo(self, tensors: Sequence[Tensor]) -> Tensor: grouped_tensor = [] for tensor in tensors: - shape = list(input.shape) - shape[self.dim] = shape[self.dim] // self.num_chunks - shape.insert(self.dim, self.num_chunks) + shape = list(tensor.shape) # ... x hidden_size x ... + shape[self.dim] = shape[self.dim] // self.num_groups + shape.insert(self.dim, self.num_groups) # ... group x group_size x ... grouped_tensor.append(tensor.reshape(*shape).cpu()) - output_shape = tensors[0].shape - del output_shape[self.dim] - output_shape[self.dim] = -1 - - return torch.cat(grouped_tensor, dim=self.dim).reshape(*output_shape) + return torch.cat(grouped_tensor, dim=self.dim + 1).flatten(self.dim, self.dim + 1) diff --git a/tests/test_saving.py b/tests/test_saving.py index f0afb71..89f5567 100644 --- a/tests/test_saving.py +++ b/tests/test_saving.py @@ -44,7 +44,7 @@ def test_no_parallelism_zero_3(devices, model_name): @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3]) @pytest.mark.parametrize( - "model_name", ["bert-base-uncased", "hf-internal-testing/tiny-random-t5", "hf-internal-testing/tiny-random-t5"] + "model_name", ["bert-base-uncased", "gpt2", "hf-internal-testing/tiny-random-t5"] ) def test_parallelism_no_zero_3(devices, model_name): model = AutoModel.from_pretrained(model_name).to(devices[0]).half() From 9b83437fb30a82c82ba86d061c0c40a30b7d30ce Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sat, 22 Jul 2023 11:37:04 +0000 Subject: [PATCH 4/4] black tests --- tests/test_saving.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_saving.py b/tests/test_saving.py index 89f5567..85b4204 100644 --- a/tests/test_saving.py +++ b/tests/test_saving.py @@ -43,9 +43,7 @@ def test_no_parallelism_zero_3(devices, model_name): @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3]) -@pytest.mark.parametrize( - "model_name", ["bert-base-uncased", "gpt2", "hf-internal-testing/tiny-random-t5"] -) +@pytest.mark.parametrize("model_name", ["bert-base-uncased", "gpt2", "hf-internal-testing/tiny-random-t5"]) def test_parallelism_no_zero_3(devices, model_name): model = AutoModel.from_pretrained(model_name).to(devices[0]).half() model_state_dict = model.state_dict()