Skip to content

Commit

Permalink
Merge pull request #106 from BlackSamorez/sharding_refactor
Browse files Browse the repository at this point in the history
[WIP] ZeRO-3 refactoring (sharding)
  • Loading branch information
Andrei Panferov authored Aug 6, 2023
2 parents 807a2ff + 7b6f692 commit a7d1939
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 368 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.3.2
version = 2.0.0
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from tensor_parallel.config import Config
from tensor_parallel.dispatch import convert_state_dict, infer_sharded_device_map, save_tensor_parallel
from tensor_parallel.factory import tensor_parallel
from tensor_parallel.legacy import Sharded
from tensor_parallel.pretrained_model import TensorParallelPreTrainedModel
from tensor_parallel.sharding import Sharded
from tensor_parallel.shard import make_distributed_shard
from tensor_parallel.state_actions import StateAction
from tensor_parallel.tensor_parallel import TensorParallel
40 changes: 39 additions & 1 deletion src/tensor_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import re
from functools import partial
from typing import Any, Callable, Dict, Sequence, Union
from itertools import chain
from typing import Any, Callable, Dict, Iterable, Mapping, Sequence, Union

import torch
from torch import nn
Expand Down Expand Up @@ -151,3 +152,40 @@ def add_lora_rules(model: nn.Module, config: Config) -> Config:
config.input_rules.update(lora_input_rules)
config.output_rules.update(lora_output_rules)
return config


def get_parameter_name_mapping(names: Iterable[str], tensor_parallel_config: Config) -> Mapping[str, str]:
"""Maps original model's parameter names to tensor_parallel parameter names.
Args:
names (Iterable[str]): Parameter names
tensor_parallel_config (Config): Config
Returns:
Iterable[str]: tensor_parallel parameter names
"""
patterns = tuple(
regex.pattern
for regex in chain(tensor_parallel_config.input_rules.keys(), tensor_parallel_config.output_rules.keys())
)
patterns = [pattern[:-1] if pattern.endswith("$") else pattern for pattern in patterns]
patterns = set(pattern if pattern.endswith(".") else pattern + r"\." for pattern in patterns)
patterns = [re.compile(pattern) for pattern in patterns]

insertions = {name: [] for name in names}
for pattern in patterns:
for name in names:
match = pattern.search(name)
if match is not None:
end_pos = match.span()[1]
insertions[name].append(end_pos)
insertions = {name: sorted(pos) for name, pos in insertions.items()}

name_replacements = {}
for name in names:
new_name = name
for pos in insertions[name][::-1]:
new_name = new_name[:pos] + r"tp_wrapped_module." + new_name[pos:]
name_replacements[name] = new_name

return name_replacements
27 changes: 6 additions & 21 deletions src/tensor_parallel/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import re
from contextlib import contextmanager
from itertools import chain
from typing import Union

import torch

from tensor_parallel.config import get_parameter_name_mapping
from tensor_parallel.pretrained_model import TensorParallelPreTrainedModel
from tensor_parallel.sharding import Sharded
from tensor_parallel.tensor_parallel import Config, TensorParallel
from tensor_parallel.utils import find_tied_weight_aliases


@contextmanager
def save_tensor_parallel(model: Union[TensorParallel, TensorParallelPreTrainedModel, Sharded]):
def save_tensor_parallel(model: Union[TensorParallel, TensorParallelPreTrainedModel]):
"""Enables state_dict reconstruction for tensor_parallel models.
With it '.state_dict()' produces a state dict that can be loaded into an underlying model.
Example:
Expand All @@ -24,13 +23,13 @@ def save_tensor_parallel(model: Union[TensorParallel, TensorParallelPreTrainedMo
```
Args:
model (Union[TensorParallel, TensorParallelPreTrainedModel, Sharded]): tensor_parallel model
model (Union[TensorParallel, TensorParallelPreTrainedModel]): tensor_parallel model
"""
model.preserve_shards_when_saving = False
model.set_preserve_shards_when_saving(False)
try:
yield
finally:
model.preserve_shards_when_saving = True
model.set_preserve_shards_when_saving(True)


