Skip to content

Commit

Permalink
Fix TP initialization (#35860)
Browse files Browse the repository at this point in the history
* fix tp

* Update modeling_utils.py

* style

* style

* Update test_tp.py

* Update test_tp.py

* style

* Update test_tp.py

* Update test_tp.py

* Update test_tp.py

* Update test_tp.py
  • Loading branch information
Cyrilvallez authored Jan 28, 2025
1 parent f85ba20 commit f48ecd7
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 43 deletions.
92 changes: 49 additions & 43 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,6 +3443,29 @@ def from_pretrained(
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if tp_plan is not None and device_map is not None:
raise ValueError(
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
)

# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
# And temporarily setting the default device to current process rank result in the following error
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
tp_device = None
if tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
# This is the easiest way to dispatch to the current process device
device_map = tp_device

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -4090,7 +4113,6 @@ def from_pretrained(

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
Expand All @@ -4106,16 +4128,6 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)

if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
Expand Down Expand Up @@ -4249,38 +4261,32 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)

with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
55 changes: 55 additions & 0 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import os
import subprocess
import tempfile
import textwrap

from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
Expand All @@ -30,6 +33,22 @@


class TestTensorParallel(TestCasePlus):
def torchrun(self, script: str):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necesary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
cmd = (
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()

# Note that the subprocess will be waited for here, and raise an error if not successful
try:
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True)
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")

@require_torch_multi_gpu
def test_tp(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
Expand All @@ -43,6 +62,42 @@ def test_tp(self):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call

@require_torch_multi_gpu
def test_loading_memory_consumption(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
torch.distributed.barrier()
# The expected full model memory footprint
expected_model_memory = 16
overhead_factor = 1.2
# Assert we did not use more than the full model expected memory (with some overhead)
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
raise ValueError("Loading the model used more than the full model size")
# Assert we correctly handled the sharding between devices
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
raise ValueError("Each model shard is larger than what is expected.")
torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)


if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
Expand Down

0 comments on commit f48ecd7

Please sign in to comment.