From d877757cc828fa5ed6e97e7925e8f0de44f93dce Mon Sep 17 00:00:00 2001 From: Jai Dhyani Date: Tue, 26 Mar 2024 19:31:08 -0700 Subject: [PATCH 1/4] Make training more legible (#74) * more verbose logging during setup * fixed/slightly simplified model/optimizer init/loading logic * epoch logging + option to silence wandb noise * add metadata help to ModelTrainingState fields * docstring for ModelTrainingState * explicitly log all non-default config values * simplify skip logic; rename t0 to last_training_step_time * "User-set" is more accurate than "non-default" * delete mfu stuff * move optimizer.zero_grad into if-actually-training conditional * remove unused functions and add some docstrings * Simplify model configuration to only use transformers library directly (#75) * Drop support for typedmodels, do everything through transformers directly * minor cleanup * remove null and obsolete mamba line from sample config * update help strings on ModelConfig --- .../sample_config.json | 5 +- .../sample_mamba.json | 4 +- .../sample_transformers_bloom.json | 2 +- src/delphi/eval/utils.py | 5 +- src/delphi/static/configs/debug.json | 4 +- src/delphi/static/configs/debug_mamba.json | 4 +- .../configs/debug_transformers_bloom.json | 2 +- src/delphi/static/configs/v0-llama2-1.6m.json | 4 +- .../static/configs/v0-llama2-100k-quick.json | 4 +- src/delphi/static/configs/v0-llama2-100k.json | 4 +- .../static/configs/v0-llama2-12.8m.json | 4 +- src/delphi/static/configs/v0-llama2-200k.json | 4 +- .../static/configs/v0-llama2-25.6m.json | 4 +- src/delphi/static/configs/v0-llama2-3.2m.json | 4 +- src/delphi/static/configs/v0-llama2-400k.json | 4 +- src/delphi/static/configs/v0-llama2-6.4m.json | 4 +- src/delphi/static/configs/v0-llama2-800k.json | 4 +- src/delphi/train/config/gigaconfig.py | 2 +- src/delphi/train/config/model_config.py | 36 +++++ src/delphi/train/config/models/__init__.py | 4 - .../train/config/models/model_config.py | 64 --------- src/delphi/train/config/models/model_types.py | 70 ---------- .../train/config/models/typed_llama_config.py | 30 ---- .../train/config/models/typed_mamba_config.py | 39 ------ .../train/config/models/typed_model_config.py | 15 -- src/delphi/train/config/utils.py | 38 +++++ src/delphi/train/config/wandb_config.py | 1 + src/delphi/train/iteration_params.py | 20 ++- src/delphi/train/train_step.py | 71 ++-------- src/delphi/train/training.py | 8 +- src/delphi/train/utils.py | 132 ++++++++---------- src/delphi/train/wandb_utils.py | 2 +- .../train/config/models/test_model_config.py | 39 ++---- tests/train/test_wandb_utils.py | 16 ++- 34 files changed, 224 insertions(+), 429 deletions(-) create mode 100644 src/delphi/train/config/model_config.py delete mode 100644 src/delphi/train/config/models/__init__.py delete mode 100644 src/delphi/train/config/models/model_config.py delete mode 100644 src/delphi/train/config/models/model_types.py delete mode 100644 src/delphi/train/config/models/typed_llama_config.py delete mode 100644 src/delphi/train/config/models/typed_mamba_config.py delete mode 100644 src/delphi/train/config/models/typed_model_config.py diff --git a/scripts/training_config_examples/sample_config.json b/scripts/training_config_examples/sample_config.json index b2de5093..c6fc86b5 100644 --- a/scripts/training_config_examples/sample_config.json +++ b/scripts/training_config_examples/sample_config.json @@ -16,9 +16,8 @@ "batch_size": 64, "max_seq_len": 512, "model_config": { - "model_type": "llama2", - "mamba": null, - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": -1, diff --git a/scripts/training_config_examples/sample_mamba.json b/scripts/training_config_examples/sample_mamba.json index 94a2dd37..81bdf313 100644 --- a/scripts/training_config_examples/sample_mamba.json +++ b/scripts/training_config_examples/sample_mamba.json @@ -16,8 +16,8 @@ "batch_size": 64, "max_seq_len": 512, "model_config": { - "model_type": "mamba", - "mamba": { + "model_type": "MambaForCausalLM", + "model_params": { "vocab_size": 4096, "hidden_size": 768, "state_size": 16, diff --git a/scripts/training_config_examples/sample_transformers_bloom.json b/scripts/training_config_examples/sample_transformers_bloom.json index 7d812ea4..4030dcb9 100644 --- a/scripts/training_config_examples/sample_transformers_bloom.json +++ b/scripts/training_config_examples/sample_transformers_bloom.json @@ -8,7 +8,7 @@ "batch_size": 64, "model_config": { "model_type": "BloomForCausalLM", - "transformers_config": { + "model_params": { "apply_residual_connection_post_layernorm": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 2d052974..b45a49d1 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from typing import cast @@ -85,7 +86,9 @@ def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Datas # or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189 split=f"train{slice}", ) - return cast(Dataset, dataset) + dataset = cast(Dataset, dataset) + logging.info(f" Loaded {data_files_str} ({len(dataset)} entries)") + return dataset def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset: diff --git a/src/delphi/static/configs/debug.json b/src/delphi/static/configs/debug.json index 5d40717a..1694c587 100644 --- a/src/delphi/static/configs/debug.json +++ b/src/delphi/static/configs/debug.json @@ -8,8 +8,8 @@ "train_sample_limit": 256, "batch_size": 64, "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "hidden_size": 48, "intermediate_size": 48, "num_attention_heads": 2, diff --git a/src/delphi/static/configs/debug_mamba.json b/src/delphi/static/configs/debug_mamba.json index 2a619e59..16d0391a 100644 --- a/src/delphi/static/configs/debug_mamba.json +++ b/src/delphi/static/configs/debug_mamba.json @@ -9,8 +9,8 @@ "train_sample_limit": 64, "batch_size": 8, "model_config": { - "model_type": "mamba", - "mamba": { + "model_type": "MambaForCausalLM", + "model_params": { "vocab_size": 4096, "hidden_size": 48, "state_size": 16, diff --git a/src/delphi/static/configs/debug_transformers_bloom.json b/src/delphi/static/configs/debug_transformers_bloom.json index 6c89e9e2..18720113 100644 --- a/src/delphi/static/configs/debug_transformers_bloom.json +++ b/src/delphi/static/configs/debug_transformers_bloom.json @@ -9,7 +9,7 @@ "batch_size": 64, "model_config": { "model_type": "BloomForCausalLM", - "transformers_config": { + "model_params": { "apply_residual_connection_post_layernorm": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-1.6m.json b/src/delphi/static/configs/v0-llama2-1.6m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-1.6m.json +++ b/src/delphi/static/configs/v0-llama2-1.6m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-100k-quick.json b/src/delphi/static/configs/v0-llama2-100k-quick.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-100k-quick.json +++ b/src/delphi/static/configs/v0-llama2-100k-quick.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-100k.json b/src/delphi/static/configs/v0-llama2-100k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-100k.json +++ b/src/delphi/static/configs/v0-llama2-100k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-12.8m.json b/src/delphi/static/configs/v0-llama2-12.8m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-12.8m.json +++ b/src/delphi/static/configs/v0-llama2-12.8m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-200k.json b/src/delphi/static/configs/v0-llama2-200k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-200k.json +++ b/src/delphi/static/configs/v0-llama2-200k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-25.6m.json b/src/delphi/static/configs/v0-llama2-25.6m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-25.6m.json +++ b/src/delphi/static/configs/v0-llama2-25.6m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-3.2m.json b/src/delphi/static/configs/v0-llama2-3.2m.json index 909a898c..978b7007 100644 --- a/src/delphi/static/configs/v0-llama2-3.2m.json +++ b/src/delphi/static/configs/v0-llama2-3.2m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-400k.json b/src/delphi/static/configs/v0-llama2-400k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-400k.json +++ b/src/delphi/static/configs/v0-llama2-400k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-6.4m.json b/src/delphi/static/configs/v0-llama2-6.4m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-6.4m.json +++ b/src/delphi/static/configs/v0-llama2-6.4m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-800k.json b/src/delphi/static/configs/v0-llama2-800k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-800k.json +++ b/src/delphi/static/configs/v0-llama2-800k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/train/config/gigaconfig.py b/src/delphi/train/config/gigaconfig.py index 4da07941..9d6236d2 100644 --- a/src/delphi/train/config/gigaconfig.py +++ b/src/delphi/train/config/gigaconfig.py @@ -7,7 +7,7 @@ from .debug_config import DebugConfig from .huggingface_config import HuggingfaceConfig -from .models import ModelConfig +from .model_config import ModelConfig from .optimizer_config import OptimizerConfig from .wandb_config import WandbConfig diff --git a/src/delphi/train/config/model_config.py b/src/delphi/train/config/model_config.py new file mode 100644 index 00000000..df6b66d7 --- /dev/null +++ b/src/delphi/train/config/model_config.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass, field +from typing import Any, cast + +import transformers +from beartype import beartype +from beartype.typing import Type + + +@beartype +@dataclass(frozen=True) +class ModelConfig: + model_type: str = field( + metadata={ + "help": ( + "Name of any CausalLM Model from the transformers " + "library (e.g. 'BartForCausalLM'). Model configuration arguments, " + "e.g. hidden size, should be specified in model_params" + ) + } + ) + model_params: dict[str, Any] = field( + default_factory=dict, + metadata={ + "help": ( + "config for the transformers model specified by model_type. " + "e.g. {'hidden_size': 128, ...}" + ) + }, + ) + + def get_model(self): + model_class = getattr(transformers, self.model_type) + config_class = cast( + Type[transformers.PretrainedConfig], model_class.config_class + ) + return model_class(config_class(**(self.model_params))) diff --git a/src/delphi/train/config/models/__init__.py b/src/delphi/train/config/models/__init__.py deleted file mode 100644 index cdf71c6a..00000000 --- a/src/delphi/train/config/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .model_config import ModelConfig -from .model_types import ModelType, ModelTypes -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig diff --git a/src/delphi/train/config/models/model_config.py b/src/delphi/train/config/models/model_config.py deleted file mode 100644 index 55e3bf36..00000000 --- a/src/delphi/train/config/models/model_config.py +++ /dev/null @@ -1,64 +0,0 @@ -from dataclasses import asdict, dataclass, field -from typing import Any, Optional, Type, cast - -import transformers -from beartype import beartype -from beartype.typing import Type -from transformers import PreTrainedModel - -from .model_types import ModelType, ModelTypes -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig - - -@beartype -@dataclass(frozen=True) -class ModelConfig: - model_type: str = field( - metadata={ - "help": ( - "The model type to train. May be either a predefined " - "type (delphi, mamba) or any CausalLM Model from the transformers " - "library (e.g. BartForCausalLM). Predefined types should " - "specify their respective configs in this model config; " - "transformer library models should specify their model " - "config arguments in transformers_config." - ) - } - ) - transformers_config: dict[str, Any] = field( - default_factory=dict, - metadata={"help": "config for the transformers model specified by model_type"}, - ) - mamba: Optional[TypedMambaConfig] = field( - default=None, - metadata={"help": "config for Delphi mamba model. See TypedMambaConfig"}, - ) - llama2: Optional[TypedLlamaConfig] = field( - default=None, - metadata={"help": "config for Delphi llama2 model. See TypedLlamaConfig"}, - ) - - def is_predefined_type(self): - return hasattr(self, self.model_type) - - def get_config_args(self) -> dict[str, Any]: - if self.is_predefined_type(): - return asdict(getattr(self, self.model_type)) - else: - return self.transformers_config - - def get_model_class(self) -> type[PreTrainedModel]: - if self.is_predefined_type(): - model_type = cast(ModelType, ModelTypes.get(self.model_type)) - return model_type.model - else: - model_class = getattr(transformers, self.model_type) - return model_class - - def get_model(self): - model_class = self.get_model_class() - config_class = cast( - Type[transformers.PretrainedConfig], model_class.config_class - ) - return model_class(config_class(**(self.get_config_args()))) diff --git a/src/delphi/train/config/models/model_types.py b/src/delphi/train/config/models/model_types.py deleted file mode 100644 index 5ccd0990..00000000 --- a/src/delphi/train/config/models/model_types.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -For any given model we use, there are three associated types: -- TypedModelConfig: a typed dataclass that defines the arguments to the model. - We use this to enforce some semblance of type safety in configs and code in general. -- PretrainedConfig: a transformers config that defines the model architecture. - The arguments for this are defined in TypedModelConfig. -- PreTrainedModel: a transformers model that implements the model architecture. - Configured by PretrainedConfig. - - This file defines a ModelType dataclass that associated these three types for a given model, - and a ModelTypes container class that defines all the models we use in Delphi along with a - helpful ModelTypes.get() method for getting ModelType from a string. -""" -from dataclasses import dataclass -from typing import Optional - -from beartype import beartype -from beartype.typing import Type -from transformers import LlamaForCausalLM, MambaForCausalLM, PreTrainedModel - -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class ModelType: - name: str - delphi_config: type[TypedModelConfig] - model: type[PreTrainedModel] - - # Allow for ModelType == 'llama2' - def __eq__(self, other): - if isinstance(other, str): - return self.name == other - else: - return super().__eq__(other) - - def __post_init__(self): - # register the ModelType so ModelTypes.get(model_type_name) works - _model_name_to_model_type[self.name.lower()] = self - - -_model_name_to_model_type: dict[str, ModelType] = {} - - -# define new model types here -class ModelTypes: - MAMBA = ModelType( - name="mamba", - delphi_config=TypedMambaConfig, - model=MambaForCausalLM, - ) - LLAMA2 = ModelType( - name="llama2", - delphi_config=TypedLlamaConfig, - model=LlamaForCausalLM, - ) - - # NEWMODEL = ModelType( # var name should match name - # name="newmodel", # string that will be associated with model in configs, etc - # typed_config=TypedNewModelConfig, # typed dataclass for args to config - # config=NewModelConfig, # transformers config - # model=NewModelForCausalLM, # transformers model - # ) - - @classmethod - def get(cls: Type["ModelTypes"], name: str) -> Optional[ModelType]: - return _model_name_to_model_type.get(name.lower()) diff --git a/src/delphi/train/config/models/typed_llama_config.py b/src/delphi/train/config/models/typed_llama_config.py deleted file mode 100644 index 7e1a2f83..00000000 --- a/src/delphi/train/config/models/typed_llama_config.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Optional - -from beartype import beartype - -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class TypedLlamaConfig(TypedModelConfig): - attention_bias: bool = False - attention_dropout: float = 0.0 - bos_token_id: int = -1 - eos_token_id: int = -2 - hidden_act: str = "silu" - hidden_size: int = 288 - initializer_range: float = 0.02 - intermediate_size: int = 288 - max_position_embeddings: int = 512 - num_attention_heads: int = 6 - num_hidden_layers: int = 6 - num_key_value_heads: int = 6 - pretraining_tp: int = 1 - rms_norm_eps: float = 1e-06 - rope_scaling: Optional[dict[str, Any]] = None - rope_theta: float = 10000.0 - tie_word_embeddings: bool = False - use_cache: bool = True - vocab_size: int = 4096 diff --git a/src/delphi/train/config/models/typed_mamba_config.py b/src/delphi/train/config/models/typed_mamba_config.py deleted file mode 100644 index 8b01a932..00000000 --- a/src/delphi/train/config/models/typed_mamba_config.py +++ /dev/null @@ -1,39 +0,0 @@ -from dataclasses import dataclass -from typing import Union - -from beartype import beartype - -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class TypedMambaConfig(TypedModelConfig): - # model shape - vocab_size: int = 4096 - hidden_size: int = 768 - state_size: int = 16 - num_hidden_layers: int = 32 - conv_kernel: int = 4 - expand: int = 2 - use_bias: bool = False - use_conv_bias: bool = True - # tokens - bos_token_id: int = 0 - eos_token_id: int = 0 - pad_token_id: int = 0 - # time step - time_step_rank: Union[int, str] = "auto" - time_step_scale: float = 1.0 - time_step_min: float = 0.001 - time_step_max: float = 0.1 - time_step_init_scheme: str = "random" # "random" or "uniform" - time_step_floor: float = 0.0001 - # misc - layer_norm_epsilon: float = 1e-05 - hidden_act: str = "silu" - initializer_range: float = 0.1 - residual_in_fp32: bool = True - rescale_prenorm_residual: bool = False - use_cache: bool = True - tie_word_embeddings: bool = True diff --git a/src/delphi/train/config/models/typed_model_config.py b/src/delphi/train/config/models/typed_model_config.py deleted file mode 100644 index 8eae1de8..00000000 --- a/src/delphi/train/config/models/typed_model_config.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class TypedModelConfig: - """ - This is a dummy class for typing purposes. We could make a Union class that we update - every time we add a TypedModelConfig class, but that would mean remembering to go update - another thing when adding a new TypedModelConfig. - """ - - def __init__(self): - raise NotImplementedError( - "TypedModelConfig is a dummy class to provide typing for actual ModelConfig classes. It shouldn't ever be instantiated." - ) diff --git a/src/delphi/train/config/utils.py b/src/delphi/train/config/utils.py index 0f582104..719b522d 100644 --- a/src/delphi/train/config/utils.py +++ b/src/delphi/train/config/utils.py @@ -1,8 +1,10 @@ import json import logging import os +from dataclasses import fields, is_dataclass from datetime import datetime from pathlib import Path +from typing import Type import platformdirs from beartype.typing import Any, Iterable @@ -68,6 +70,24 @@ def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]: return combined_config +def filter_config_to_actual_config_values(target_dataclass: Type, config: dict): + """Remove non-config values from config dict. + + This can happen if e.g. being lazy and passing in all args from a script + """ + datafields = fields(target_dataclass) + name_to_field = {f.name: f for f in datafields} + to_remove = [] + for k, v in config.items(): + if k not in name_to_field.keys(): + logging.debug(f"removing non-config-value {k}={v} from config dict") + to_remove.append(k) + elif isinstance(v, dict) and is_dataclass(name_to_field.get(k)): + filter_config_to_actual_config_values(name_to_field[k].type, v) + for k in to_remove: + config.pop(k) + + def set_backup_vals(config: dict[str, Any], config_files: list[Path]): if len(config_files) == 1: prefix = f"{config_files[0].stem}__" @@ -76,10 +96,23 @@ def set_backup_vals(config: dict[str, Any], config_files: list[Path]): if "run_name" not in config: run_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") config["run_name"] = f"{prefix}{run_time}" + logging.info(f"Setting run_name to {config['run_name']}") if "output_dir" not in config: config["output_dir"] = os.path.join( platformdirs.user_data_dir(appname="delphi"), config["run_name"] ) + logging.info(f"Setting output_dir to {config['output_dir']}") + + +def log_config_recursively( + config: dict, logging_fn=logging.info, indent=" ", prefix="" +): + for k, v in config.items(): + if isinstance(v, dict): + logging_fn(f"{prefix}{k}") + log_config_recursively(v, logging_fn, indent, prefix=indent + prefix) + else: + logging_fn(f"{prefix}{k}: {v}") def build_config_from_files_and_overrides( @@ -89,6 +122,11 @@ def build_config_from_files_and_overrides( combined_config = build_config_dict_from_files(config_files) _merge_dicts(merge_into=combined_config, merge_from=overrides) set_backup_vals(combined_config, config_files) + filter_config_to_actual_config_values(GigaConfig, combined_config) + logging.info("User-set config values:") + log_config_recursively( + combined_config, logging_fn=logging.info, prefix=" ", indent=" " + ) return from_dict(GigaConfig, combined_config) diff --git a/src/delphi/train/config/wandb_config.py b/src/delphi/train/config/wandb_config.py index 716dec0a..23fbc3a0 100644 --- a/src/delphi/train/config/wandb_config.py +++ b/src/delphi/train/config/wandb_config.py @@ -7,3 +7,4 @@ class WandbConfig: log: bool = False project: str = "delphi" entity: str = "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" + silence: bool = False diff --git a/src/delphi/train/iteration_params.py b/src/delphi/train/iteration_params.py index 8a7e8d88..b6c624db 100644 --- a/src/delphi/train/iteration_params.py +++ b/src/delphi/train/iteration_params.py @@ -30,9 +30,23 @@ def set_iteration_params( * config.batch_size * config.max_seq_len ) - logging.debug(f"tokens per iteration will be: {tokens_per_iter:,}") - logging.debug( - f"breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} max seq len" + logging.info("Iteration setup:") + logging.info(f" batch size: {config.batch_size}") + logging.info(f" training set size: {len(train_ds)}") + logging.info(f" training batches: {num_batches}") + logging.info( + f" gradient accumulations per step (=batches per step): {config.optimizer.gradient_accumulation_steps}" + ) + logging.info(f" steps per batch: {num_steps}") + logging.info(f" tokens per sequence: {config.max_seq_len}") + logging.info(f" tokens per training step will be: {tokens_per_iter:,}") + logging.info( + f" breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} tokens per sequence" + ) + logging.info(f" validation set size: {len(validation_ds)}") + logging.info(f" batches per validation step: {eval_iters}") + logging.info( + f" tokens per validation step: {eval_iters * config.batch_size * config.max_seq_len:,}" ) return IterationParams( num_batches, num_steps, eval_iters, lr_decay_iters, tokens_per_iter diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index baee2147..9852d633 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -6,7 +6,6 @@ from datasets import Dataset from .config import GigaConfig -from .config.models import ModelTypes from .iteration_params import IterationParams from .run_context import RunContext from .utils import EvalData, ModelTrainingState, estimate_loss, get_next_xy, set_lr @@ -76,84 +75,36 @@ def train_step( callback(eval_data) # 3. forward backward update, with optional gradient accumulation to simulate larger batch size - logging.debug( - f"gradient accumulation steps: {config.optimizer.gradient_accumulation_steps}, " - f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}" - ) - for micro_step in range(config.optimizer.gradient_accumulation_steps): - X, Y = get_next_xy(train_batch_iter, run_context.device) - if config.debug_config.no_training: - logging.debug("no_training set, skipping forward backward update") - loss = torch.Tensor([42.1]).to(run_context.device) - else: + if config.debug_config.no_training: + logging.debug("no_training set, skipping forward backward pass") + loss = torch.Tensor([42.1]).to(run_context.device) + else: + for micro_step in range(config.optimizer.gradient_accumulation_steps): + X, Y = get_next_xy(train_batch_iter, run_context.device) loss = ( model(X, labels=Y, return_dict=True).loss / config.optimizer.gradient_accumulation_steps ) loss.backward() - if config.debug_config.no_training: - logging.debug("debug no_training is set, skipping optimizer step") - else: # clip the gradient if config.grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore optimizer.step() - - # flush the gradients as soon as we can, no need for this memory anymore - optimizer.zero_grad(set_to_none=True) + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) # 4. log timing t1 = time.time() - dt = t1 - model_training_state.t0 - model_training_state.t0 = t1 + dt = t1 - model_training_state.last_training_step_time + model_training_state.last_training_step_time = t1 if model_training_state.iter_num % config.log_interval == 0: # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point lossf = loss.item() * config.optimizer.gradient_accumulation_steps - if ( - model_training_state.local_iter_num >= 5 - ): # let the training loop settle a bit - mfu = estimate_mfu( - config=config, model=model_training_state.model, timedelta=dt - ) - model_training_state.running_mfu = ( - mfu - if model_training_state.running_mfu == -1.0 - else 0.9 * model_training_state.running_mfu + 0.1 * mfu - ) logging.debug( ( f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {model_training_state.lr:e} | " - f"{dt*1000:.2f}ms | mfu {model_training_state.running_mfu*100:.2f}%" + f"{dt*1000:.2f}ms" ) ) model_training_state.iter_num += 1 model_training_state.local_iter_num += 1 - - -def estimate_mfu(config: GigaConfig, model: torch.nn.Module, timedelta: float) -> float: - """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" - # first estimate the number of flops we do per iteration. - # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 - N = sum(p.numel() for p in model.parameters()) - if config.model_config.model_type == ModelTypes.LLAMA2: - cfg = model.config - L, H, Q, T = ( - cfg.num_hidden_layers, - cfg.num_attention_heads, - cfg.hidden_size // cfg.num_attention_heads, - cfg.max_position_embeddings, - ) - else: - logging.debug( - f"estimate_mfu not implemented for {config.model_config.model_type}, setting MFU to -1" - ) - return -1.0 - flops_per_token = 6 * N + 12 * L * H * Q * T - flops_per_fwdbwd = flops_per_token * T - fwdbwd_per_iter = config.batch_size * config.optimizer.gradient_accumulation_steps - flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter - # express our flops throughput as ratio of A100 bfloat16 peak flops - flops_achieved = flops_per_iter * (1.0 / timedelta) # per second - flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS - mfu = flops_achieved / flops_promised - return mfu diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index ab189f14..ba0603c7 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -45,7 +45,7 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: logging.debug(f"Run context: {run_context}") # load data - logging.debug("Loading data...") + logging.info("Loading data...") train_ds = cast( Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit) ) @@ -67,14 +67,20 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: model_training_state = initialize_model_training_state(config, run_context.device) # setup eval callbacks + logging.info("Setting eval step callbacks...") eval_callbacks = [save_checkpoint_if_needed] + logging.info(f" added save_checkpoint_if_needed eval callback") if config.wandb_config.log: + if config.wandb_config.silence: + wandb_utils.silence_wandb() wandb_utils.init_wandb(config) eval_callbacks.append(wandb_utils.log_to_wandb) + logging.info(f" added log_to_wandb callback") # training loop logging.info("Starting training...") for epoch in range(config.max_epochs): + logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}") train_batch_iter = iter( batch_generator( train_ds, config.batch_size, epoch, config.batch_ordering_seed diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 3d7b2a5c..c974d573 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -4,7 +4,7 @@ import os import time from collections.abc import Generator -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import cast @@ -25,16 +25,23 @@ @dataclass class ModelTrainingState: + """mutable training state - stuff that changes over the course of training""" + model: torch.nn.Module optimizer: torch.optim.Optimizer - iter_num: int - local_iter_num: int - best_val_loss: float - running_mfu: float - t0: float - epoch: int - step: int - lr: float = 1.0e-5 + iter_num: int = field( + metadata={"help": "total iterations so far across all epochs"} + ) + local_iter_num: int = field( + metadata={"help": "total iterations on this instance so far"} + ) + best_val_loss: float = field(metadata={"help": "best validation loss so far"}) + last_training_step_time: float = field( + metadata={"help": "time last iteration ended"} + ) + epoch: int = field(metadata={"help": "current epoch"}) + step: int = field(metadata={"help": "step within current epoch"}) + lr: float = field(default=1.0e-5, metadata={"help": "learning rate"}) @dataclass @@ -63,25 +70,6 @@ def get_device(device_str: str = "auto") -> torch.device: return torch.device(device_str) -def get_optimizer( - model: torch.nn.Module, - config: GigaConfig, - output_dir=None, - device: torch.device = torch.device("cpu"), -) -> AdamW: - optimizer = AdamW( - lr=config.optimizer.learning_rate, - params=model.parameters(), - weight_decay=config.optimizer.weight_decay, - betas=(config.optimizer.beta1, config.optimizer.beta2), - ) - if output_dir is not None: - opt_path = os.path.join(output_dir, "opt.pt") - with open(opt_path, "rb") as f: - optimizer.load_state_dict(torch.load(f)) - return optimizer - - def get_lr( iter_num: int, warmup_iters: int, @@ -145,56 +133,48 @@ def save_checkpoint_if_needed(eval_data: EvalData): ) -def load_model_from_checkpoint(config: GigaConfig, output_dir: str) -> torch.nn.Module: - model = config.model_config.get_model() - st.load_model(model, os.path.join(output_dir, "model", "model.safetensors")) - return model - - def initialize_model_training_state( config: GigaConfig, device: torch.device ) -> ModelTrainingState: t0 = time.time() - training_state = None + model = config.model_config.get_model() + model.to(device) # type: ignore + optimizer = AdamW( + lr=config.optimizer.learning_rate, + params=model.parameters(), + weight_decay=config.optimizer.weight_decay, + betas=(config.optimizer.beta1, config.optimizer.beta2), + ) + training_state_vals = dict() if config.init_from == "scratch": - # init a new model from scratch - logging.debug("Initializing a new model from scratch") - model = config.model_config.get_model() - checkpoint = None + logging.info(f" initialized model and optimizer from scratch") # TODO: resume from huggingface model elif config.init_from == "resume": logging.info(f"Resuming training from {config.output_dir}") checkpoint = config.output_dir - model = load_model_from_checkpoint(config, checkpoint) + st.load_model( + model, os.path.join(config.output_dir, "model", "model.safetensors") + ) with open(os.path.join(checkpoint, "training_state.json"), "r") as f: - training_state = json.load(f) - model.to(device) # type: ignore - # optimizer - optimizer = get_optimizer( - model=model, - config=config, - output_dir=config.output_dir - if (Path(config.output_dir) / "opt.safetensors").exists() - else None, - device=device, - ) - epoch = training_state.get("epoch", 0) if training_state is not None else 0 - step = training_state.get("step", 0) if training_state is not None else 0 - best_val_loss = training_state.get("best_val_loss", 1e9) if training_state else 1e9 - iter_num = training_state.get("iter_num", 0) if training_state else 0 - local_iter_num = training_state.get("local_iter_num", 0) if training_state else 0 - running_mfu = training_state.get("running_mfu", 0.0) if training_state else -1.0 - checkpoint = None # free up memory + training_state_vals = json.load(f) + opt_state_dict_path = Path(os.path.join(config.output_dir, "opt.pt")) + if opt_state_dict_path.exists(): + with open(opt_state_dict_path, "rb") as f: + logging.info(" Loading optimizer state from {state_dict_path}") + optimizer.load_state_dict(torch.load(f)) + else: + raise ValueError( + f"{config.init_from} is not one of (scratch, resume), which are the two valid initialization methods. Unable to initialize model." + ) return ModelTrainingState( model=model, optimizer=optimizer, - iter_num=iter_num, - local_iter_num=local_iter_num, - best_val_loss=best_val_loss, - running_mfu=running_mfu, - t0=t0, - epoch=epoch, - step=step, + last_training_step_time=t0, + iter_num=training_state_vals.get("iter_num", 0), + local_iter_num=training_state_vals.get("local_iter_num", 0), + best_val_loss=training_state_vals.get("best_val_loss", 1e9), + epoch=training_state_vals.get("epoch", 0), + step=training_state_vals.get("step", 0), ) @@ -214,10 +194,9 @@ def load_delphi_training_dataset(split: str, limit: int = -1): def get_next_xy( - train_batch_iter: Generator, - device: torch.device - # train_batch_iter: Generator[dict[str, list[int]], None, None], device: torch.device + train_batch_iter: Generator, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: + """break a (max_seq_len +1) sequence of tokens into sample [:-1] and label [1:] pairs""" data = next(train_batch_iter).to(device) X, Y = data[:, :-1], data[:, 1:] return X, Y @@ -226,6 +205,9 @@ def get_next_xy( def batch_generator( dataset: Dataset, batch_size: int, epoch: int, ordering_seed: int ) -> Generator[torch.Tensor, None, None]: + """ + Generate batches of training data for a given epoch with pseudorandom determinism + """ sampler = list(range(len(dataset))) # type: ignore shuffle_list(sampler, seed=ordering_seed + epoch) sampler = torch.Tensor(sampler) @@ -257,19 +239,18 @@ def estimate_loss( return out -def upload_to_huggingface(eval_data: EvalData): - model = eval_data.model_training_state.model - if isinstance(model, PreTrainedModel): - model = cast(PreTrainedModel, model) - model.save_pretrained(eval_data.config.output_dir) - - def save_results( config: GigaConfig, train_results: ModelTrainingState, run_context: RunContext, results_path: str, ): + """ + save results to disk, and to huggingface if configured to do so. + + Saves everything required to replicate the current state of training, including optimizer state, + config, context (e.g. hardware), training step, etc + """ os.makedirs(results_path, exist_ok=True) with open(os.path.join(results_path, "config.json"), "w") as file: json.dump(asdict(config), file, indent=2) @@ -291,7 +272,6 @@ def save_results( "iter_num": train_results.iter_num, "local_iter_num": train_results.local_iter_num, "best_val_loss": train_results.best_val_loss, - "running_mfu": train_results.running_mfu, "lr": train_results.lr, "epoch": train_results.epoch, "step": train_results.step, diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py index 89ed5518..b24926df 100644 --- a/src/delphi/train/wandb_utils.py +++ b/src/delphi/train/wandb_utils.py @@ -10,6 +10,7 @@ def silence_wandb(): # set env var WANDB_SILENT=true + logging.info("silencing wandb output") os.environ["WANDB_SILENT"] = "true" @@ -35,7 +36,6 @@ def log_to_wandb(eval_data: EvalData): "loss/train": eval_data.losses["train"], "loss/val": eval_data.losses["val"], "lr": mts.lr, - "mfu": mts.running_mfu * 100, # convert to percentage }, step=mts.iter_num, ) diff --git a/tests/train/config/models/test_model_config.py b/tests/train/config/models/test_model_config.py index aad4a13a..eca5fe69 100644 --- a/tests/train/config/models/test_model_config.py +++ b/tests/train/config/models/test_model_config.py @@ -1,9 +1,8 @@ import pytest from dacite import from_dict -from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import BloomConfig, BloomForCausalLM -from delphi.train.config.models import ModelConfig, TypedLlamaConfig -from delphi.train.config.models.model_config import ModelConfig +from delphi.train.config.model_config import ModelConfig @pytest.fixture @@ -11,8 +10,8 @@ def llama_config(): return from_dict( ModelConfig, { - "model_type": "llama2", - "llama2": {"hidden_size": 49, "num_attention_heads": 7}, + "model_type": "LlamaForCausalLM", + "model_params": {"hidden_size": 49, "num_attention_heads": 7}, }, ) @@ -23,36 +22,20 @@ def bloom_config(): ModelConfig, { "model_type": "BloomForCausalLM", - "transformers_config": {"layer_norm_epsilon": 0.0042}, + "model_params": {"layer_norm_epsilon": 0.0042}, }, ) -def test_deserialziation(llama_config): - direct_llama_config = ModelConfig( - model_type="llama2", - llama2=TypedLlamaConfig(hidden_size=49, num_attention_heads=7), +def test_deserialziation(bloom_config): + direct_bloom_config = ModelConfig( + model_type="BloomForCausalLM", + model_params=dict(layer_norm_epsilon=0.0042), ) - assert llama_config == direct_llama_config + assert bloom_config == direct_bloom_config -def test_model_config_is_predefined_type(llama_config): - assert llama_config.is_predefined_type() - - -def test_model_config_is_not_predefined_type(bloom_config): - assert not bloom_config.is_predefined_type() - - -def test_config_to_model_predefined(llama_config): - model = llama_config.get_model() - - assert isinstance(model, LlamaForCausalLM) - assert isinstance(model.config, LlamaConfig) - assert model.config.hidden_size == 49 - - -def test_config_to_model_generic_type(bloom_config): +def test_config_to_model(bloom_config): model = bloom_config.get_model() assert isinstance(model, BloomForCausalLM) diff --git a/tests/train/test_wandb_utils.py b/tests/train/test_wandb_utils.py index 093fd0c5..a7cbb80e 100644 --- a/tests/train/test_wandb_utils.py +++ b/tests/train/test_wandb_utils.py @@ -4,10 +4,11 @@ import pytest import torch +import transformers from dacite import from_dict from delphi.train.config import GigaConfig -from delphi.train.config.models import TypedLlamaConfig +from delphi.train.config.utils import load_preset from delphi.train.run_context import RunContext from delphi.train.utils import EvalData, initialize_model_training_state from delphi.train.wandb_utils import init_wandb, log_to_wandb, silence_wandb @@ -21,8 +22,15 @@ def mock_giga_config(): "run_name": "test_run", "device": "cpu", "model_config": { - "model_type": "llama2", - "llama2": asdict(TypedLlamaConfig()), + "model_type": "LlamaForCausalLM", + "model_params": { + "hidden_size": 48, + "intermediate_size": 48, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "vocab_size": 4096, + }, }, "wandb_config": { "log": True, @@ -43,7 +51,6 @@ def mock_model_training_state(mock_giga_config): mts.epoch = 1 mts.iter_num = 1 mts.lr = 0.001 - mts.running_mfu = 3.0 return mts @@ -93,7 +100,6 @@ def test_log_to_wandb(mock_wandb_log, mock_eval_data): "loss/train": 0.5, "loss/val": 0.4, "lr": 0.001, - "mfu": 300.0, }, step=1, ) From 01eb277cceeb65fbde2159330187aeb275bb4b80 Mon Sep 17 00:00:00 2001 From: Rai <62800649+menamerai@users.noreply.github.com> Date: Sat, 30 Mar 2024 14:06:20 -0400 Subject: [PATCH 2/4] added token selector (#92) --- requirements-nocuda.txt | 2 ++ src/delphi/eval/vis.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/requirements-nocuda.txt b/requirements-nocuda.txt index 95816526..d4dc94eb 100644 --- a/requirements-nocuda.txt +++ b/requirements-nocuda.txt @@ -21,6 +21,8 @@ wandb==0.16.3 spacy==3.7.2 pandas==1.3.4 dacite==1.8.1 +panel==1.4.0 +jupyter_bokeh==4.0.1 # temporarily installing transformers from main until 4.39.0 comes out (for mamba support) transformers @ git+https://github.com/huggingface/transformers@main diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py index 5dd4fdb2..2ae9273b 100644 --- a/src/delphi/eval/vis.py +++ b/src/delphi/eval/vis.py @@ -1,6 +1,7 @@ import uuid from typing import cast +import panel as pn import torch from IPython.core.display import HTML from IPython.core.display_functions import display @@ -138,3 +139,18 @@ def vis_sample_prediction_probs( """ display(HTML(html_str)) return html_str + + +def token_selector( + vocab_map: dict[str, int] +) -> tuple[pn.widgets.MultiChoice, list[int]]: + tokens = list(vocab_map.keys()) + token_selector = pn.widgets.MultiChoice(name="Tokens", options=tokens) + token_ids = [vocab_map[token] for token in cast(list[str], token_selector.value)] + + def update_tokens(event): + token_ids.clear() + token_ids.extend([vocab_map[token] for token in event.new]) + + token_selector.param.watch(update_tokens, "value") + return token_selector, token_ids From 5b7ec89061c6d111c665678a97d895efe5414a53 Mon Sep 17 00:00:00 2001 From: Siwei Li <46750682+siwei-li@users.noreply.github.com> Date: Sat, 30 Mar 2024 12:47:42 -0700 Subject: [PATCH 3/4] tokenize text stories and split into batches (#55) * Add function to tokenize text stories and split into batches * Split the tokenization function into two parts, fixing the while-loop issues * Add docstrings to the functions * Minor edits in the code, fix the test * Uses batch_encode() method to save time * Add script to upload to delphi-suite/batched-tokenized-stories * Remove the test file in tests/train to pass pytest * Update function name --------- authored-by: Siwei Li --- scripts/tokenize_dataset.py | 78 +++++++++++++++++++++ src/delphi/dataset/tokenization.py | 107 +++++++++++++++++++++++++++++ tests/dataset/test_tokenizer.py | 88 ++++++++++++++++++++++++ 3 files changed, 273 insertions(+) create mode 100755 scripts/tokenize_dataset.py create mode 100644 src/delphi/dataset/tokenization.py create mode 100644 tests/dataset/test_tokenizer.py diff --git a/scripts/tokenize_dataset.py b/scripts/tokenize_dataset.py new file mode 100755 index 00000000..5a3d01d5 --- /dev/null +++ b/scripts/tokenize_dataset.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import argparse + +from datasets import Dataset +from transformers import AutoTokenizer + +from delphi.dataset.tokenization import tokenize_dataset +from delphi.eval.utils import load_validation_dataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + + parser.add_argument( + "--input-dataset-name", + type=str, + help="Text dataset from huggingface to tokenize", + ) + parser.add_argument( + "--output-dataset-name", + type=str, + help="Name of the tokenized dataset to upload to huggingface", + ) + parser.add_argument( + "--tokenizer-name", + type=str, + help="Name of the tokenizer from huggingface", + ) + parser.add_argument( + "--token", + type=str, + help="Hugging Face API token", + ) + parser.add_argument( + "--context-size", + type=int, + default=512, + help="Context size of the tokenized dataset as input of the model", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Batch size of text inputs into the tokenizer", + ) + parser.add_argument( + "--column-name", + type=str, + help="Name of the column containing text documents in the input dataset", + ) + args = parser.parse_args() + + input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}") + tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}") + + if args.column_name: + text_docs = input_dataset[args.column_name] + else: + if len(input_dataset.column_names) > 1: + raise ValueError("There are more than one column in the specified dataset") + text_docs = input_dataset[input_dataset.column_names[0]] + + output_dataset = Dataset.from_dict( + { + "tokens": tokenize_dataset( + text_docs, + tokenizer, + context_size=args.context_size, + batch_size=args.batch_size, + ) + } + ) + + output_dataset.push_to_hub( + repo_id=f"delphi-suite/{args.output_dataset_name}", + private=False, + token=args.token, + ) diff --git a/src/delphi/dataset/tokenization.py b/src/delphi/dataset/tokenization.py new file mode 100644 index 00000000..b800b64b --- /dev/null +++ b/src/delphi/dataset/tokenization.py @@ -0,0 +1,107 @@ +from collections import deque +from typing import Optional + +from transformers import PreTrainedTokenizerBase + + +def extend_deque( + dq: deque[int], + context_size: int, + text_documents: list[str], + doc_idx: int, + tokenizer: PreTrainedTokenizerBase, + batch_size: int, +) -> int: + """ + Extends the deque with tokenized text documents until the deque grows large + enough to reach the context size, or until all text documents are processed. + + The usage of a deque here aims to save the memory as opposed to + load all the documents and tokenize them at once. + + Args: + dq: Deque to extend with tokenized tokens. + context_size: Size of the context(input sequences). + text_documents: List of (untokenized) text documents to be tokenized. + doc_idx: Index of the current text story. + tokenizer: Tokenizer to encode the text strings. + Returns: + int: Updated index in the text documents dataset. + """ + while len(dq) < context_size and doc_idx < len(text_documents): + text_doc = text_documents[doc_idx : doc_idx + batch_size] + batch_input_ids = tokenizer( + text_doc, return_attention_mask=False, add_special_tokens=False + )["input_ids"] + for input_ids in batch_input_ids: + dq.extend(input_ids + [tokenizer.eos_token_id]) + doc_idx += batch_size + return doc_idx + + +def make_new_samples( + dq: deque[int], context_size: int, bos_token_id: int +) -> list[list[int]]: + """ + Generates new samples for training by creating sequences of tokens + from the deque until the deque does not hold enough tokens to generate + another sample. + + Note: the model is unable to use the last token in an input sequence, + so we repeat this token in the next input sequence. + + Args: + dq: Deque containing tokenized tokens. + context_size: Size of the context (input sequences). + bos_token_id: bos_token_id of the tokenizer used. + + Returns: + list[list[int]]: List of token sequences of the same length(context_size). + """ + + samples = [] + while len(dq) >= context_size: + sample = [bos_token_id] + + # For the first (n-1) elements, pop from the left of the deque + # and add to the new sample, the n-th element will be retained + # in the deque for making the next sample. + for _ in range(context_size - 1): + sample.append(dq.popleft()) + sample.append(dq[0]) + + samples.append(sample) + return samples + + +def tokenize_dataset( + text_documents: list[str], + tokenizer: PreTrainedTokenizerBase, + context_size: int, + batch_size: int, +) -> list[list[int]]: + """ + Tokenizes the input text documents using the provided tokenizer and + generates token sequences of the specified length. + + Args: + text_documents: List[str], + tokenizer, + context_size, + + Returns: + list[list[int]]: List of token sequences of length equal to context_size. + """ + + dq = deque() + doc_idx = 0 + samples = [] + + while doc_idx < len(text_documents): + doc_idx = extend_deque( + dq, context_size, text_documents, doc_idx, tokenizer, batch_size + ) + samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id)) + + # We discard the last chunk, so no processing on the remainder of the deque here + return samples diff --git a/tests/dataset/test_tokenizer.py b/tests/dataset/test_tokenizer.py new file mode 100644 index 00000000..99b2dcb3 --- /dev/null +++ b/tests/dataset/test_tokenizer.py @@ -0,0 +1,88 @@ +import collections +import random + +import pytest +from transformers import AutoTokenizer + +from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset + + +@pytest.fixture +def tokenizer(): + return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") + + +def test_extend_deque(tokenizer): + CTX_SIZE = 10 + BATCH_SIZE = 2 + # generate 100 random stories + text_stories = [ + " ".join( + [ + tokenizer.decode(random.randint(3, tokenizer.vocab_size)) + for _ in range(random.randint(100, 800)) + ] + ) + for _ in range(100) + ] + prompt_idx = 0 + dq = collections.deque() + + while prompt_idx < len(text_stories): + prompt_idx = extend_deque( + dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE + ) + if prompt_idx < len(text_stories) - 1: + # assert that the deque has grown large enough in each round + assert len(dq) >= CTX_SIZE + while len(dq) >= CTX_SIZE: + for _ in range(CTX_SIZE - 1): + dq.popleft() + + +def test_make_new_sample(tokenizer): + for _ in range(100): + total_tokens = random.randint(100, 1000) + context_size = random.randint(5, total_tokens // 2) + dq = collections.deque(random.choices(range(3, 1000), k=total_tokens)) + samples = make_new_samples(dq, context_size, tokenizer.bos_token_id) + tokens_cnt = 0 + for i, sample in enumerate(samples): + assert sample[0] == tokenizer.bos_token_id + if i > 0: + # assert that there is an overlap of the last token in the previous sample + # and the first token in its following sample + assert sample[1] == samples[i - 1][-1] + tokens_cnt += len(sample) + + # We discard the last chunk so the following lines are only for testing + tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning + assert tokens_cnt == total_tokens + ( + 2 * len(samples) + 1 + ) # BOS for each batch + overlapping of the last tokens in the batches + assert len(dq) > 0 # always leaving at least one element in the deque + + +def test_tokenize_dataset(tokenizer): + CTX_SIZE = 10 + BATCH_SIZE = 2 + + text_stories = [ + "Once upon a", + "Mother woke up alert. She put on her coat", + "Once upon a time, in a small town, there was a weird", + "Once upon a time, there was a", + "Sara and Tom are friends. They like to play in the park.", + ] + correct_batches = [ + [1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], + [1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], + [1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], + [1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], + [1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], + [1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], + ] + assert ( + tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) + == correct_batches + ) From bb5797f5b02598c645231083defa33c68ee3e9d2 Mon Sep 17 00:00:00 2001 From: Jai Dhyani Date: Sat, 30 Mar 2024 14:09:42 -0700 Subject: [PATCH 4/4] Train rework (#89) * rename "train_step" to "iteration_step" * isolate train_step * vscode debugging configs * replace generator with direct indexing for data * unused import * get rid of eval_callbacks * remove unused return value * only pass device to train_step * remember to init wandb when using it! * factor accumulate_gradients out + testing for train_step * fixed incorrect num_batches in test that could give cause negative * conditional wandb silencing in init_wandb * configurable train/validation datasets * fix docstring on accumulate_gradients * factor out some setup boilerplate from run_training * fix end2end demo * wip * CR: test deterministic train step --- .vscode/launch.json | 14 +- .vscode/settings.json | 5 + notebooks/end2end_demo.ipynb | 17 +- scripts/run_training.py | 8 +- .../sample_config.json | 4 +- .../sample_mamba.json | 4 +- src/delphi/constants.py | 2 +- src/delphi/eval/utils.py | 6 +- src/delphi/static/configs/debug.json | 4 +- src/delphi/static/configs/debug_mamba.json | 4 +- .../configs/debug_transformers_bloom.json | 4 +- src/delphi/train/checkpoint_step.py | 70 +++++ src/delphi/train/config/__init__.py | 3 +- src/delphi/train/config/data_config.py | 58 ++++ .../{gigaconfig.py => training_config.py} | 19 +- src/delphi/train/config/utils.py | 14 +- src/delphi/train/iteration_params.py | 4 +- src/delphi/train/train_step.py | 128 +++------ src/delphi/train/training.py | 122 +++++---- src/delphi/train/utils.py | 172 ++++++++---- src/delphi/train/wandb_utils.py | 11 +- tests/train/test_train_step.py | 253 ++++++++++++++++++ tests/train/test_wandb_utils.py | 16 +- 23 files changed, 706 insertions(+), 236 deletions(-) create mode 100644 src/delphi/train/checkpoint_step.py create mode 100644 src/delphi/train/config/data_config.py rename src/delphi/train/config/{gigaconfig.py => training_config.py} (88%) create mode 100644 tests/train/test_train_step.py diff --git a/.vscode/launch.json b/.vscode/launch.json index cc1d9f99..69d07ea3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,14 +4,22 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { - "name": "run_training 256", + "name": "run_training debug", "type": "debugpy", "request": "launch", "program": "scripts/run_training.py", "console": "integratedTerminal", - "args": "--debug --train_sample_limit=256" - //"args": "${command:pickArgs}" + "args": "--debug --loglevel 20" + }, + { + "name": "run_training custom", + "type": "debugpy", + "request": "launch", + "program": "scripts/run_training.py", + "console": "integratedTerminal", + "args": "${command:pickArgs}" }, { "name": "run_training --help", diff --git a/.vscode/settings.json b/.vscode/settings.json index 5a69a6b6..bbfd0f7e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,4 +8,9 @@ }, "python.analysis.typeCheckingMode": "basic", "black-formatter.importStrategy": "fromEnvironment", + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, } \ No newline at end of file diff --git a/notebooks/end2end_demo.ipynb b/notebooks/end2end_demo.ipynb index f08aba38..3f8e938d 100644 --- a/notebooks/end2end_demo.ipynb +++ b/notebooks/end2end_demo.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ "from delphi.eval.vis_per_token_model import visualize_per_token_category\n", "\n", "# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n", - "from delphi.eval.token_labelling import TOKEN_LABELS" + "from delphi.eval.spacy_token_labelling import TOKEN_LABELS" ] }, { @@ -38,12 +38,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# load data\n", - "tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n", + "tokenized_corpus_dataset = cast(Dataset, load_dataset(\n", + " constants.tokenized_corpus_dataset,\n", + " split=\"validation\"\n", + "))\n", "\n", "# TODO: convert to use static paths\n", "# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n", @@ -66,13 +69,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0f8846898fbb4a1b9e872ff6511acd3d", + "model_id": "d6c18c9588f3499b94e89ccea5954780", "version_major": 2, "version_minor": 0 }, @@ -80,7 +83,7 @@ "VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…" ] }, - "execution_count": 7, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } diff --git a/scripts/run_training.py b/scripts/run_training.py index 0b1960db..46b510b9 100755 --- a/scripts/run_training.py +++ b/scripts/run_training.py @@ -13,7 +13,7 @@ from delphi.constants import CONFIG_PRESETS_DIR from delphi.train.config import ( - GigaConfig, + TrainingConfig, build_config_from_files_and_overrides, get_preset_paths, get_user_config_path, @@ -208,7 +208,9 @@ def setup_parser() -> ( ) config_arg_group = parser.add_argument_group("Config arguments") help_parsers = dict() - add_dataclass_args_recursively(parser, GigaConfig, config_arg_group, help_parsers) + add_dataclass_args_recursively( + parser, TrainingConfig, config_arg_group, help_parsers + ) add_preset_args(parser) add_logging_args(parser) return parser, help_parsers @@ -232,7 +234,7 @@ def var_args_to_dict(config_vars: dict[str, Any]) -> dict[str, Any]: def args_to_dict(args: argparse.Namespace) -> dict[str, Any]: # at the toplevel, filter for args corresponding to field names in GigaConfig - field_names = set(field.name for field in fields(GigaConfig)) + field_names = set(field.name for field in fields(TrainingConfig)) config_vars = { k: v for k, v in vars(args).items() if k.split(".")[0] in field_names } diff --git a/scripts/training_config_examples/sample_config.json b/scripts/training_config_examples/sample_config.json index c6fc86b5..2efaef8f 100644 --- a/scripts/training_config_examples/sample_config.json +++ b/scripts/training_config_examples/sample_config.json @@ -51,7 +51,5 @@ "decay_lr": true, "warmup_iters": 1000, "min_lr": 0.0 - }, - "train_sample_limit": -1, - "val_sample_limit": -1 + } } \ No newline at end of file diff --git a/scripts/training_config_examples/sample_mamba.json b/scripts/training_config_examples/sample_mamba.json index 81bdf313..5b9a7be3 100644 --- a/scripts/training_config_examples/sample_mamba.json +++ b/scripts/training_config_examples/sample_mamba.json @@ -56,7 +56,5 @@ "decay_lr": true, "warmup_iters": 1000, "min_lr": 0.0 - }, - "train_sample_limit": -1, - "val_sample_limit": -1 + } } \ No newline at end of file diff --git a/src/delphi/constants.py b/src/delphi/constants.py index 4ede491e..8743087a 100644 --- a/src/delphi/constants.py +++ b/src/delphi/constants.py @@ -4,4 +4,4 @@ CONFIG_PRESETS_DIR = STATIC_ASSETS_DIR / "configs" CORPUS_DATASET = "delphi-suite/stories" -TOKENIZED_CORPUS_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized" +TINYSTORIES_TOKENIZED_HF_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized" diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index b45a49d1..6dd310d9 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -109,12 +109,12 @@ def tokenize( ) -def load_logprob_dataset(model: str) -> Dataset: - return load_dataset(f"transcendingvictor/{model}-validation-logprobs") # type: ignore +def load_logprob_dataset(model: str): + return load_dataset(f"transcendingvictor/{model}-validation-logprobs") def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]: return { - model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] + model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] # type: ignore for model in constants.LLAMA2_MODELS } diff --git a/src/delphi/static/configs/debug.json b/src/delphi/static/configs/debug.json index 1694c587..3fbc6de7 100644 --- a/src/delphi/static/configs/debug.json +++ b/src/delphi/static/configs/debug.json @@ -5,7 +5,9 @@ "max_epochs": 2, "eval_interval": 1, "eval_iters": 1, - "train_sample_limit": 256, + "data_config": { + "train_sample_limit": 256 + }, "batch_size": 64, "model_config": { "model_type": "LlamaForCausalLM", diff --git a/src/delphi/static/configs/debug_mamba.json b/src/delphi/static/configs/debug_mamba.json index 16d0391a..813cf1bd 100644 --- a/src/delphi/static/configs/debug_mamba.json +++ b/src/delphi/static/configs/debug_mamba.json @@ -6,7 +6,9 @@ "eval_interval": 1, "log_interval": 1, "eval_iters": 10, - "train_sample_limit": 64, + "data_config": { + "train_sample_limit": 64 + }, "batch_size": 8, "model_config": { "model_type": "MambaForCausalLM", diff --git a/src/delphi/static/configs/debug_transformers_bloom.json b/src/delphi/static/configs/debug_transformers_bloom.json index 18720113..2532e99d 100644 --- a/src/delphi/static/configs/debug_transformers_bloom.json +++ b/src/delphi/static/configs/debug_transformers_bloom.json @@ -5,7 +5,9 @@ "max_epochs": 2, "eval_interval": 1, "eval_iters": 1, - "train_sample_limit": 256, + "data_config": { + "train_sample_limit": 256 + }, "batch_size": 64, "model_config": { "model_type": "BloomForCausalLM", diff --git a/src/delphi/train/checkpoint_step.py b/src/delphi/train/checkpoint_step.py new file mode 100644 index 00000000..ac581296 --- /dev/null +++ b/src/delphi/train/checkpoint_step.py @@ -0,0 +1,70 @@ +import logging +from collections.abc import Callable + +from datasets import Dataset + +from .config import TrainingConfig +from .iteration_params import IterationParams +from .run_context import RunContext +from .utils import ( + CheckpointData, + ModelTrainingState, + estimate_loss, + save_checkpoint_if_needed, +) +from .wandb_utils import log_to_wandb + + +def should_save_checkpoint(config: TrainingConfig, mts: ModelTrainingState): + return mts.iter_num % config.eval_interval == 0 + + +def log_and_save_checkpoint( + config: TrainingConfig, + mts: ModelTrainingState, + iteration_params: IterationParams, + train_ds: Dataset, + validation_ds: Dataset, + run_context: RunContext, +): + """ + Save a checkpoint of the current model + training state, evaluate, and optionally upload to huggingface and log to wandb (if configured) + """ + model = mts.model + if config.debug_config.no_eval: + logging.debug("no_eval=True, skipping evaluation and using dummy losses") + losses = {"train": 42.0, "val": 43.0} + else: + losses = estimate_loss( + model=model, + eval_iters=iteration_params.eval_iters, + batch_size=config.batch_size, + split_to_ds={"train": train_ds, "val": validation_ds}, + device=run_context.device, + epoch=mts.epoch, + feature_names={ + "train": config.data_config.train_feature, + "val": ( + config.data_config.validation_feature + or config.data_config.train_feature + ), + }, + ) + new_best_val_loss = False + if losses["val"] < mts.best_val_loss: + mts.best_val_loss = float(losses["val"]) + new_best_val_loss = True + checkpoint_data = CheckpointData( + tokens_per_iter=iteration_params.tokens_per_iter, + losses=losses, + new_best_val_loss=new_best_val_loss, + config=config, + model_training_state=mts, + run_context=run_context, + ) + logging.info( + f"step {mts.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" + ) + save_checkpoint_if_needed(checkpoint_data) + if config.wandb_config.log: + log_to_wandb(checkpoint_data) diff --git a/src/delphi/train/config/__init__.py b/src/delphi/train/config/__init__.py index ce698e55..55512633 100644 --- a/src/delphi/train/config/__init__.py +++ b/src/delphi/train/config/__init__.py @@ -1,5 +1,6 @@ -from .gigaconfig import GigaConfig +from .model_config import ModelConfig from .optimizer_config import OptimizerConfig +from .training_config import TrainingConfig from .utils import ( build_config_dict_from_files, build_config_from_files, diff --git a/src/delphi/train/config/data_config.py b/src/delphi/train/config/data_config.py new file mode 100644 index 00000000..10fa303c --- /dev/null +++ b/src/delphi/train/config/data_config.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass, field +from typing import Optional + +from beartype import beartype + +from delphi import constants + + +@beartype +@dataclass(frozen=True) +class DataConfig: + train_dataset: str = field( + # TODO: remove default after updating configs to include this field + default=constants.TINYSTORIES_TOKENIZED_HF_DATASET, + metadata={"help": "tokenized dataset on huggingface to use for train"}, + ) + train_split: str = field( + default="train", + metadata={"help": "split of the train dataset to use for train"}, + ) + train_feature: str = field( + default="tokens", + metadata={ + "help": "feature in the train dataset to use for train; should be a list of max_seq_len+1 token ints" + }, + ) + train_sample_limit: Optional[int] = field( + default=None, + metadata={"help": "limit the number of train samples to use"}, + ) + + validation_dataset: Optional[str] = field( + default=None, + metadata={ + "help": ( + "tokenized dataset on huggingface to use for validation. " + "If not set, validation defaults to using train_dataset" + ) + }, + ) + validation_split: str = field( + default="validation", + metadata={"help": "split of the validation dataset to use for validation"}, + ) + validation_feature: Optional[str] = field( + default=None, + metadata={ + "help": ( + "feature in the validation dataset to use for validation; " + "should be a list of max_seq_len+1 token ints. " + "If not set, validation defaults to using train_feature." + ) + }, + ) + validation_sample_limit: Optional[int] = field( + default=None, + metadata={"help": "limit the number of validation samples to use"}, + ) diff --git a/src/delphi/train/config/gigaconfig.py b/src/delphi/train/config/training_config.py similarity index 88% rename from src/delphi/train/config/gigaconfig.py rename to src/delphi/train/config/training_config.py index 9d6236d2..25ca1586 100644 --- a/src/delphi/train/config/gigaconfig.py +++ b/src/delphi/train/config/training_config.py @@ -1,10 +1,12 @@ import os from dataclasses import dataclass, field from datetime import datetime +from typing import Optional import platformdirs from beartype import beartype +from .data_config import DataConfig from .debug_config import DebugConfig from .huggingface_config import HuggingfaceConfig from .model_config import ModelConfig @@ -14,7 +16,7 @@ @beartype @dataclass(frozen=True) -class GigaConfig: +class TrainingConfig: model_config: ModelConfig # meta run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") @@ -72,15 +74,22 @@ class GigaConfig: metadata={"help": "seed used for pseudorandomly sampling data during training"}, ) torch_seed: int = field(default=42, metadata={"help": "seed used for torch"}) + + # data + data_config: DataConfig = field( + default_factory=DataConfig, + metadata={"help": "specify training and validation data"}, + ) + # debugging - train_sample_limit: int = field( - default=-1, + train_sample_limit: Optional[int] = field( + default=None, metadata={ "help": "for debugging: limit size of the training set.# -1 implies no limit" }, ) - val_sample_limit: int = field( - default=-1, + val_sample_limit: Optional[int] = field( + default=None, metadata={ "help": "for debugging: limit size of the validation set. -1 implies no limit" }, diff --git a/src/delphi/train/config/utils.py b/src/delphi/train/config/utils.py index 719b522d..efc674f5 100644 --- a/src/delphi/train/config/utils.py +++ b/src/delphi/train/config/utils.py @@ -12,7 +12,7 @@ from delphi.constants import CONFIG_PRESETS_DIR -from .gigaconfig import GigaConfig +from .training_config import TrainingConfig def _merge_dicts(merge_into: dict[str, Any], merge_from: dict[str, Any]): @@ -39,7 +39,7 @@ def get_user_config_path() -> Path: return user_config_path -def get_presets_by_name() -> dict[str, GigaConfig]: +def get_presets_by_name() -> dict[str, TrainingConfig]: return { preset.stem: build_config_from_files([preset]) for preset in get_preset_paths() } @@ -118,22 +118,22 @@ def log_config_recursively( def build_config_from_files_and_overrides( config_files: list[Path], overrides: dict[str, Any], -) -> GigaConfig: +) -> TrainingConfig: combined_config = build_config_dict_from_files(config_files) _merge_dicts(merge_into=combined_config, merge_from=overrides) set_backup_vals(combined_config, config_files) - filter_config_to_actual_config_values(GigaConfig, combined_config) + filter_config_to_actual_config_values(TrainingConfig, combined_config) logging.info("User-set config values:") log_config_recursively( combined_config, logging_fn=logging.info, prefix=" ", indent=" " ) - return from_dict(GigaConfig, combined_config) + return from_dict(TrainingConfig, combined_config) -def build_config_from_files(config_files: list[Path]) -> GigaConfig: +def build_config_from_files(config_files: list[Path]) -> TrainingConfig: return build_config_from_files_and_overrides(config_files, {}) -def load_preset(preset_name: str) -> GigaConfig: +def load_preset(preset_name: str) -> TrainingConfig: preset_path = Path(CONFIG_PRESETS_DIR) / f"{preset_name}.json" # type: ignore return build_config_from_files([preset_path]) diff --git a/src/delphi/train/iteration_params.py b/src/delphi/train/iteration_params.py index b6c624db..e3290bd7 100644 --- a/src/delphi/train/iteration_params.py +++ b/src/delphi/train/iteration_params.py @@ -3,7 +3,7 @@ from datasets import Dataset -from .config import GigaConfig +from .config import TrainingConfig @dataclass @@ -16,7 +16,7 @@ class IterationParams: def set_iteration_params( - config: GigaConfig, train_ds: Dataset, validation_ds: Dataset + config: TrainingConfig, train_ds: Dataset, validation_ds: Dataset ) -> IterationParams: num_batches = len(train_ds) // config.batch_size # we take gradient_accumulation_steps batches per step (one in each microstep) diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index 9852d633..64bebbc1 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -1,110 +1,70 @@ import logging -import time -from collections.abc import Callable, Generator +from collections.abc import Iterable import torch from datasets import Dataset +from transformers import PreTrainedModel -from .config import GigaConfig -from .iteration_params import IterationParams -from .run_context import RunContext -from .utils import EvalData, ModelTrainingState, estimate_loss, get_next_xy, set_lr +from .config import TrainingConfig +from .utils import ModelTrainingState, gen_minibatches def train_step( model_training_state: ModelTrainingState, train_ds: Dataset, - validation_ds: Dataset, - iteration_params: IterationParams, - eval_callbacks: list[Callable], - config: GigaConfig, - train_batch_iter: Generator, - run_context: RunContext, + config: TrainingConfig, + device: torch.device, + ds_indices: list[int], ): """ - Runs a training step, updating (mutating in place) model_training_state - returns true if training should break, false otherwise + Runs a training step, updating (mutating in place) model_training_state: + - generate gradient_accumulation_steps batches (each batch is batch_size/gradient_accumulation_steps items) + - forward pass, accumulating gradient/gradient_accumulation_steps over gradient_accumulation_steps batches + - clip gradient where gradient exceeds grad_clip (if configured) + - backward pass, updating model weights + - reset grad """ model = model_training_state.model optimizer = model_training_state.optimizer - # here's how each train step works: - # 1. Set learning rate - # 2. (every eval_interval steps) evaluate, log to wandb, save checkpoint - # 3. forward backward update - # 4. log timing - - # 1. determine and set the learning rate for this iteration - model_training_state.lr = set_lr( - iteration_params.lr_decay_iters, - config, - optimizer, - model_training_state.iter_num, - ) - - # 2. evaluate the loss on train/val sets and write checkpoints - if model_training_state.iter_num % config.eval_interval == 0: - if config.debug_config.no_eval: - logging.debug("no_eval=True, skipping evaluation and using dummy losses") - losses = {"train": 42.0, "val": 43.0} - else: - losses = estimate_loss( - model=model, - eval_iters=iteration_params.eval_iters, - batch_size=config.batch_size, - split_to_ds={"train": train_ds, "val": validation_ds}, - device=run_context.device, - epoch=model_training_state.epoch, - ) - new_best_val_loss = False - if losses["val"] < model_training_state.best_val_loss: - model_training_state.best_val_loss = float(losses["val"]) - new_best_val_loss = True - eval_data = EvalData( - tokens_per_iter=iteration_params.tokens_per_iter, - losses=losses, - new_best_val_loss=new_best_val_loss, - config=config, - model_training_state=model_training_state, - run_context=run_context, - ) - logging.info( - f"step {model_training_state.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" - ) - for callback in eval_callbacks: - callback(eval_data) - - # 3. forward backward update, with optional gradient accumulation to simulate larger batch size if config.debug_config.no_training: + total_loss = 0.0 logging.debug("no_training set, skipping forward backward pass") - loss = torch.Tensor([42.1]).to(run_context.device) else: - for micro_step in range(config.optimizer.gradient_accumulation_steps): - X, Y = get_next_xy(train_batch_iter, run_context.device) - loss = ( - model(X, labels=Y, return_dict=True).loss - / config.optimizer.gradient_accumulation_steps - ) - loss.backward() + minibatches = gen_minibatches( + dataset=train_ds, + indices=ds_indices, + batch_size=config.batch_size, + num_minibatches=config.optimizer.gradient_accumulation_steps, + step=model_training_state.step, + device=device, + feature_name=config.data_config.train_feature, + ) + total_loss = accumulate_gradients( + model=model, + batches=minibatches, + num_batches=config.optimizer.gradient_accumulation_steps, + ) # clip the gradient if config.grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore optimizer.step() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) + model_training_state.train_loss = total_loss - # 4. log timing - t1 = time.time() - dt = t1 - model_training_state.last_training_step_time - model_training_state.last_training_step_time = t1 - if model_training_state.iter_num % config.log_interval == 0: - # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point - lossf = loss.item() * config.optimizer.gradient_accumulation_steps - logging.debug( - ( - f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {model_training_state.lr:e} | " - f"{dt*1000:.2f}ms" - ) - ) - model_training_state.iter_num += 1 - model_training_state.local_iter_num += 1 + +def accumulate_gradients( + model: PreTrainedModel, + batches: Iterable[tuple[torch.Tensor, torch.Tensor]], + num_batches: int, +) -> float: + """ + Accumulate gradients over multiple batches as if they were a single batch + """ + total_loss = 0.0 + for X, Y in batches: + loss = model(X, labels=Y, return_dict=True).loss / num_batches + total_loss += loss.item() + loss.backward() + return total_loss diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index ba0603c7..7881eebb 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -1,36 +1,50 @@ import logging import os +import time from dataclasses import fields -from typing import cast import torch -from datasets import Dataset from tqdm import tqdm from transformers import __version__ as transformers_version from delphi import __version__ as delphi_version -from . import wandb_utils -from .config import GigaConfig +from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint +from .config import TrainingConfig from .iteration_params import set_iteration_params from .run_context import RunContext from .train_step import train_step from .utils import ( ModelTrainingState, - batch_generator, get_device, + get_indices_for_epoch, initialize_model_training_state, - load_delphi_training_dataset, - save_checkpoint_if_needed, + load_tokens_dataset_from_huggingface, + set_lr, + setup_determinism, ) +from .wandb_utils import init_wandb -def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: +def setup_training(config: TrainingConfig): + logging.info("Setting up training...") + os.makedirs(config.output_dir, exist_ok=True) + + # torch misc - TODO: check if this is actually needed + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + + # determinism + setup_determinism(config.torch_seed) + + # wandb setup + if config.wandb_config.log: + init_wandb(config=config) + + +def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]: + setup_training(config) logging.info("Starting training...") - logging.debug("Setting torch.use_deterministic_algorithms(True)") - torch.use_deterministic_algorithms(True) - torch.backends.cudnn.benchmark = False - torch.manual_seed(config.torch_seed) logging.info("Config:") for field in fields(config): logging.info(f" {field.name}: {getattr(config, field.name)}") @@ -46,57 +60,73 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: # load data logging.info("Loading data...") - train_ds = cast( - Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit) + train_ds = load_tokens_dataset_from_huggingface( + hf_dataset_id=config.data_config.train_dataset, + split=config.data_config.train_split, + tokens_feature=config.data_config.train_feature, + limit=config.data_config.train_sample_limit, ) - validation_ds = cast( - Dataset, - load_delphi_training_dataset("validation", limit=config.val_sample_limit), + validation_ds = load_tokens_dataset_from_huggingface( + hf_dataset_id=( + config.data_config.validation_dataset or config.data_config.train_dataset + ), + split=config.data_config.validation_split, + tokens_feature=( + config.data_config.validation_feature or config.data_config.train_feature + ), + limit=config.data_config.validation_sample_limit, ) # derive iteration params (num_batches, num_steps, etc) iteration_params = set_iteration_params(config, train_ds, validation_ds) - # setup - logging.info("Setting up...") - os.makedirs(config.output_dir, exist_ok=True) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - # model init model_training_state = initialize_model_training_state(config, run_context.device) - # setup eval callbacks - logging.info("Setting eval step callbacks...") - eval_callbacks = [save_checkpoint_if_needed] - logging.info(f" added save_checkpoint_if_needed eval callback") - if config.wandb_config.log: - if config.wandb_config.silence: - wandb_utils.silence_wandb() - wandb_utils.init_wandb(config) - eval_callbacks.append(wandb_utils.log_to_wandb) - logging.info(f" added log_to_wandb callback") - # training loop logging.info("Starting training...") for epoch in range(config.max_epochs): logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}") - train_batch_iter = iter( - batch_generator( - train_ds, config.batch_size, epoch, config.batch_ordering_seed - ) + train_data_indices = get_indices_for_epoch( + dataset_size=len(train_ds), + batch_size=config.batch_size, + epoch=epoch, + ordering_seed=config.batch_ordering_seed, ) model_training_state.epoch = epoch for step in tqdm(range(iteration_params.num_steps)): model_training_state.step = step + if should_save_checkpoint(config, model_training_state): + log_and_save_checkpoint( + config=config, + mts=model_training_state, + iteration_params=iteration_params, + train_ds=train_ds, + validation_ds=validation_ds, + run_context=run_context, + ) + model_training_state.lr = set_lr( + lr_decay_iters=iteration_params.lr_decay_iters, + config=config, + optimizer=model_training_state.optimizer, + iter_num=model_training_state.iter_num, + ) train_step( - model_training_state, - train_ds, - validation_ds, - iteration_params, - eval_callbacks, - config, - train_batch_iter, - run_context, + model_training_state=model_training_state, + train_ds=train_ds, + config=config, + device=run_context.device, + ds_indices=train_data_indices, ) + t1 = time.time() + dt = t1 - model_training_state.last_training_step_time + model_training_state.last_training_step_time = t1 + if model_training_state.iter_num % config.log_interval == 0: + logging.debug( + ( + f"{model_training_state.iter_num} | loss {model_training_state.train_loss:.4f} | lr {model_training_state.lr:e} | " + f"{dt*1000:.2f}ms" + ) + ) + model_training_state.iter_num += 1 return model_training_state, run_context diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index c974d573..67d25e6b 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -6,19 +6,17 @@ from collections.abc import Generator from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import cast +from typing import Optional, cast +import datasets import safetensors.torch as st import torch -from datasets import Dataset +from datasets import Dataset, load_dataset from huggingface_hub import HfApi from torch.optim import AdamW from transformers import PreTrainedModel -from delphi import constants -from delphi.eval.utils import load_delphi_dataset - -from .config import GigaConfig +from .config import TrainingConfig from .run_context import RunContext from .shuffle import shuffle_list @@ -27,14 +25,11 @@ class ModelTrainingState: """mutable training state - stuff that changes over the course of training""" - model: torch.nn.Module + model: PreTrainedModel optimizer: torch.optim.Optimizer iter_num: int = field( metadata={"help": "total iterations so far across all epochs"} ) - local_iter_num: int = field( - metadata={"help": "total iterations on this instance so far"} - ) best_val_loss: float = field(metadata={"help": "best validation loss so far"}) last_training_step_time: float = field( metadata={"help": "time last iteration ended"} @@ -42,19 +37,30 @@ class ModelTrainingState: epoch: int = field(metadata={"help": "current epoch"}) step: int = field(metadata={"help": "step within current epoch"}) lr: float = field(default=1.0e-5, metadata={"help": "learning rate"}) + train_loss: float = field( + default=0.0, metadata={"help": "loss on most recent train step"} + ) @dataclass -class EvalData: - # values we expose to eval callback functions +class CheckpointData: + """values we expose to assorted checkpoint/eval functions""" + tokens_per_iter: int losses: dict[str, float] new_best_val_loss: bool - config: GigaConfig + config: TrainingConfig model_training_state: ModelTrainingState run_context: RunContext +def setup_determinism(seed: int): + logging.debug(f"Setting up torch determinism (seed={seed})...") + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + + def get_device(device_str: str = "auto") -> torch.device: """ Get torch device specified by device_str. May pass "auto" to set torch device automatically. @@ -92,10 +98,13 @@ def get_lr( def set_lr( lr_decay_iters: int, - config: GigaConfig, + config: TrainingConfig, optimizer: torch.optim.Optimizer, iter_num: int, ): + """ + Set the learning rate (calculated by get_lr) on the optimizer + """ lr = ( get_lr( iter_num=iter_num, @@ -112,7 +121,7 @@ def set_lr( return lr -def save_checkpoint_if_needed(eval_data: EvalData): +def save_checkpoint_if_needed(eval_data: CheckpointData): mts = eval_data.model_training_state # we save if it's not the first iter AND at least one of: # 1) we have a new best validation loss @@ -134,7 +143,7 @@ def save_checkpoint_if_needed(eval_data: EvalData): def initialize_model_training_state( - config: GigaConfig, device: torch.device + config: TrainingConfig, device: torch.device ) -> ModelTrainingState: t0 = time.time() model = config.model_config.get_model() @@ -171,48 +180,67 @@ def initialize_model_training_state( optimizer=optimizer, last_training_step_time=t0, iter_num=training_state_vals.get("iter_num", 0), - local_iter_num=training_state_vals.get("local_iter_num", 0), best_val_loss=training_state_vals.get("best_val_loss", 1e9), epoch=training_state_vals.get("epoch", 0), step=training_state_vals.get("step", 0), ) -def load_delphi_training_dataset(split: str, limit: int = -1): - """For training, we want (X, Y) pairs, where X is a chunk of text and Y is the next token.) - To construct this, we take the original tokenized dataset, break it into max_seq_len+1 length chunks, - and then take [:-1] as X and [1:] as Y. - """ - if limit == -1: - ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split) - else: - ds = load_delphi_dataset(constants.TOKENIZED_CORPUS_DATASET, split).select( - range(limit) - ) - ds.set_format("torch") - return ds +def get_indices_for_epoch( + dataset_size: int, batch_size: int, epoch: int, ordering_seed: int +) -> list[int]: + """ """ + indices = list(range(dataset_size)) + shuffle_list(indices, seed=ordering_seed + epoch) + return indices -def get_next_xy( - train_batch_iter: Generator, device: torch.device +def get_xy_batch( + dataset: Dataset, + indices: list[int], + batch_size: int, + batch_num: int, + feature_name: str, + device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - """break a (max_seq_len +1) sequence of tokens into sample [:-1] and label [1:] pairs""" - data = next(train_batch_iter).to(device) - X, Y = data[:, :-1], data[:, 1:] - return X, Y + """ + Get a batch of data from a dataset given a batch number and indices + + Args: + """ + start = batch_num * batch_size + end = (batch_num + 1) * batch_size + batch_indices = indices[start:end] + data = dataset[batch_indices][feature_name].to(device) + return data[:, :-1], data[:, 1:] -def batch_generator( - dataset: Dataset, batch_size: int, epoch: int, ordering_seed: int -) -> Generator[torch.Tensor, None, None]: +def gen_minibatches( + dataset: Dataset, + batch_size: int, + num_minibatches: int, + step: int, + indices: list[int], + device: torch.device, + feature_name: str, +) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]: """ - Generate batches of training data for a given epoch with pseudorandom determinism + Generate minibatches from a dataset given a step and indices """ - sampler = list(range(len(dataset))) # type: ignore - shuffle_list(sampler, seed=ordering_seed + epoch) - sampler = torch.Tensor(sampler) - for samples in sampler.split(batch_size): - yield dataset[samples]["tokens"] + assert ( + batch_size % num_minibatches == 0 + ), "batch_size must be divisible by num_minibatches" + minibatch_size = batch_size // num_minibatches + first_minibatch_num = num_minibatches * step + for i in range(num_minibatches): + yield get_xy_batch( + dataset=dataset, + indices=indices, + batch_num=first_minibatch_num + i, + batch_size=minibatch_size, + feature_name=feature_name, + device=device, + ) @torch.no_grad() @@ -223,15 +251,30 @@ def estimate_loss( split_to_ds: dict[str, Dataset], device: torch.device, epoch: int, + feature_names: dict[str, str], ) -> dict[str, float]: """helps estimate an arbitrarily accurate loss over either split using many batches""" out = {} model.eval() for split, ds in split_to_ds.items(): - batch_iter = iter(batch_generator(ds, batch_size, epoch, 1234)) + indices = get_indices_for_epoch( + dataset_size=len(ds), + batch_size=batch_size, + epoch=epoch, + ordering_seed=1234, + ) + eval_iters = min(eval_iters, len(ds) // batch_size) losses = torch.zeros(eval_iters) # keep on CPU - for k in range(min(eval_iters, len(ds) // batch_size)): # type: ignore - X, Y = get_next_xy(batch_iter, device) + minibatches = gen_minibatches( + dataset=ds, + batch_size=batch_size, + num_minibatches=eval_iters, + step=0, + indices=indices, + device=device, + feature_name=feature_names[split], + ) + for k, (X, Y) in enumerate(minibatches): loss = model(X, labels=Y, return_dict=True).loss losses[k] = loss.item() out[split] = losses.mean() @@ -240,7 +283,7 @@ def estimate_loss( def save_results( - config: GigaConfig, + config: TrainingConfig, train_results: ModelTrainingState, run_context: RunContext, results_path: str, @@ -270,7 +313,6 @@ def save_results( with open(os.path.join(results_path, "training_state.json"), "w") as file: training_state_dict = { "iter_num": train_results.iter_num, - "local_iter_num": train_results.local_iter_num, "best_val_loss": train_results.best_val_loss, "lr": train_results.lr, "epoch": train_results.epoch, @@ -288,3 +330,33 @@ def save_results( repo_id=str(config.huggingface.repo_id), path_in_repo=f"iter_{train_results.iter_num}/", ) + + +def load_tokens_dataset_from_huggingface( + hf_dataset_id: str, + split: str, + tokens_feature: str, + limit: Optional[int] = None, +) -> Dataset: + """Load a dataset from huggingface + + Args: + hf_dataset_id (str): huggingface dataset id e.g. "delphi-suite/v0-tinystories-v2-clean-tokenized" + split (str): split to load, e.g. "train" or "validation" + tokens_feature (str): feature name for tokens, e.g. "tokens" + limit (Optional[int], optional): limit the number of samples. None (default) means no limit (use full dataset split) + """ + ds = cast( + Dataset, + load_dataset( + hf_dataset_id, + split=split, + features=datasets.Features( + {tokens_feature: datasets.Sequence(datasets.Value("int32"))} + ), + ), + ) + if limit is not None and limit > 0: + ds = ds.select(range(limit)) + ds.set_format("torch") + return ds diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py index b24926df..ac308c78 100644 --- a/src/delphi/train/wandb_utils.py +++ b/src/delphi/train/wandb_utils.py @@ -4,19 +4,18 @@ import wandb -from .config import GigaConfig -from .utils import EvalData +from .config import TrainingConfig +from .utils import CheckpointData def silence_wandb(): - # set env var WANDB_SILENT=true logging.info("silencing wandb output") os.environ["WANDB_SILENT"] = "true" -def init_wandb(config: GigaConfig): +def init_wandb(config: TrainingConfig): # if log level < debug, silence wandb - if logging.getLogger().level > logging.INFO: + if logging.getLogger().level > logging.INFO or config.wandb_config.silence: silence_wandb() wandb.init( entity=config.wandb_config.entity, @@ -26,7 +25,7 @@ def init_wandb(config: GigaConfig): ) -def log_to_wandb(eval_data: EvalData): +def log_to_wandb(eval_data: CheckpointData): mts = eval_data.model_training_state try: wandb.log( diff --git a/tests/train/test_train_step.py b/tests/train/test_train_step.py new file mode 100644 index 00000000..fccc823b --- /dev/null +++ b/tests/train/test_train_step.py @@ -0,0 +1,253 @@ +# TODO: there are some ugly hacks here, and the test states are way too complicated +# clean this up as other parts of the codebase are refactored + +from dataclasses import asdict + +import dacite +import pytest +import torch +from datasets import Dataset +from jaxtyping import Float + +from delphi.train.config import TrainingConfig +from delphi.train.config.utils import load_preset +from delphi.train.train_step import accumulate_gradients, train_step +from delphi.train.utils import ModelTrainingState, get_xy_batch, setup_determinism + + +@pytest.fixture +def dataset(): + ds = Dataset.from_dict( + { + "tokens": [list(range(i, i + 512)) for i in range(64)], + }, + ) + ds.set_format(type="torch") + return ds + + +@pytest.fixture +def model(): + setup_determinism(42) + # TODO: replace this with a model config dict after model_config update is in (next PR) + return load_preset("debug").model_config.get_model() + + +def get_params(model: torch.nn.Module) -> Float[torch.Tensor, "params"]: + params = [ + (name, param) for name, param in model.named_parameters() if param.requires_grad + ] + params.sort(key=lambda x: x[0]) + return torch.cat([p.flatten() for _, p in params]) + + +def test_basic_reproducibility(dataset, model): + """ + check that the same batch produces the same gradient + """ + # setup + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + model_training_state = ModelTrainingState( + model=model, + optimizer=optimizer, + iter_num=0, + epoch=0, + step=0, + train_loss=0.0, + lr=0.01, + best_val_loss=float("inf"), + last_training_step_time=0.0, + ) + device = torch.device("cpu") + indices = list(range(64)) + + # train + train_step(model_training_state, dataset, load_preset("debug"), device, indices) + + params = get_params(model) + + assert torch.isclose( + params[[1000, 2000, 3000]], + torch.tensor([-0.01782517, -0.00771354, 0.03517739]), + ).all() + + +def test_accumulate_gradients_accumulates(dataset, model): + """ + check that gradient accumulation works as expected and doesn't reset on each microstep + """ + # setup + indices_set_a = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # different batch but idential last batch (with batches of 3); + # this should result in a different accumulated gradient + indices_set_b = [7, 8, 9, 7, 8, 9, 7, 8, 9] + batch_size = 3 + num_batches = len(indices_set_a) // batch_size + + batches_a = [ + get_xy_batch( + dataset=dataset, + indices=indices_set_a, + batch_size=3, + batch_num=microstep, + feature_name="tokens", + device=torch.device("cpu"), + ) + for microstep in range(num_batches) + ] + batches_b = [ + get_xy_batch( + dataset=dataset, + indices=indices_set_b, + batch_size=3, + batch_num=microstep, + feature_name="tokens", + device=torch.device("cpu"), + ) + for microstep in range(num_batches) + ] + + # accumulate + _total_loss = accumulate_gradients(model, batches_a, len(batches_a)) + + grads_a = torch.cat( + [ + param.grad.clone().detach().flatten() + for param in model.parameters() + if param.grad is not None + ] + ) + + # reset grad on model + model.zero_grad() + + _total_loss = accumulate_gradients(model, batches_b, len(batches_b)) + grads_b = torch.cat( + [ + param.grad.clone().detach().flatten() + for param in model.parameters() + if param.grad is not None + ] + ) + + # test + assert not torch.isclose(grads_a, grads_b).all() + + +def test_accumulate_gradients_consistent(dataset, model): + """ + Validate that the gradients are consistent when the same batch is passed to accumulate_gradients + """ + # setup + indices_set = list(range(1, 10)) + num_batches = 3 + batch_size = 3 + batches_a = [ + get_xy_batch( + dataset=dataset, + indices=indices_set, + batch_size=batch_size, + batch_num=microstep, + feature_name="tokens", + device=torch.device("cpu"), + ) + for microstep in range(num_batches) + ] + batches_aa = [ + get_xy_batch( + dataset=dataset, + indices=indices_set, + batch_size=batch_size, + batch_num=microstep, + feature_name="tokens", + device=torch.device("cpu"), + ) + for microstep in range(num_batches) + ] + + # accumulate + total_loss = accumulate_gradients(model, batches_a, num_batches) + + grads_a = torch.cat( + [ + param.grad.clone().detach().flatten() + for param in model.parameters() + if param.grad is not None + ] + ) + + # reset grad on model + model.zero_grad() + + total_loss = accumulate_gradients(model, batches_aa, num_batches) + grads_aa = torch.cat( + [ + param.grad.clone().detach().flatten() + for param in model.parameters() + if param.grad is not None + ] + ) + + # test + assert torch.isclose(grads_a, grads_aa).all() + + +def get_model_training_state(model, optimizer, step): + return ModelTrainingState( + model=model, + optimizer=optimizer, + iter_num=0, + epoch=0, + step=step, + train_loss=0.0, + lr=0.01, + best_val_loss=float("inf"), + last_training_step_time=0.0, + ) + + +def test_train_step_no_training(dataset, model): + """ + Test train_step when no_training is set to True + """ + # setup + config_dict = asdict(load_preset("debug")) + config_dict["debug_config"] = {"no_training": True} + config = dacite.from_dict(TrainingConfig, config_dict) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + model_training_state = get_model_training_state( + model=model, optimizer=optimizer, step=0 + ) + device = torch.device("cpu") + indices = [0, 1, 2, 3] + + # (don't) train + train_step(model_training_state, dataset, config, device, indices) + + # test + assert model_training_state.train_loss == 0.0 + + +def test_train_step_with_training(dataset, model): + """ + Test train_step when training is performed + """ + # setup + config_dict = asdict(load_preset("debug")) + config_dict["debug_config"] = {"no_training": False} + config_dict["batch_size"] = 16 + config_dict["optimizer"] = {"gradient_accumulation_steps": 4} + config_dict["grad_clip"] = 1.0 + config = dacite.from_dict(TrainingConfig, config_dict) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + model_training_state = get_model_training_state( + model=model, optimizer=optimizer, step=0 + ) + device = torch.device("cpu") + indices = list(range(len(dataset))) + + # train + train_step(model_training_state, dataset, config, device, indices) + + # test + assert model_training_state.train_loss > 0.0 diff --git a/tests/train/test_wandb_utils.py b/tests/train/test_wandb_utils.py index a7cbb80e..9b4213c9 100644 --- a/tests/train/test_wandb_utils.py +++ b/tests/train/test_wandb_utils.py @@ -4,20 +4,18 @@ import pytest import torch -import transformers from dacite import from_dict -from delphi.train.config import GigaConfig -from delphi.train.config.utils import load_preset +from delphi.train.config import TrainingConfig from delphi.train.run_context import RunContext -from delphi.train.utils import EvalData, initialize_model_training_state +from delphi.train.utils import CheckpointData, initialize_model_training_state from delphi.train.wandb_utils import init_wandb, log_to_wandb, silence_wandb @pytest.fixture def mock_giga_config(): config = from_dict( - GigaConfig, + TrainingConfig, { "run_name": "test_run", "device": "cpu", @@ -55,8 +53,8 @@ def mock_model_training_state(mock_giga_config): @pytest.fixture -def mock_eval_data(mock_giga_config, mock_model_training_state): - eval_data = EvalData( +def mock_checkpoint_data(mock_giga_config, mock_model_training_state): + eval_data = CheckpointData( model_training_state=mock_model_training_state, tokens_per_iter=1000, losses={"train": 0.5, "val": 0.4}, @@ -91,8 +89,8 @@ def test_init_wandb(mock_wandb_init: MagicMock, mock_giga_config): @patch("wandb.log") -def test_log_to_wandb(mock_wandb_log, mock_eval_data): - log_to_wandb(mock_eval_data) +def test_log_to_wandb(mock_wandb_log, mock_checkpoint_data): + log_to_wandb(mock_checkpoint_data) mock_wandb_log.assert_called_once_with( { "iter": 1,