Skip to content

Commit

Permalink
Merge pull request #73 from BlackSamorez/state_dict_fixes
Browse files Browse the repository at this point in the history
State dict fixes for tied weights
  • Loading branch information
IaroslavLisniak authored May 14, 2023
2 parents 8457dc3 + 4314ad3 commit f99e200
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 47 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.2.3
version = 1.2.4
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down
17 changes: 13 additions & 4 deletions src/tensor_parallel/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensor_parallel.pretrained_model import TensorParallelPreTrainedModel
from tensor_parallel.sharding import Sharded
from tensor_parallel.tensor_parallel import Config, TensorParallel
from tensor_parallel.utils import find_tied_weight_aliases


@contextmanager
Expand Down Expand Up @@ -55,18 +56,26 @@ def infer_sharded_device_map(tp_model: Union[TensorParallel, TensorParallelPreTr
Returns:
dict: parameter device mapping
"""
id_to_device = {}
for name, param in tp_model.named_parameters():
id_to_device[id(param)] = tp_model.devices[infer_sharded_data_device_id(name)]
for name, buffer in tp_model.named_buffers():
id_to_device[id(buffer)] = tp_model.devices[infer_sharded_data_device_id(name)]
id_to_aliases = find_tied_weight_aliases(tp_model)

device_map = {}
for name, _ in tp_model.named_parameters():
device_map[name] = tp_model.devices[infer_sharded_data_device_id(name)]
for name, _ in tp_model.named_buffers():
device_map[name] = tp_model.devices[infer_sharded_data_device_id(name)]
for idx, aliases in id_to_aliases.items():
for alias in aliases:
device_map[alias] = id_to_device[idx]

return device_map


def convert_state_dict(
input_state_dict, tensor_parallel_config: Config, world_size: int, for_pretrained: bool = False
) -> dict:
"""Creates a state_dict to be loaded into a tensor parallel model from a state_dict of a base model.
WARNING: this function doesn't properly work with tied weights. You'll probably need to fix the resulting state_dict by hand.
Args:
input_state_dict (_type_): state_dict to be converted
Expand Down
19 changes: 10 additions & 9 deletions src/tensor_parallel/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
The TensorParallel module wrapper for Hugging Face PreTrainedModel
"""
import logging
from functools import lru_cache
from typing import Any, Dict, Optional, Sequence

import torch
Expand Down Expand Up @@ -57,11 +58,6 @@ def __init__(
module, device_ids, output_device, output_device_index, tensor_parallel_config
)

self.encoder_shards = nn.ModuleList()
if module.config.is_encoder_decoder:
for encoder_decoder_shard in self.wrapped_model.module_shards:
self.encoder_shards.append(encoder_decoder_shard.get_encoder())

@property
def devices(self):
return self.wrapped_model.devices
Expand Down Expand Up @@ -113,11 +109,16 @@ def _reorder_cache(self, past, beam_idx):
beam_idx.to(self.wrapped_model.devices[i]),
)

@lru_cache(maxsize=None)
def get_encoder(self):
assert len(self.wrapped_model.module_shards), "Can't get encoder since no module shards present"
if len(self.wrapped_model.module_shards) == 1:
return self.wrapped_model.module_shards[0].get_encoder()

encoder_shards = [
encoder_decoder_shard.get_encoder() for encoder_decoder_shard in self.wrapped_model.module_shards
]

encoder_wrapper_class = None
if isinstance(self.wrapped_model, TensorParallel):

Expand All @@ -133,14 +134,14 @@ def forward(self, *args, **kwargs):
) = self.wrapped_pretrained_model.wrapped_model.prepare_args_kwargs_for_forward(*args, **kwargs)
if self.wrapped_pretrained_model.wrapped_model.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
return parallel_apply(
self.wrapped_pretrained_model.encoder_shards,
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.devices,
)[self.wrapped_pretrained_model.wrapped_model.output_device_index]
else:
return parallel_apply_simple(
self.wrapped_pretrained_model.encoder_shards,
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.devices,
Expand Down Expand Up @@ -169,14 +170,14 @@ def forward(self, *args, **kwargs):
)
if self.wrapped_pretrained_model.wrapped_model.module.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
return parallel_apply(
self.wrapped_pretrained_model.encoder_shards,
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.module.devices,
)[self.wrapped_pretrained_model.wrapped_model.module.output_device_index]
else:
return parallel_apply_simple(
self.wrapped_pretrained_model.encoder_shards,
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.module.devices,
Expand Down
8 changes: 6 additions & 2 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_bloom_config(model_config: BloomConfig, devices: Sequence[torch.device])
r".*mlp\.dense_4h_to_h\.weight$": Split(world_size=world_size, dim=1),
r".*mlp\.dense_4h_to_h\.bias$": Scale(world_size=world_size),
# BloomModel
r".*word_embeddings.weight$": Split(world_size=world_size, dim=1),
r".*word_embeddings\.weight$": Split(world_size=world_size, dim=1),
r".*lm_head\.weight$": Split(world_size=world_size, dim=1),
# note: ^-- lm_head.weight is tied with word_embeddings
},
input_rules={
Expand Down Expand Up @@ -108,7 +109,8 @@ def select_kv_for_rank(*kvs, rank):
# T5DenseActDense
r".*DenseReluDense\.wo\.weight$": Split(world_size=world_size, dim=1),
# T5Model
r".*shared.weight$": Split(world_size=world_size, dim=1),
r".*embed_tokens\.weight$": Split(world_size=world_size, dim=1),
r".*shared\.weight$": Split(world_size=world_size, dim=1),
r".*lm_head\.weight$": Split(world_size=world_size, dim=1),
# note: ^-- lm_head.weight tied with word embeddings
},
Expand All @@ -121,6 +123,7 @@ def select_kv_for_rank(*kvs, rank):
r".*SelfAttention$": {0: "sum", 1: gather_kv_across_ranks},
r".*DenseReluDense$": {0: "sum"},
r".*shared$": {0: "gather -1"},
r".*embed_tokens$": {0: "gather -1"},
r".*lm_head$": {0: "sum"},
},
attr_rules={
Expand Down Expand Up @@ -216,6 +219,7 @@ def __call__(self, tensor: torch.Tensor, rank: int) -> torch.Tensor:
# GPT2Model
r".*wte\.weight$": Split(world_size=world_size, dim=1),
r".*wpe\.weight$": Split(world_size=world_size, dim=1),
r".*lm_head\.weight$": Split(world_size=world_size, dim=1),
# GPT2LMHeadModel
# note: ^-- lm_head.weight is tied with word_embeddings
},
Expand Down
17 changes: 10 additions & 7 deletions src/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,21 @@ def state_dict(self, *args, **kwargs):
if self.preserve_shards_when_saving:
return state_dict

for i in range(len(self.module_shards)):
sanity_check_param_name = next(
name for name, _ in state_dict.items() if name.endswith(f"_sanity_check_params.{i}")
)
del state_dict[sanity_check_param_name]

# fix names for zero-3'ed params that were inside _TensorParallelWrapper
names_inside_tp_wrapper = [name for name in state_dict.keys() if "tp_wrapped_module." in name]
for name in names_inside_tp_wrapper:
state_dict[name.replace("tp_wrapped_module.", "")] = state_dict.pop(name)

shards_prefix = next(name for name, _ in state_dict.items() if "module_shards." in name)
try:
shards_prefix = next(name for name, _ in state_dict.items() if "module_shards." in name)
except StopIteration:
return state_dict # no parameters are actually tensor parallel
shards_prefix = shards_prefix[: shards_prefix.find("module_shards.") + len("module_shards.")]
module_prefix = shards_prefix[: -len("module_shards.")]

Expand Down Expand Up @@ -171,12 +180,6 @@ def state_dict(self, *args, **kwargs):
for i in range(len(self.module_shards)):
del state_dict[f"{shards_prefix}{i}.{unsharded_name}"]

for i in range(len(self.module_shards)):
sanity_check_param_name = next(
name for name, _ in state_dict.items() if name.endswith(f"_sanity_check_params.{i}")
)
del state_dict[sanity_check_param_name]

return state_dict


Expand Down
26 changes: 26 additions & 0 deletions src/tensor_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Based on: https://stackoverflow.com/questions/49739102/python-nested-dictionary-comparison
"""

from itertools import chain
from typing import Mapping, Optional, Sequence

from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions


Expand Down Expand Up @@ -97,3 +101,25 @@ def nested_map(fn, *t):
# Map.
flat = map(nested_flatten, t)
return nested_pack(map(fn, *flat), t[0])


def find_tied_weight_aliases(
module: nn.Module, destination: Optional[Mapping[int, Sequence[str]]] = None, prefix: Optional[str] = None
) -> Mapping[int, Sequence[str]]:
if prefix is None:
prefix = ""
if destination is None:
destination = {}

for name, param in chain(module._parameters.items(), module._buffers.items()):
if param is not None:
if id(param) in destination:
destination[id(param)].append(prefix + name)
else:
destination[id(param)] = [prefix + name]

for name, submodule in module._modules.items():
if submodule is not None:
find_tied_weight_aliases(module=submodule, destination=destination, prefix=prefix + name + ".")

return destination
43 changes: 23 additions & 20 deletions tests/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_no_parallelism_zero_3(devices, model_name):
model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
model_state_dict = model.state_dict()
model_tp = Sharded(
TensorParallel(model, devices, tensor_parallel_config=Config({}, {}, {}, {}))
) # zero-3 sharding only
del model
with save_tensor_parallel(model_tp):
model_tp_state_dict = model_tp.state_dict()
del model_tp

assert sorted(list(model_state_dict.keys())) == sorted(list(model_tp_state_dict.keys()))

Expand All @@ -41,13 +43,17 @@ def test_no_parallelism_zero_3(devices, model_name):


@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
@pytest.mark.parametrize(
"model_name", ["bert-base-uncased", "hf-internal-testing/tiny-random-t5", "hf-internal-testing/tiny-random-t5"]
)
def test_parallelism_no_zero_3(devices, model_name):
model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
model_state_dict = model.state_dict()
model_tp = TensorParallelPreTrainedModel(model, devices)
del model
with save_tensor_parallel(model_tp):
model_tp_state_dict = model_tp.state_dict()
del model_tp

assert sorted(list(model_state_dict.keys())) == sorted(list(model_tp_state_dict.keys()))

Expand All @@ -63,28 +69,32 @@ def test_parallelism_no_zero_3(devices, model_name):
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_parallelism_zero_3(devices, model_name):
model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
model_state_dict = model.state_dict()
model_tp = tensor_parallel(model, devices, sharded=True)
del model
with save_tensor_parallel(model_tp):
model_tp_state_dict = model_tp.state_dict()
del model_tp

assert sorted(list(model_state_dict.keys())) == sorted(list(model_tp_state_dict.keys()))

for name in model_state_dict.keys():
data = model_state_dict[name]
data_tp = model_tp_state_dict[name]

assert data.shape == data_tp.shape
assert data.shape == data_tp.shape, name

torch.testing.assert_close(data, data_tp, msg=lambda msg: f"{name}:\n{msg}")


@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
@pytest.mark.parametrize(
"model_name", ["bert-base-uncased", "hf-internal-testing/tiny-random-t5", "hf-internal-testing/tiny-random-bloom"]
)
@pytest.mark.parametrize("shard_as_pretrained", [True, False])
def test_save_keep_shards(devices, model_name, shard_as_pretrained):
model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
if shard_as_pretrained:
model_tp = TensorParallelPreTrainedModel(model, devices)
else:
Expand All @@ -93,22 +103,15 @@ def test_save_keep_shards(devices, model_name, shard_as_pretrained):
model_tp.load_state_dict(model_tp.state_dict())


def test_sharding_meta():
model_name = "bert-base-uncased"
with init_empty_weights():
model_tp = TensorParallel(AutoModel.from_pretrained(model_name), ["meta", "meta"])

with pytest.raises(RuntimeError):
Sharded(model_tp)


@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
@pytest.mark.parametrize(
"model_name", ["bert-base-uncased", "hf-internal-testing/tiny-random-t5", "hf-internal-testing/tiny-random-bloom"]
)
@pytest.mark.parametrize("pretrained", [True])
def test_save_shards_load_shards(devices, model_name, pretrained):
devices = [torch.device(device) for device in devices]

model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
shraded_class = TensorParallelPreTrainedModel if pretrained else TensorParallel
model_tp = shraded_class(model, devices)

Expand All @@ -120,7 +123,7 @@ def test_save_shards_load_shards(devices, model_name, pretrained):
del model_tp

with init_empty_weights():
model_tp = shraded_class(AutoModel.from_config(AutoConfig.from_pretrained(model_name)), devices)
model_tp = shraded_class(AutoModel.from_config(AutoConfig.from_pretrained(model_name)).half(), devices)

checkpoint = PATH_TO_SAVE + ("pytorch_model.bin.index.json" if pretrained else "test_save_shards_load_shards.bin")
load_checkpoint_in_model(
Expand All @@ -135,7 +138,7 @@ def test_save_shards_load_shards(devices, model_name, pretrained):
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_convert_state_dict(use_pretrained, devices, model_name):
model = AutoModel.from_pretrained(model_name).to(devices[0])
model = AutoModel.from_pretrained(model_name).to(devices[0]).half()
torch.save(model.state_dict(), PATH_TO_SAVE + "test_convert_state_dict.bin")

if use_pretrained:
Expand Down
14 changes: 10 additions & 4 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,20 @@ def test_encoder(use_predefined_config, model_name, sharded):
devices = ["cpu"] * 2
model = T5ForConditionalGeneration.from_pretrained(model_name, low_cpu_mem_usage=True).float().to(devices[0])

input = torch.randint(1, 1000, size=(2, 3), device=devices[0])
out_ref = model.get_encoder()(input)
inp1 = torch.randint(1, 1000, size=(2, 3), device=devices[0])
inp2 = torch.randint(1, 1000, size=(2, 3), device=devices[0])

out1_ref = model.get_encoder()(inp1)
out2_ref = model.get_encoder()(inp2)

if not use_predefined_config:
model.config.architectures = ["Pretend we don't know this architecture"]
model_tp = tensor_parallel(model, devices, sharded=sharded)
assert isinstance(model_tp, TensorParallelPreTrainedModel)
del model

out = model_tp.get_encoder()(input)
torch.testing.assert_close(out_ref, out, atol=3e-3, rtol=1e-05)
out1 = model_tp.get_encoder()(inp1)
out2 = model_tp.get_encoder()(inp2)

torch.testing.assert_close(out1_ref, out1, atol=3e-3, rtol=1e-05)
torch.testing.assert_close(out2_ref, out2, atol=3e-3, rtol=1e-05)

0 comments on commit f99e200

Please sign in to comment.