Skip to content

Commit

Permalink
Merge pull request #16 from huggingface/phuc/feature_add_parallel_con…
Browse files Browse the repository at this point in the history
…text

[Refactor] Add minimal ParallelContext
  • Loading branch information
xrsrke authored Jan 12, 2024
2 parents cbe0d30 + 490bc83 commit fd99571
Show file tree
Hide file tree
Showing 41 changed files with 1,198 additions and 947 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,18 @@ We showcase usage in the `examples` directory.

Let's go through some key concepts.

## DistributedProcessGroups
## ParallelContext

`DistributedProcessGroups` is the base class referencing all the process groups you might need when running parallel workloads. You can initialize it using the following:
`ParallelContext` is the base class referencing all the process groups you might need when running parallel workloads. You can initialize it using the following:
```python
from nanotron.core.process_groups import get_process_groups
from nanotron.distributed import ParallelContext

dp, tp, pp = ... # Predefine your topology
dpg: DistributedProcessGroups = get_process_groups(data_parallel_size=dp, tensor_parallel_size=tp, pipeline_parallel_size=pp)
# define your topology
parallel_context = ParallelContext(
tensor_parallel_size=2,
data_parallel_size=2,
pipeline_parallel_size=2
)
```

`ProcessGroups` is a mechanism in order to run distributed collectives (`all-reduce`, `all-gather`, ...) on a subgroup of all the ranks. It provides the granularity needed for 3D parallelism.
Expand Down
58 changes: 29 additions & 29 deletions run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,32 @@
- Benchmark:
USE_BENCH=1 USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=2 run_benchmark2.py --pp 2 --tp 1 --dp 1 --model_name huggyllama/llama-7b --ckpt-path /admin/home/ferdinand_mom/.cache/huggingface/hub/models--HuggingFaceBR4--llama-7b-orig/snapshots/2160b3d0134a99d365851a7e95864b21e873e1c3
"""
import os
import argparse
import os
from pathlib import Path

import torch
from nanotron import logging
from nanotron.config import GenerationArgs, ParallelismArgs, LoggingArgs, get_config_from_file
from nanotron.config import GenerationArgs, LoggingArgs, ParallelismArgs, get_config_from_file
from nanotron.core import distributed as dist
from nanotron.core.parallel.parameters import sanity_check
from nanotron.core.parallel.pipeline_parallelism.engine import (
OneForwardOneBackwardPipelineEngine,
)
from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer
from nanotron.core.parallel.tensor_parallelism.enum import TensorParallelLinearMode
from nanotron.core.process_groups import get_process_groups
from nanotron.core.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
set_random_seed,
)
from nanotron.distributed import ParallelContext
from nanotron.generate.generation import (
GenerationInput,
TokenizerConfig,
greedy_search_text,
)

from nanotron.helpers import set_logger_verbosity_format
from nanotron.logging import log_rank
from nanotron.serialize import (
Expand All @@ -51,6 +50,7 @@

logger = logging.get_logger(__name__)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default=None, help="Model name")
Expand All @@ -74,32 +74,31 @@ def main():
recompute_granularity=None,
tp_linear_async_communication=True,
)

logging_config = LoggingArgs(
log_level="info",
log_level_replica="info",
)

dtype = torch.bfloat16

# Set random states
set_random_seed(42)

# Initialise all process groups
dpg = get_process_groups(
parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)

# Set log levels
if dist.get_rank(dpg.world_pg) == 0:
if dist.get_rank(parallel_context.world_pg) == 0:
if logging_config.log_level is not None:
set_logger_verbosity_format(logging_config.log_level, dpg=dpg)
set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
else:
if logging_config.log_level_replica is not None:
set_logger_verbosity_format(logging_config.log_level_replica, dpg=dpg)

set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)

tokenizer_path = args.model_name
# if config.yaml in checkpoint path we use it
Expand All @@ -118,7 +117,7 @@ def main():
assert args.model_name is not None, "model_name must be provided or config.yaml must be in checkpoint path"
model_name = args.model_name
model_config: AutoConfig = AutoConfig.from_pretrained(model_name)

# model_config.num_hidden_layers = 1
log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0)

Expand All @@ -131,7 +130,7 @@ def main():
# Get synchronized random states
if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
random_states = RandomStates(
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=dpg.tp_pg)}
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)}
)
else:
# We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
Expand All @@ -140,17 +139,17 @@ def main():
model = DistributedTrainer.build_model(
model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls](
config=model_config,
dpg=dpg,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=random_states,
),
dtype=dtype,
dpg=dpg,
parallel_context=parallel_context,
)

# Mark some parameters as tied
# TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead?
mark_tied_parameters(model=model, dpg=dpg, parallel_config=parallel_config)
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)

# Sanity check model
sanity_check(root_module=model)
Expand All @@ -163,7 +162,7 @@ def main():
level=logging.INFO,
rank=0,
)
load_weights(model=model, dpg=dpg, root_folder=checkpoint_path)
load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
Expand All @@ -184,15 +183,15 @@ def main():
"def fib(n)",
# "This film was probably inspired by Godzilla",
]

outputs = greedy_search_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
model=model.model,
# TODO @thomasw21: Figure out how to pass p2p.
p2p=model.model.p2p,
dpg=dpg,
parallel_context=parallel_context,
max_new_tokens=args.max_new_tokens,
max_micro_batch_size=2,
generation_config=GenerationArgs(sampler="greedy", use_cache=False),
Expand All @@ -204,36 +203,36 @@ def main():
)

dist.barrier()

for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)

log_rank(
f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}",
logger=logger,
level=logging.INFO,
rank=0,
)

log_rank(
f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}",
logger=logger,
level=logging.INFO,
rank=0,
)

log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)

if args.compare_with_no_cache:

outputs = greedy_search_text(
Expand All @@ -243,7 +242,7 @@ def main():
model=model.model,
# TODO @thomasw21: Figure out how to pass p2p.
p2p=model.model.p2p,
dpg=dpg,
parallel_context=parallel_context,
max_new_tokens=args.max_new_tokens,
max_micro_batch_size=2,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
Expand All @@ -255,35 +254,36 @@ def main():
)

dist.barrier()

for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)

log_rank(
f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}",
logger=logger,
level=logging.INFO,
rank=0,
)

log_rank(
f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}",
logger=logger,
level=logging.INFO,
rank=0,
)

log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)


if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval
output_pp_rank=output_pp_rank,
vocab_size=trainer.model_config.vocab_size,
seed=trainer.config.data.seed,
dpg=trainer.dpg,
parallel_context=trainer.parallel_context,
)()
elif isinstance(trainer.config.data.dataset, PretrainDatasetsArgs):
log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

with main_rank_first(trainer.dpg.world_pg):
with main_rank_first(trainer.parallel_context.world_pg):
# 1st device processes dataset and cache it, then other devices load from cache
# TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
# TODO: generalise to include for validation/test splits
Expand All @@ -85,7 +85,7 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval
dataloader = get_train_dataloader(
train_dataset=train_dataset,
sequence_length=trainer.sequence_length,
dpg=trainer.dpg,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
Expand All @@ -97,9 +97,9 @@ def get_dataloader(trainer: DistributedTrainer, sanity_check_dataloader_interval
# Check if we have enough samples for train_steps
assert (
trainer.config.tokens.train_steps - trainer.start_iteration_step
) * trainer.global_batch_size // trainer.dpg.dp_pg.size() < len(dataloader), (
f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.dpg.dp_pg.size()}), "
f"Try train_steps<={len(dataloader) * trainer.dpg.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}"
) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() < len(dataloader), (
f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size()}), "
f"Try train_steps<={len(dataloader) * trainer.parallel_context.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}"
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}")
Expand Down
11 changes: 5 additions & 6 deletions src/nanotron/core/parallel/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from torch import nn

from nanotron.core import distributed as dist
from nanotron.core.parallel.tied_parameters import get_tied_id_to_param
from nanotron.core.process_groups import DistributedProcessGroups
from nanotron.distributed import ParallelContext
from torch import nn


def initial_sync(model: nn.Module, dpg: DistributedProcessGroups):
def initial_sync(model: nn.Module, parallel_context: ParallelContext):
# Synchronize across dp: basic assumption
sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0])
for name, param in sorted_name_params:
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=dpg.dp_pg)
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg)

# Synchronize across tied weights: basic assumption
for (_, group_ranks), param in sorted(
get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0]
):
group = dpg.world_ranks_to_pg[group_ranks]
group = parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
Loading

0 comments on commit fd99571

Please sign in to comment.