Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from huggingface/brrr-nanotron-sync
Browse files Browse the repository at this point in the history
Helping making brrr depend on nanotron
thomwolf authored Jan 17, 2024
2 parents 7f2fd88 + 6298c5e commit 2be81dc
Showing 88 changed files with 755 additions and 1,254 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ In the `/examples` directory, we provide a set of **self-sufficient** examples f
Requirements:
- Python >= 3.10
- PyTorch >= 2.0.0
- Flash-Attention >= 2.4.2

To install:
```bash
@@ -78,7 +79,7 @@ Let's go through some key concepts.

`ParallelContext` is the base class referencing all the process groups you might need when running parallel workloads. You can initialize it using the following:
```python
from nanotron.distributed import ParallelContext
from nanotron.parallel import ParallelContext

# define your topology
parallel_context = ParallelContext(
@@ -190,7 +191,7 @@ Usually the go-to solution when models can't fit within a device. The basic idea
- Distributed samplers for generation

[Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) introduces that notion upon implementing one of the first large scale transformers:
![Tensor parallelism in transformer model](assets/tensor_parallelism_in_transformer.png)
![Tensor parallelism in transformer model](assets/tensor_parallel_in_transformer.png)
(Source: [link](https://arxiv.org/abs/1909.08053))

## Pipeline parallelism
25 changes: 12 additions & 13 deletions run_generate.py
Original file line number Diff line number Diff line change
@@ -20,28 +20,27 @@

import torch
from nanotron import logging
from nanotron import distributed as dist
from nanotron.config import GenerationArgs, LoggingArgs, ParallelismArgs, get_config_from_file
from nanotron.core import distributed as dist
from nanotron.core.parallel.parameters import sanity_check
from nanotron.core.parallel.pipeline_parallelism.engine import (
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
OneForwardOneBackwardPipelineEngine,
)
from nanotron.core.parallel.pipeline_parallelism.tensor_pointer import TensorPointer
from nanotron.core.parallel.tensor_parallelism.enum import TensorParallelLinearMode
from nanotron.core.random import (
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
set_random_seed,
)
from nanotron.distributed import ParallelContext
from nanotron.generate.generation import (
from nanotron.parallel import ParallelContext
from nanotron.generation.decode import (
GenerationInput,
TokenizerConfig,
greedy_search_text,
decode_text,
)
from nanotron.helpers import set_logger_verbosity_format
from nanotron.logging import log_rank
from nanotron.logging import log_rank, set_logger_verbosity_format
from nanotron.serialize import (
load_weights,
)
@@ -184,7 +183,7 @@ def main():
# "This film was probably inspired by Godzilla",
]

outputs = greedy_search_text(
outputs = decode_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
@@ -235,7 +234,7 @@ def main():

if args.compare_with_no_cache:

outputs = greedy_search_text(
outputs = decode_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
4 changes: 2 additions & 2 deletions run_train.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@
from nanotron.config import (
PretrainDatasetsArgs,
)
from nanotron.core import distributed as dist
from nanotron.core.utils import (
from nanotron import distributed as dist
from nanotron.utils import (
main_rank_first,
)
from nanotron.dataloader import (
51 changes: 10 additions & 41 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
@@ -2,27 +2,27 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, Type

import dacite
import torch
import yaml
from dacite import from_dict
from yaml.loader import SafeLoader

from nanotron.config.models_config import NanotronConfigs
from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit
from nanotron.config.utils_config import (
RecomputeGranularity,
cast_str_to_pipeline_engine,
cast_str_to_torch_dtype,
serialize,
)
from nanotron.core.parallel.pipeline_parallelism.engine import (
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.core.parallel.tensor_parallelism.nn import TensorParallelLinearMode
from nanotron.generate.sampler import SamplerType
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.generation.sampler import SamplerType
from nanotron.logging import get_logger

logger = get_logger(__name__)
@@ -42,7 +42,6 @@ class LoggingArgs:
log_level: Optional[str] = None
log_level_replica: Optional[str] = None
iteration_step_info_interval: Optional[int] = 1
extensions = None

def __post_init__(self):
if self.log_level is None:
@@ -104,16 +103,14 @@ class CheckpointsArgs:
checkpoints_path: where to save the checkpoints
checkpoint_interval: how often to save the checkpoints
resume_checkpoint_path: if you want to load from a specific checkpoint path
s3: if you want to upload the checkpoints on s3
"""

checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = True
extensions = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

def __post_init__(self):
if isinstance(self.checkpoints_path, str):
@@ -140,31 +137,19 @@ class GeneralArgs:
seed: Optional[int] = None
step: Optional[int] = None
consumed_train_samples: Optional[int] = None
# If you want to signal the training script to stop, you just need to touch the following file
# We force users to set one in order to programmatically be able to remove it.
kill_switch_path: Optional[Path] = None
# If you want to signal the training script to pause, you just need to add the following file
benchmark_csv_path: Optional[Path] = None
ignore_sanity_checks: bool = False

def __post_init__(self):
if self.seed is None:
self.seed = 42
if isinstance(self.kill_switch_path, str):
self.kill_switch_path = Path(self.kill_switch_path)
if self.benchmark_csv_path is not None:
assert (
os.environ.get("NANOTRON_BENCHMARK", None) is not None
), f"Please set NANOTRON_BENCHMARK to 1 when using benchmark_csv_path. Got {os.environ.get('NANOTRON_BENCHMARK', None)}"

if self.run is None:
self.run = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if os.environ.get("SLURM_JOB_ID", None) is not None:
self.run += f"_{os.environ['SLURM_JOB_ID']}"
else:
self.run = self.run.replace("%d", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
if os.environ.get("SLURM_JOB_ID", None) is not None:
self.run = self.run.replace("%j", os.environ["SLURM_JOB_ID"])


@dataclass
@@ -213,22 +198,6 @@ def __post_init__(self):
self.recompute_granularity = RecomputeGranularity[self.recompute_granularity.upper()]


@dataclass
class RandomInit:
std: float


@dataclass
class ExistingCheckpointInit:
"""This is used to initialize from an already existing model (without optimizer, lr_scheduler...)"""

path: Path

def __post_init__(self):
if isinstance(self.path, str):
self.path = Path(self.path)


@dataclass
class ModelArgs:
"""Arguments related to model architecture"""
@@ -353,7 +322,7 @@ class Config:
tokens: TokensArgs
optimizer: OptimizerArgs
data: DataArgs
profiler: Optional[ProfilerArgs] = None
profiler: Optional[ProfilerArgs]

def __post_init__(self):
# Some final sanity checks across separate arguments sections:
@@ -380,13 +349,13 @@ def save_as_yaml(self, file_path: str):
yaml.dump(config_dict, f)

# Sanity test config can be reloaded
_ = get_config_from_file(file_path)
_ = get_config_from_file(file_path, config_class=self.__class__)

def as_dict(self) -> dict:
return serialize(self)


def get_config_from_file(config_path: str) -> Config:
def get_config_from_file(config_path: str, config_class: Type[Config] = Config) -> Config:
"""Get a config objet from a file (python or YAML)
Args:
@@ -402,7 +371,7 @@ def get_config_from_file(config_path: str) -> Config:
# Make a nice dataclass from our yaml
try:
config = from_dict(
data_class=Config,
data_class=config_class,
data=args,
config=dacite.Config(
cast=[Path],
13 changes: 13 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union


@dataclass
class RandomInit:
std: float


@dataclass
class ExistingCheckpointInit:
"""This is used to initialize from an already existing model (without optimizer, lr_scheduler...)"""

path: Path


@dataclass
class LlamaConfig:
"""Configuration for a LLAMA model
6 changes: 3 additions & 3 deletions src/nanotron/config/utils_config.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,13 @@

import torch

from nanotron.core.parallel.pipeline_parallelism.engine import (
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
OneForwardOneBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.core.parallel.tensor_parallelism.nn import TensorParallelLinearMode
from nanotron.generate.sampler import SamplerType
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.generation.sampler import SamplerType


class RecomputeGranularity(Enum):
43 changes: 1 addition & 42 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,7 @@
import importlib
import importlib.metadata as importlib_metadata
import platform
import warnings
from typing import Tuple, Union

from packaging.version import Version, parse

CHECKPOINT_VERSION = Version("1.2")
CHECKPOINT_VERSION = Version("0.1")

PY_VERSION = parse(platform.python_version())

# https://github.com/huggingface/transformers/blob/f67dac97bdc63874f2288546b3fa87e69d2ea1c8/src/transformers/utils/import_utils.py#L41
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib_metadata.version(pkg_name)
package_exists = True
except importlib_metadata.PackageNotFoundError:
package_exists = False
if return_version:
return package_exists, package_version
else:
return package_exists


def _can_import_from_module(module: str, name: str) -> bool:
"""
Check if a specific module can be imported from a package.
"""
if not _is_package_available(module):
return False
try:
spec = importlib.util.find_spec(module)
module_obj = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module_obj)
return hasattr(module_obj, name)
except Exception as e:
warnings.warn(f"Unable to import {name} from {module}: {e}")
return False


TENSORBOARDX_AVAILABLE = _is_package_available("tensorboardX")
HUGGINGFACE_HUB_AVAILABLE = _is_package_available("huggingface_hub")
HF_TENSORBOARD_LOGGER_AVAILABLE = _can_import_from_module("huggingface_hub", "HFSummaryWriter")
Loading

0 comments on commit 2be81dc

Please sign in to comment.