def infer_sharded_data_device_id(name: str):
Expand Down Expand Up @@ -112,21 +111,7 @@ def convert_data(input_state_dict, output_state_dict, tensor_parallel_config: Co


def convert_names(state_dict, tensor_parallel_config: Config):
patterns = tuple(
regex.pattern
for regex in chain(tensor_parallel_config.input_rules.keys(), tensor_parallel_config.output_rules.keys())
)
patterns = set(pattern[:-1] + "\." if pattern.endswith("$") else pattern for pattern in patterns)
patterns = [re.compile(pattern) for pattern in patterns]

name_replacements = {name: name for name in state_dict.keys()}
for pattern in patterns:
for initial_name, old_name in name_replacements.items():
match = pattern.search(old_name)
if match is not None:
end_pos = match.span()[1]
new_name = old_name[:end_pos] + "tp_wrapped_module." + old_name[end_pos:]
name_replacements[initial_name] = new_name
name_replacements = get_parameter_name_mapping(state_dict.keys(), tensor_parallel_config)

for initial_name, final_name in name_replacements.items():
state_dict[final_name] = state_dict.pop(initial_name)
Expand Down
49 changes: 8 additions & 41 deletions src/tensor_parallel/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Collection, Optional, Sequence, Union
from typing import Collection, Optional, Sequence, Tuple, Union

import torch
import torch.distributed
Expand All @@ -9,7 +9,6 @@
from tensor_parallel.config import Config
from tensor_parallel.pretrained_model import TensorParallelPreTrainedModel
from tensor_parallel.shard import make_distributed_shard
from tensor_parallel.sharding import Sharded
from tensor_parallel.tensor_parallel import TensorParallel

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -44,11 +43,10 @@ def tensor_parallel(
and manually re-assembled for each forward. This is equivalent to pytorch FullyShardedDataParallel
:param sharded_param_names: if sharded=True, this is a list of all parameter names (strings) that ZeRO-3 applies to;
by default, ZeRO-3 applies to all parameters that are not split with tensor parallelism.
:note: the default sharded_param_names are formed of parameters that are equal between shards after TP is applied
:note: the default sharded_param_names are formed of all parameters that were not processed with tensor parallelism
:param kwargs: additional keyword arguments passed to TensorParallel init
"""
num_trainable_parameters = sum(p.numel() for p in module.parameters() if p.requires_grad)
distributed = distributed if distributed is not None else torch.distributed.is_initialized()

if distributed:
Expand All @@ -60,53 +58,22 @@ def tensor_parallel(
return make_distributed_shard(module, device=torch.device(device_ids[0]), **kwargs)
else:
if isinstance(module, PreTrainedModel):
module = TensorParallelPreTrainedModel(
return TensorParallelPreTrainedModel(
module,
device_ids=device_ids,
tensor_parallel_config=tensor_parallel_config,
distributed=distributed,
sharded=sharded,
sharded_param_names=sharded_param_names,
**kwargs,
)
module.wrapped_model = _maybe_sharded(
module.wrapped_model, sharded, num_trainable_parameters, sharded_param_names=sharded_param_names
)
else:
module = TensorParallel(
return TensorParallel(
module,
device_ids=device_ids,
tensor_parallel_config=tensor_parallel_config,
distributed=distributed,
sharded=sharded,
sharded_param_names=sharded_param_names,
**kwargs,
)
module = _maybe_sharded(module, sharded, num_trainable_parameters, sharded_param_names=sharded_param_names)

return module


def _maybe_sharded(
module: TensorParallel,
sharded: Optional[bool],
num_trainable_parameters: int,
sharded_param_names: Optional[Collection[str]],
**kwargs,
) -> Union[Sharded, TensorParallel]:
"""Determines if sharding is necessary, returns either Sharded(module) or module itself, if unchanged"""
determined_automatically = sharded is None
if sharded is None:
num_trainable_parameters_after_tp = sum(p.numel() for p in module.parameters() if p.requires_grad)
assert num_trainable_parameters_after_tp >= num_trainable_parameters
sharded = num_trainable_parameters_after_tp > num_trainable_parameters
# use sharding if there are some *trainable* parameter that are replicated on more than one device

model_is_meta = any([p.device.type == "meta" for p in module.parameters()])
if sharded and model_is_meta and sharded_param_names is None:
logger.warning(
f"Not sharding the model that should be sharded because it has meta tensors which prevent sharding without 'sharded_param_names'. It's recomended to shard a model after loading it's weights."
)
sharded = False
elif sharded and determined_automatically:
num_extra_parameters = num_trainable_parameters_after_tp - num_trainable_parameters
replicated_parameters = num_extra_parameters // max(1, len(module.devices) - 1)
logger.warning(f"Using ZeRO-3 sharding for {replicated_parameters} non tensor-parallel parameters")

return Sharded(module, sharded_param_names=sharded_param_names, **kwargs) if sharded else module
20 changes: 20 additions & 0 deletions src/tensor_parallel/legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
from typing import Collection, Optional, Tuple

from torch import nn

from tensor_parallel.pretrained_model import TensorParallelPreTrainedModel
from tensor_parallel.tensor_parallel import TensorParallel

logger = logging.getLogger(__file__)


class Sharded(nn.Module):
def __new__(
cls,
module: Tuple[TensorParallel, TensorParallelPreTrainedModel],
sharded_param_names: Optional[Collection[str]] = None,
):
logger.warning(f"`Sharded` is deprecated. Please use `.apply_sharding()` method")
module.apply_sharding(sharded_param_names)
return module
110 changes: 40 additions & 70 deletions src/tensor_parallel/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from tensor_parallel.config import TENSOR_PARALLEL_USE_NATIVE, Config
from tensor_parallel.per_device_tensors import PerDeviceTensors
from tensor_parallel.sharding import Sharded
from tensor_parallel.slicing_configs import PREDEFINED_CONFIGS
from tensor_parallel.tensor_parallel import TensorParallel, check_device_ids, parallel_apply, parallel_apply_simple
from tensor_parallel.utils import nested_map
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(
output_device: Optional[torch.device] = None,
output_device_index: Optional[int] = None,
tensor_parallel_config: Optional[Config] = None,
distributed: bool = True,
**kwargs,
):
super().__init__(module.config) # Temporary empty config. Gets replaced in from_pretrained

Expand All @@ -56,7 +55,7 @@ def __init__(
tensor_parallel_config = find_predefined_tensor_parallel_config(module.config, device_ids)

self.wrapped_model = TensorParallel(
module, device_ids, output_device, output_device_index, tensor_parallel_config, distributed=distributed
module, device_ids, output_device, output_device_index, tensor_parallel_config, **kwargs
)

@property
Expand All @@ -67,13 +66,11 @@ def devices(self):
def tensor_parallel_config(self):
return self.wrapped_model.tensor_parallel_config

@property
def preserve_shards_when_saving(self):
return self.wrapped_model.preserve_shards_when_saving
def set_preserve_shards_when_saving(self, value: bool):
self.wrapped_model.set_preserve_shards_when_saving(value)

@preserve_shards_when_saving.setter
def preserve_shards_when_saving(self, value):
self.wrapped_model.preserve_shards_when_saving = value
def apply_sharding(self, *args, **kwargs):
self.wrapped_model.apply_sharding(*args, **kwargs)

def forward(self, *args, **kwargs):
return self.wrapped_model(*args, **kwargs)
Expand Down Expand Up @@ -123,70 +120,43 @@ def get_encoder(self):
encoder_decoder_shard.get_encoder() for encoder_decoder_shard in self.wrapped_model.module_shards
]

encoder_wrapper_class = None
if isinstance(self.wrapped_model, TensorParallel):
class _EncoderWrapper(torch.nn.Module):
def __init__(self, wrapped_pretrained_model: TensorParallelPreTrainedModel) -> None:
super().__init__()
self.wrapped_pretrained_model = wrapped_pretrained_model

def forward(self, *args, **kwargs):
if self.wrapped_pretrained_model.wrapped_model.need_delayed_init:
for shard, device in zip(
self.wrapped_pretrained_model.wrapped_model.module_shards,
self.wrapped_pretrained_model.wrapped_model.devices,
):
shard.to(device)
self.wrapped_pretrained_model.wrapped_model.need_delayed_init = False

class _EncoderWrapper(torch.nn.Module):
def __init__(self, wrapped_pretrained_model: TensorParallelPreTrainedModel) -> None:
super().__init__()
self.wrapped_pretrained_model = wrapped_pretrained_model
# Synchronize sharded parameters
if self.wrapped_pretrained_model.wrapped_model.sharding_manager is not None:
self.wrapped_pretrained_model.wrapped_model.sharding_manager.synchronize_weights(
self.wrapped_pretrained_model.wrapped_model.all_cuda
)

def forward(self, *args, **kwargs):
(
(
inputs,
kwargs_tup,
) = self.wrapped_pretrained_model.wrapped_model.prepare_args_kwargs_for_forward(*args, **kwargs)
if self.wrapped_pretrained_model.wrapped_model.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
return parallel_apply(
encoder_shards,
inputs,
kwargs_tup,
) = self.wrapped_pretrained_model.wrapped_model.prepare_args_kwargs_for_forward(*args, **kwargs)
if self.wrapped_pretrained_model.wrapped_model.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
return parallel_apply(
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.devices,
)[self.wrapped_pretrained_model.wrapped_model.output_device_index]
else:
return parallel_apply_simple(
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.devices,
)[self.wrapped_pretrained_model.wrapped_model.output_device_index]

encoder_wrapper_class = _EncoderWrapper

elif isinstance(self.wrapped_model, Sharded):

class _EncoderWrapper(torch.nn.Module):
def __init__(self, wrapped_pretrained_model: TensorParallelPreTrainedModel) -> None:
super().__init__()
self.wrapped_pretrained_model = wrapped_pretrained_model

def forward(self, *args, **kwargs):
if (
len(self.wrapped_pretrained_model.wrapped_model.module.module_shards) > 1
and len(self.wrapped_pretrained_model.wrapped_model.sharded_param_names) > 0
):
self.wrapped_pretrained_model.wrapped_model._maybe_fill_sharded_params()
(
self.wrapped_pretrained_model.wrapped_model.devices,
)[self.wrapped_pretrained_model.wrapped_model.output_device_index]
else:
return parallel_apply_simple(
encoder_shards,
inputs,
kwargs_tup,
) = self.wrapped_pretrained_model.wrapped_model.module.prepare_args_kwargs_for_forward(
*args, **kwargs
)
if self.wrapped_pretrained_model.wrapped_model.module.all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
return parallel_apply(
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.module.devices,
)[self.wrapped_pretrained_model.wrapped_model.module.output_device_index]
else:
return parallel_apply_simple(
encoder_shards,
inputs,
kwargs_tup,
self.wrapped_pretrained_model.wrapped_model.module.devices,
)[self.wrapped_pretrained_model.wrapped_model.module.output_device_index]

encoder_wrapper_class = _EncoderWrapper

return encoder_wrapper_class(self)
self.wrapped_pretrained_model.wrapped_model.devices,
)[self.wrapped_pretrained_model.wrapped_model.output_device_index]

return _EncoderWrapper(self)
Loading

0 comments on commit a7d1939

Please sign in to comment.