Skip to content

Commit

Permalink
Merge pull request #103 from BlackSamorez/gpt2-fix
Browse files Browse the repository at this point in the history
Gpt2 fix
  • Loading branch information
Andrei Panferov authored Jul 22, 2023
2 parents 20b1cfb + 9b83437 commit 31329ad
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 29 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.2.9
version = 1.3.0
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down
10 changes: 6 additions & 4 deletions src/tensor_parallel/autoconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn.modules import conv

from tensor_parallel.config import Config
from tensor_parallel.state_actions import Scale, Split, SplitInsideChunks
from tensor_parallel.state_actions import Scale, Split, SplitInGroupedChunks


def get_default_config(module: nn.Module, device_ids: Sequence[torch.device]) -> Config:
Expand Down Expand Up @@ -63,11 +63,13 @@ def get_default_config(module: nn.Module, device_ids: Sequence[torch.device]) ->
if not module.transposed:
state_rules[f"^{name}.weight$"] = Split(world_size=len(device_ids), dim=1)
else:
state_rules[f"^{name}.weight$"] = SplitInsideChunks(
world_size=len(device_ids), dim=0, num_chunks=groups
state_rules[f"^{name}.weight$"] = SplitInGroupedChunks(
world_size=len(device_ids), dim=0, num_groups=groups, chunk_size=1
)
if module.bias is not None:
state_rules[f"^{name}.bias$"] = Scale(world_size=len(device_ids))
input_rules[f"^{name}$"] = {0: SplitInsideChunks(world_size=len(device_ids), dim=1, num_chunks=groups)}
input_rules[f"^{name}$"] = {
0: SplitInGroupedChunks(world_size=len(device_ids), dim=1, num_groups=groups, chunk_size=1)
}
output_rules[f"^{name}$"] = {0: "sum"}
return Config(state_rules, input_rules, output_rules, {})
6 changes: 3 additions & 3 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensor_parallel.communications import CollectiveOperation
from tensor_parallel.config import Config
from tensor_parallel.per_device_tensors import PerDeviceTensors
from tensor_parallel.state_actions import Scale, Split, SplitInChunks, SplitInsideChunks
from tensor_parallel.state_actions import Scale, Split, SplitInChunks, SplitInGroupedChunks

ConfigGetter = Callable[[PretrainedConfig, Sequence[torch.device]], Config]

Expand Down Expand Up @@ -188,8 +188,8 @@ def get_gpt2_config(model_config: GPT2Config, devices: Sequence[torch.device]) -
return Config(
state_rules={
# GPT2Attention
r".*c_attn\.weight$": SplitInsideChunks(world_size=world_size, dim=1, num_chunks=3),
r".*c_attn\.bias$": SplitInsideChunks(world_size=world_size, dim=0, num_chunks=3),
r".*c_attn\.weight$": SplitInGroupedChunks(world_size=world_size, dim=1, num_groups=3, chunk_size=head_dim),
r".*c_attn\.bias$": SplitInGroupedChunks(world_size=world_size, dim=0, num_groups=3, chunk_size=head_dim),
r".*q_attn\.weight$": SplitInChunks(world_size=world_size, dim=1, chunk_size=head_dim),
r".*q_attn\.bias$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*attn\.c_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
Expand Down
35 changes: 17 additions & 18 deletions src/tensor_parallel/state_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,31 +90,30 @@ def undo(self, tensors: Sequence[Tensor]) -> Tensor:
return torch.cat([tensor.cpu() for tensor in tensors], dim=self.dim)


class SplitInsideChunks(Split):
"""AAABBBCCCDDDEEE -> (world_size = 3, num_chunks = 5) -> [ABCDE, ABCDE, ABCDE]"""
class SplitInGroupedChunks(Split):
"""AABBCCDDEE AABBCCDDEE AABBCCDDEE -> (world_size = 3, num_groups = 3, chunk_size = 2) -> [AABB AABB AABB, CCDD CCDD CCDD, EE EE EE]"""

def __init__(self, world_size: int, dim: int, num_chunks: int) -> None:
def __init__(self, world_size: int, dim: int, num_groups: int, chunk_size: int) -> None:
super().__init__(world_size, dim)
self.num_chunks = num_chunks
self.num_groups = num_groups
self.chunk_size = chunk_size

def __call__(self, tensor: Tensor, rank: int) -> Tensor:
shape = list(tensor.shape)
shape[self.dim] = shape[self.dim] // self.num_chunks
shape.insert(self.dim, self.num_chunks)
grouped_tensor = tensor.reshape(*shape)
grouped_shard = torch.tensor_split(grouped_tensor, self.world_size, dim=self.dim + 1)[rank]
return torch.flatten(grouped_shard, start_dim=self.dim, end_dim=self.dim + 1)
shape = list(tensor.shape) # ... x hidden_size x ...
shape[self.dim] //= self.num_groups
shape.insert(self.dim, self.num_groups) # ... group x group_size x ...
shape[self.dim + 1] //= self.chunk_size
shape.insert(self.dim + 2, self.chunk_size) # ... group x chunk x chunk_size ...
return (
tensor.reshape(shape).tensor_split(self.world_size, dim=self.dim + 1)[rank].flatten(self.dim, self.dim + 2)
)

def undo(self, tensors: Sequence[Tensor]) -> Tensor:
grouped_tensor = []
for tensor in tensors:
shape = list(input.shape)
shape[self.dim] = shape[self.dim] // self.num_chunks
shape.insert(self.dim, self.num_chunks)
shape = list(tensor.shape) # ... x hidden_size x ...
shape[self.dim] = shape[self.dim] // self.num_groups
shape.insert(self.dim, self.num_groups) # ... group x group_size x ...
grouped_tensor.append(tensor.reshape(*shape).cpu())

output_shape = tensors[0].shape
del output_shape[self.dim]
output_shape[self.dim] = -1

return torch.cat(grouped_tensor, dim=self.dim).reshape(*output_shape)
return torch.cat(grouped_tensor, dim=self.dim + 1).flatten(self.dim, self.dim + 1)
4 changes: 1 addition & 3 deletions tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def test_no_parallelism_zero_3(devices, model_name):


@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize(
"model_name", ["bert-base-uncased", "hf-internal-testing/tiny-random-t5", "hf-internal-testing/tiny-random-t5"]
)
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "gpt2", "hf-internal-testing/tiny-random-t5"])
def test_parallelism_no_zero_3(devices, model_name):
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
model_state_dict = model.state_dict()
Expand Down

0 comments on commit 31329ad

Please sign in to comment.