diff --git a/scripts/training_config_examples/sample_config.json b/scripts/training_config_examples/sample_config.json index b2de5093..c6fc86b5 100644 --- a/scripts/training_config_examples/sample_config.json +++ b/scripts/training_config_examples/sample_config.json @@ -16,9 +16,8 @@ "batch_size": 64, "max_seq_len": 512, "model_config": { - "model_type": "llama2", - "mamba": null, - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": -1, diff --git a/scripts/training_config_examples/sample_mamba.json b/scripts/training_config_examples/sample_mamba.json index 94a2dd37..81bdf313 100644 --- a/scripts/training_config_examples/sample_mamba.json +++ b/scripts/training_config_examples/sample_mamba.json @@ -16,8 +16,8 @@ "batch_size": 64, "max_seq_len": 512, "model_config": { - "model_type": "mamba", - "mamba": { + "model_type": "MambaForCausalLM", + "model_params": { "vocab_size": 4096, "hidden_size": 768, "state_size": 16, diff --git a/scripts/training_config_examples/sample_transformers_bloom.json b/scripts/training_config_examples/sample_transformers_bloom.json index 7d812ea4..4030dcb9 100644 --- a/scripts/training_config_examples/sample_transformers_bloom.json +++ b/scripts/training_config_examples/sample_transformers_bloom.json @@ -8,7 +8,7 @@ "batch_size": 64, "model_config": { "model_type": "BloomForCausalLM", - "transformers_config": { + "model_params": { "apply_residual_connection_post_layernorm": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 2d052974..b45a49d1 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from typing import cast @@ -85,7 +86,9 @@ def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Datas # or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189 split=f"train{slice}", ) - return cast(Dataset, dataset) + dataset = cast(Dataset, dataset) + logging.info(f" Loaded {data_files_str} ({len(dataset)} entries)") + return dataset def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset: diff --git a/src/delphi/static/configs/debug.json b/src/delphi/static/configs/debug.json index 5d40717a..1694c587 100644 --- a/src/delphi/static/configs/debug.json +++ b/src/delphi/static/configs/debug.json @@ -8,8 +8,8 @@ "train_sample_limit": 256, "batch_size": 64, "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "hidden_size": 48, "intermediate_size": 48, "num_attention_heads": 2, diff --git a/src/delphi/static/configs/debug_mamba.json b/src/delphi/static/configs/debug_mamba.json index 2a619e59..16d0391a 100644 --- a/src/delphi/static/configs/debug_mamba.json +++ b/src/delphi/static/configs/debug_mamba.json @@ -9,8 +9,8 @@ "train_sample_limit": 64, "batch_size": 8, "model_config": { - "model_type": "mamba", - "mamba": { + "model_type": "MambaForCausalLM", + "model_params": { "vocab_size": 4096, "hidden_size": 48, "state_size": 16, diff --git a/src/delphi/static/configs/debug_transformers_bloom.json b/src/delphi/static/configs/debug_transformers_bloom.json index 6c89e9e2..18720113 100644 --- a/src/delphi/static/configs/debug_transformers_bloom.json +++ b/src/delphi/static/configs/debug_transformers_bloom.json @@ -9,7 +9,7 @@ "batch_size": 64, "model_config": { "model_type": "BloomForCausalLM", - "transformers_config": { + "model_params": { "apply_residual_connection_post_layernorm": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-1.6m.json b/src/delphi/static/configs/v0-llama2-1.6m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-1.6m.json +++ b/src/delphi/static/configs/v0-llama2-1.6m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-100k-quick.json b/src/delphi/static/configs/v0-llama2-100k-quick.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-100k-quick.json +++ b/src/delphi/static/configs/v0-llama2-100k-quick.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-100k.json b/src/delphi/static/configs/v0-llama2-100k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-100k.json +++ b/src/delphi/static/configs/v0-llama2-100k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-12.8m.json b/src/delphi/static/configs/v0-llama2-12.8m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-12.8m.json +++ b/src/delphi/static/configs/v0-llama2-12.8m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-200k.json b/src/delphi/static/configs/v0-llama2-200k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-200k.json +++ b/src/delphi/static/configs/v0-llama2-200k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-25.6m.json b/src/delphi/static/configs/v0-llama2-25.6m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-25.6m.json +++ b/src/delphi/static/configs/v0-llama2-25.6m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-3.2m.json b/src/delphi/static/configs/v0-llama2-3.2m.json index 909a898c..978b7007 100644 --- a/src/delphi/static/configs/v0-llama2-3.2m.json +++ b/src/delphi/static/configs/v0-llama2-3.2m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-400k.json b/src/delphi/static/configs/v0-llama2-400k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-400k.json +++ b/src/delphi/static/configs/v0-llama2-400k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-6.4m.json b/src/delphi/static/configs/v0-llama2-6.4m.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-6.4m.json +++ b/src/delphi/static/configs/v0-llama2-6.4m.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/static/configs/v0-llama2-800k.json b/src/delphi/static/configs/v0-llama2-800k.json index 376c2876..a3b275ef 100644 --- a/src/delphi/static/configs/v0-llama2-800k.json +++ b/src/delphi/static/configs/v0-llama2-800k.json @@ -1,7 +1,7 @@ { "model_config": { - "model_type": "llama2", - "llama2": { + "model_type": "LlamaForCausalLM", + "model_params": { "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 1, diff --git a/src/delphi/train/config/gigaconfig.py b/src/delphi/train/config/gigaconfig.py index 4da07941..9d6236d2 100644 --- a/src/delphi/train/config/gigaconfig.py +++ b/src/delphi/train/config/gigaconfig.py @@ -7,7 +7,7 @@ from .debug_config import DebugConfig from .huggingface_config import HuggingfaceConfig -from .models import ModelConfig +from .model_config import ModelConfig from .optimizer_config import OptimizerConfig from .wandb_config import WandbConfig diff --git a/src/delphi/train/config/model_config.py b/src/delphi/train/config/model_config.py new file mode 100644 index 00000000..df6b66d7 --- /dev/null +++ b/src/delphi/train/config/model_config.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass, field +from typing import Any, cast + +import transformers +from beartype import beartype +from beartype.typing import Type + + +@beartype +@dataclass(frozen=True) +class ModelConfig: + model_type: str = field( + metadata={ + "help": ( + "Name of any CausalLM Model from the transformers " + "library (e.g. 'BartForCausalLM'). Model configuration arguments, " + "e.g. hidden size, should be specified in model_params" + ) + } + ) + model_params: dict[str, Any] = field( + default_factory=dict, + metadata={ + "help": ( + "config for the transformers model specified by model_type. " + "e.g. {'hidden_size': 128, ...}" + ) + }, + ) + + def get_model(self): + model_class = getattr(transformers, self.model_type) + config_class = cast( + Type[transformers.PretrainedConfig], model_class.config_class + ) + return model_class(config_class(**(self.model_params))) diff --git a/src/delphi/train/config/models/__init__.py b/src/delphi/train/config/models/__init__.py deleted file mode 100644 index cdf71c6a..00000000 --- a/src/delphi/train/config/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .model_config import ModelConfig -from .model_types import ModelType, ModelTypes -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig diff --git a/src/delphi/train/config/models/model_config.py b/src/delphi/train/config/models/model_config.py deleted file mode 100644 index 55e3bf36..00000000 --- a/src/delphi/train/config/models/model_config.py +++ /dev/null @@ -1,64 +0,0 @@ -from dataclasses import asdict, dataclass, field -from typing import Any, Optional, Type, cast - -import transformers -from beartype import beartype -from beartype.typing import Type -from transformers import PreTrainedModel - -from .model_types import ModelType, ModelTypes -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig - - -@beartype -@dataclass(frozen=True) -class ModelConfig: - model_type: str = field( - metadata={ - "help": ( - "The model type to train. May be either a predefined " - "type (delphi, mamba) or any CausalLM Model from the transformers " - "library (e.g. BartForCausalLM). Predefined types should " - "specify their respective configs in this model config; " - "transformer library models should specify their model " - "config arguments in transformers_config." - ) - } - ) - transformers_config: dict[str, Any] = field( - default_factory=dict, - metadata={"help": "config for the transformers model specified by model_type"}, - ) - mamba: Optional[TypedMambaConfig] = field( - default=None, - metadata={"help": "config for Delphi mamba model. See TypedMambaConfig"}, - ) - llama2: Optional[TypedLlamaConfig] = field( - default=None, - metadata={"help": "config for Delphi llama2 model. See TypedLlamaConfig"}, - ) - - def is_predefined_type(self): - return hasattr(self, self.model_type) - - def get_config_args(self) -> dict[str, Any]: - if self.is_predefined_type(): - return asdict(getattr(self, self.model_type)) - else: - return self.transformers_config - - def get_model_class(self) -> type[PreTrainedModel]: - if self.is_predefined_type(): - model_type = cast(ModelType, ModelTypes.get(self.model_type)) - return model_type.model - else: - model_class = getattr(transformers, self.model_type) - return model_class - - def get_model(self): - model_class = self.get_model_class() - config_class = cast( - Type[transformers.PretrainedConfig], model_class.config_class - ) - return model_class(config_class(**(self.get_config_args()))) diff --git a/src/delphi/train/config/models/model_types.py b/src/delphi/train/config/models/model_types.py deleted file mode 100644 index 5ccd0990..00000000 --- a/src/delphi/train/config/models/model_types.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -For any given model we use, there are three associated types: -- TypedModelConfig: a typed dataclass that defines the arguments to the model. - We use this to enforce some semblance of type safety in configs and code in general. -- PretrainedConfig: a transformers config that defines the model architecture. - The arguments for this are defined in TypedModelConfig. -- PreTrainedModel: a transformers model that implements the model architecture. - Configured by PretrainedConfig. - - This file defines a ModelType dataclass that associated these three types for a given model, - and a ModelTypes container class that defines all the models we use in Delphi along with a - helpful ModelTypes.get() method for getting ModelType from a string. -""" -from dataclasses import dataclass -from typing import Optional - -from beartype import beartype -from beartype.typing import Type -from transformers import LlamaForCausalLM, MambaForCausalLM, PreTrainedModel - -from .typed_llama_config import TypedLlamaConfig -from .typed_mamba_config import TypedMambaConfig -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class ModelType: - name: str - delphi_config: type[TypedModelConfig] - model: type[PreTrainedModel] - - # Allow for ModelType == 'llama2' - def __eq__(self, other): - if isinstance(other, str): - return self.name == other - else: - return super().__eq__(other) - - def __post_init__(self): - # register the ModelType so ModelTypes.get(model_type_name) works - _model_name_to_model_type[self.name.lower()] = self - - -_model_name_to_model_type: dict[str, ModelType] = {} - - -# define new model types here -class ModelTypes: - MAMBA = ModelType( - name="mamba", - delphi_config=TypedMambaConfig, - model=MambaForCausalLM, - ) - LLAMA2 = ModelType( - name="llama2", - delphi_config=TypedLlamaConfig, - model=LlamaForCausalLM, - ) - - # NEWMODEL = ModelType( # var name should match name - # name="newmodel", # string that will be associated with model in configs, etc - # typed_config=TypedNewModelConfig, # typed dataclass for args to config - # config=NewModelConfig, # transformers config - # model=NewModelForCausalLM, # transformers model - # ) - - @classmethod - def get(cls: Type["ModelTypes"], name: str) -> Optional[ModelType]: - return _model_name_to_model_type.get(name.lower()) diff --git a/src/delphi/train/config/models/typed_llama_config.py b/src/delphi/train/config/models/typed_llama_config.py deleted file mode 100644 index 7e1a2f83..00000000 --- a/src/delphi/train/config/models/typed_llama_config.py +++ /dev/null @@ -1,30 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Optional - -from beartype import beartype - -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class TypedLlamaConfig(TypedModelConfig): - attention_bias: bool = False - attention_dropout: float = 0.0 - bos_token_id: int = -1 - eos_token_id: int = -2 - hidden_act: str = "silu" - hidden_size: int = 288 - initializer_range: float = 0.02 - intermediate_size: int = 288 - max_position_embeddings: int = 512 - num_attention_heads: int = 6 - num_hidden_layers: int = 6 - num_key_value_heads: int = 6 - pretraining_tp: int = 1 - rms_norm_eps: float = 1e-06 - rope_scaling: Optional[dict[str, Any]] = None - rope_theta: float = 10000.0 - tie_word_embeddings: bool = False - use_cache: bool = True - vocab_size: int = 4096 diff --git a/src/delphi/train/config/models/typed_mamba_config.py b/src/delphi/train/config/models/typed_mamba_config.py deleted file mode 100644 index 8b01a932..00000000 --- a/src/delphi/train/config/models/typed_mamba_config.py +++ /dev/null @@ -1,39 +0,0 @@ -from dataclasses import dataclass -from typing import Union - -from beartype import beartype - -from .typed_model_config import TypedModelConfig - - -@beartype -@dataclass(frozen=True) -class TypedMambaConfig(TypedModelConfig): - # model shape - vocab_size: int = 4096 - hidden_size: int = 768 - state_size: int = 16 - num_hidden_layers: int = 32 - conv_kernel: int = 4 - expand: int = 2 - use_bias: bool = False - use_conv_bias: bool = True - # tokens - bos_token_id: int = 0 - eos_token_id: int = 0 - pad_token_id: int = 0 - # time step - time_step_rank: Union[int, str] = "auto" - time_step_scale: float = 1.0 - time_step_min: float = 0.001 - time_step_max: float = 0.1 - time_step_init_scheme: str = "random" # "random" or "uniform" - time_step_floor: float = 0.0001 - # misc - layer_norm_epsilon: float = 1e-05 - hidden_act: str = "silu" - initializer_range: float = 0.1 - residual_in_fp32: bool = True - rescale_prenorm_residual: bool = False - use_cache: bool = True - tie_word_embeddings: bool = True diff --git a/src/delphi/train/config/models/typed_model_config.py b/src/delphi/train/config/models/typed_model_config.py deleted file mode 100644 index 8eae1de8..00000000 --- a/src/delphi/train/config/models/typed_model_config.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class TypedModelConfig: - """ - This is a dummy class for typing purposes. We could make a Union class that we update - every time we add a TypedModelConfig class, but that would mean remembering to go update - another thing when adding a new TypedModelConfig. - """ - - def __init__(self): - raise NotImplementedError( - "TypedModelConfig is a dummy class to provide typing for actual ModelConfig classes. It shouldn't ever be instantiated." - ) diff --git a/src/delphi/train/config/utils.py b/src/delphi/train/config/utils.py index 0f582104..719b522d 100644 --- a/src/delphi/train/config/utils.py +++ b/src/delphi/train/config/utils.py @@ -1,8 +1,10 @@ import json import logging import os +from dataclasses import fields, is_dataclass from datetime import datetime from pathlib import Path +from typing import Type import platformdirs from beartype.typing import Any, Iterable @@ -68,6 +70,24 @@ def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]: return combined_config +def filter_config_to_actual_config_values(target_dataclass: Type, config: dict): + """Remove non-config values from config dict. + + This can happen if e.g. being lazy and passing in all args from a script + """ + datafields = fields(target_dataclass) + name_to_field = {f.name: f for f in datafields} + to_remove = [] + for k, v in config.items(): + if k not in name_to_field.keys(): + logging.debug(f"removing non-config-value {k}={v} from config dict") + to_remove.append(k) + elif isinstance(v, dict) and is_dataclass(name_to_field.get(k)): + filter_config_to_actual_config_values(name_to_field[k].type, v) + for k in to_remove: + config.pop(k) + + def set_backup_vals(config: dict[str, Any], config_files: list[Path]): if len(config_files) == 1: prefix = f"{config_files[0].stem}__" @@ -76,10 +96,23 @@ def set_backup_vals(config: dict[str, Any], config_files: list[Path]): if "run_name" not in config: run_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") config["run_name"] = f"{prefix}{run_time}" + logging.info(f"Setting run_name to {config['run_name']}") if "output_dir" not in config: config["output_dir"] = os.path.join( platformdirs.user_data_dir(appname="delphi"), config["run_name"] ) + logging.info(f"Setting output_dir to {config['output_dir']}") + + +def log_config_recursively( + config: dict, logging_fn=logging.info, indent=" ", prefix="" +): + for k, v in config.items(): + if isinstance(v, dict): + logging_fn(f"{prefix}{k}") + log_config_recursively(v, logging_fn, indent, prefix=indent + prefix) + else: + logging_fn(f"{prefix}{k}: {v}") def build_config_from_files_and_overrides( @@ -89,6 +122,11 @@ def build_config_from_files_and_overrides( combined_config = build_config_dict_from_files(config_files) _merge_dicts(merge_into=combined_config, merge_from=overrides) set_backup_vals(combined_config, config_files) + filter_config_to_actual_config_values(GigaConfig, combined_config) + logging.info("User-set config values:") + log_config_recursively( + combined_config, logging_fn=logging.info, prefix=" ", indent=" " + ) return from_dict(GigaConfig, combined_config) diff --git a/src/delphi/train/config/wandb_config.py b/src/delphi/train/config/wandb_config.py index 716dec0a..23fbc3a0 100644 --- a/src/delphi/train/config/wandb_config.py +++ b/src/delphi/train/config/wandb_config.py @@ -7,3 +7,4 @@ class WandbConfig: log: bool = False project: str = "delphi" entity: str = "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work" + silence: bool = False diff --git a/src/delphi/train/iteration_params.py b/src/delphi/train/iteration_params.py index 8a7e8d88..b6c624db 100644 --- a/src/delphi/train/iteration_params.py +++ b/src/delphi/train/iteration_params.py @@ -30,9 +30,23 @@ def set_iteration_params( * config.batch_size * config.max_seq_len ) - logging.debug(f"tokens per iteration will be: {tokens_per_iter:,}") - logging.debug( - f"breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} max seq len" + logging.info("Iteration setup:") + logging.info(f" batch size: {config.batch_size}") + logging.info(f" training set size: {len(train_ds)}") + logging.info(f" training batches: {num_batches}") + logging.info( + f" gradient accumulations per step (=batches per step): {config.optimizer.gradient_accumulation_steps}" + ) + logging.info(f" steps per batch: {num_steps}") + logging.info(f" tokens per sequence: {config.max_seq_len}") + logging.info(f" tokens per training step will be: {tokens_per_iter:,}") + logging.info( + f" breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} tokens per sequence" + ) + logging.info(f" validation set size: {len(validation_ds)}") + logging.info(f" batches per validation step: {eval_iters}") + logging.info( + f" tokens per validation step: {eval_iters * config.batch_size * config.max_seq_len:,}" ) return IterationParams( num_batches, num_steps, eval_iters, lr_decay_iters, tokens_per_iter diff --git a/src/delphi/train/train_step.py b/src/delphi/train/train_step.py index baee2147..9852d633 100644 --- a/src/delphi/train/train_step.py +++ b/src/delphi/train/train_step.py @@ -6,7 +6,6 @@ from datasets import Dataset from .config import GigaConfig -from .config.models import ModelTypes from .iteration_params import IterationParams from .run_context import RunContext from .utils import EvalData, ModelTrainingState, estimate_loss, get_next_xy, set_lr @@ -76,84 +75,36 @@ def train_step( callback(eval_data) # 3. forward backward update, with optional gradient accumulation to simulate larger batch size - logging.debug( - f"gradient accumulation steps: {config.optimizer.gradient_accumulation_steps}, " - f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}" - ) - for micro_step in range(config.optimizer.gradient_accumulation_steps): - X, Y = get_next_xy(train_batch_iter, run_context.device) - if config.debug_config.no_training: - logging.debug("no_training set, skipping forward backward update") - loss = torch.Tensor([42.1]).to(run_context.device) - else: + if config.debug_config.no_training: + logging.debug("no_training set, skipping forward backward pass") + loss = torch.Tensor([42.1]).to(run_context.device) + else: + for micro_step in range(config.optimizer.gradient_accumulation_steps): + X, Y = get_next_xy(train_batch_iter, run_context.device) loss = ( model(X, labels=Y, return_dict=True).loss / config.optimizer.gradient_accumulation_steps ) loss.backward() - if config.debug_config.no_training: - logging.debug("debug no_training is set, skipping optimizer step") - else: # clip the gradient if config.grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore optimizer.step() - - # flush the gradients as soon as we can, no need for this memory anymore - optimizer.zero_grad(set_to_none=True) + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) # 4. log timing t1 = time.time() - dt = t1 - model_training_state.t0 - model_training_state.t0 = t1 + dt = t1 - model_training_state.last_training_step_time + model_training_state.last_training_step_time = t1 if model_training_state.iter_num % config.log_interval == 0: # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point lossf = loss.item() * config.optimizer.gradient_accumulation_steps - if ( - model_training_state.local_iter_num >= 5 - ): # let the training loop settle a bit - mfu = estimate_mfu( - config=config, model=model_training_state.model, timedelta=dt - ) - model_training_state.running_mfu = ( - mfu - if model_training_state.running_mfu == -1.0 - else 0.9 * model_training_state.running_mfu + 0.1 * mfu - ) logging.debug( ( f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {model_training_state.lr:e} | " - f"{dt*1000:.2f}ms | mfu {model_training_state.running_mfu*100:.2f}%" + f"{dt*1000:.2f}ms" ) ) model_training_state.iter_num += 1 model_training_state.local_iter_num += 1 - - -def estimate_mfu(config: GigaConfig, model: torch.nn.Module, timedelta: float) -> float: - """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" - # first estimate the number of flops we do per iteration. - # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 - N = sum(p.numel() for p in model.parameters()) - if config.model_config.model_type == ModelTypes.LLAMA2: - cfg = model.config - L, H, Q, T = ( - cfg.num_hidden_layers, - cfg.num_attention_heads, - cfg.hidden_size // cfg.num_attention_heads, - cfg.max_position_embeddings, - ) - else: - logging.debug( - f"estimate_mfu not implemented for {config.model_config.model_type}, setting MFU to -1" - ) - return -1.0 - flops_per_token = 6 * N + 12 * L * H * Q * T - flops_per_fwdbwd = flops_per_token * T - fwdbwd_per_iter = config.batch_size * config.optimizer.gradient_accumulation_steps - flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter - # express our flops throughput as ratio of A100 bfloat16 peak flops - flops_achieved = flops_per_iter * (1.0 / timedelta) # per second - flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS - mfu = flops_achieved / flops_promised - return mfu diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index ab189f14..ba0603c7 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -45,7 +45,7 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: logging.debug(f"Run context: {run_context}") # load data - logging.debug("Loading data...") + logging.info("Loading data...") train_ds = cast( Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit) ) @@ -67,14 +67,20 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]: model_training_state = initialize_model_training_state(config, run_context.device) # setup eval callbacks + logging.info("Setting eval step callbacks...") eval_callbacks = [save_checkpoint_if_needed] + logging.info(f" added save_checkpoint_if_needed eval callback") if config.wandb_config.log: + if config.wandb_config.silence: + wandb_utils.silence_wandb() wandb_utils.init_wandb(config) eval_callbacks.append(wandb_utils.log_to_wandb) + logging.info(f" added log_to_wandb callback") # training loop logging.info("Starting training...") for epoch in range(config.max_epochs): + logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}") train_batch_iter = iter( batch_generator( train_ds, config.batch_size, epoch, config.batch_ordering_seed diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 3d7b2a5c..c974d573 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -4,7 +4,7 @@ import os import time from collections.abc import Generator -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import cast @@ -25,16 +25,23 @@ @dataclass class ModelTrainingState: + """mutable training state - stuff that changes over the course of training""" + model: torch.nn.Module optimizer: torch.optim.Optimizer - iter_num: int - local_iter_num: int - best_val_loss: float - running_mfu: float - t0: float - epoch: int - step: int - lr: float = 1.0e-5 + iter_num: int = field( + metadata={"help": "total iterations so far across all epochs"} + ) + local_iter_num: int = field( + metadata={"help": "total iterations on this instance so far"} + ) + best_val_loss: float = field(metadata={"help": "best validation loss so far"}) + last_training_step_time: float = field( + metadata={"help": "time last iteration ended"} + ) + epoch: int = field(metadata={"help": "current epoch"}) + step: int = field(metadata={"help": "step within current epoch"}) + lr: float = field(default=1.0e-5, metadata={"help": "learning rate"}) @dataclass @@ -63,25 +70,6 @@ def get_device(device_str: str = "auto") -> torch.device: return torch.device(device_str) -def get_optimizer( - model: torch.nn.Module, - config: GigaConfig, - output_dir=None, - device: torch.device = torch.device("cpu"), -) -> AdamW: - optimizer = AdamW( - lr=config.optimizer.learning_rate, - params=model.parameters(), - weight_decay=config.optimizer.weight_decay, - betas=(config.optimizer.beta1, config.optimizer.beta2), - ) - if output_dir is not None: - opt_path = os.path.join(output_dir, "opt.pt") - with open(opt_path, "rb") as f: - optimizer.load_state_dict(torch.load(f)) - return optimizer - - def get_lr( iter_num: int, warmup_iters: int, @@ -145,56 +133,48 @@ def save_checkpoint_if_needed(eval_data: EvalData): ) -def load_model_from_checkpoint(config: GigaConfig, output_dir: str) -> torch.nn.Module: - model = config.model_config.get_model() - st.load_model(model, os.path.join(output_dir, "model", "model.safetensors")) - return model - - def initialize_model_training_state( config: GigaConfig, device: torch.device ) -> ModelTrainingState: t0 = time.time() - training_state = None + model = config.model_config.get_model() + model.to(device) # type: ignore + optimizer = AdamW( + lr=config.optimizer.learning_rate, + params=model.parameters(), + weight_decay=config.optimizer.weight_decay, + betas=(config.optimizer.beta1, config.optimizer.beta2), + ) + training_state_vals = dict() if config.init_from == "scratch": - # init a new model from scratch - logging.debug("Initializing a new model from scratch") - model = config.model_config.get_model() - checkpoint = None + logging.info(f" initialized model and optimizer from scratch") # TODO: resume from huggingface model elif config.init_from == "resume": logging.info(f"Resuming training from {config.output_dir}") checkpoint = config.output_dir - model = load_model_from_checkpoint(config, checkpoint) + st.load_model( + model, os.path.join(config.output_dir, "model", "model.safetensors") + ) with open(os.path.join(checkpoint, "training_state.json"), "r") as f: - training_state = json.load(f) - model.to(device) # type: ignore - # optimizer - optimizer = get_optimizer( - model=model, - config=config, - output_dir=config.output_dir - if (Path(config.output_dir) / "opt.safetensors").exists() - else None, - device=device, - ) - epoch = training_state.get("epoch", 0) if training_state is not None else 0 - step = training_state.get("step", 0) if training_state is not None else 0 - best_val_loss = training_state.get("best_val_loss", 1e9) if training_state else 1e9 - iter_num = training_state.get("iter_num", 0) if training_state else 0 - local_iter_num = training_state.get("local_iter_num", 0) if training_state else 0 - running_mfu = training_state.get("running_mfu", 0.0) if training_state else -1.0 - checkpoint = None # free up memory + training_state_vals = json.load(f) + opt_state_dict_path = Path(os.path.join(config.output_dir, "opt.pt")) + if opt_state_dict_path.exists(): + with open(opt_state_dict_path, "rb") as f: + logging.info(" Loading optimizer state from {state_dict_path}") + optimizer.load_state_dict(torch.load(f)) + else: + raise ValueError( + f"{config.init_from} is not one of (scratch, resume), which are the two valid initialization methods. Unable to initialize model." + ) return ModelTrainingState( model=model, optimizer=optimizer, - iter_num=iter_num, - local_iter_num=local_iter_num, - best_val_loss=best_val_loss, - running_mfu=running_mfu, - t0=t0, - epoch=epoch, - step=step, + last_training_step_time=t0, + iter_num=training_state_vals.get("iter_num", 0), + local_iter_num=training_state_vals.get("local_iter_num", 0), + best_val_loss=training_state_vals.get("best_val_loss", 1e9), + epoch=training_state_vals.get("epoch", 0), + step=training_state_vals.get("step", 0), ) @@ -214,10 +194,9 @@ def load_delphi_training_dataset(split: str, limit: int = -1): def get_next_xy( - train_batch_iter: Generator, - device: torch.device - # train_batch_iter: Generator[dict[str, list[int]], None, None], device: torch.device + train_batch_iter: Generator, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: + """break a (max_seq_len +1) sequence of tokens into sample [:-1] and label [1:] pairs""" data = next(train_batch_iter).to(device) X, Y = data[:, :-1], data[:, 1:] return X, Y @@ -226,6 +205,9 @@ def get_next_xy( def batch_generator( dataset: Dataset, batch_size: int, epoch: int, ordering_seed: int ) -> Generator[torch.Tensor, None, None]: + """ + Generate batches of training data for a given epoch with pseudorandom determinism + """ sampler = list(range(len(dataset))) # type: ignore shuffle_list(sampler, seed=ordering_seed + epoch) sampler = torch.Tensor(sampler) @@ -257,19 +239,18 @@ def estimate_loss( return out -def upload_to_huggingface(eval_data: EvalData): - model = eval_data.model_training_state.model - if isinstance(model, PreTrainedModel): - model = cast(PreTrainedModel, model) - model.save_pretrained(eval_data.config.output_dir) - - def save_results( config: GigaConfig, train_results: ModelTrainingState, run_context: RunContext, results_path: str, ): + """ + save results to disk, and to huggingface if configured to do so. + + Saves everything required to replicate the current state of training, including optimizer state, + config, context (e.g. hardware), training step, etc + """ os.makedirs(results_path, exist_ok=True) with open(os.path.join(results_path, "config.json"), "w") as file: json.dump(asdict(config), file, indent=2) @@ -291,7 +272,6 @@ def save_results( "iter_num": train_results.iter_num, "local_iter_num": train_results.local_iter_num, "best_val_loss": train_results.best_val_loss, - "running_mfu": train_results.running_mfu, "lr": train_results.lr, "epoch": train_results.epoch, "step": train_results.step, diff --git a/src/delphi/train/wandb_utils.py b/src/delphi/train/wandb_utils.py index 89ed5518..b24926df 100644 --- a/src/delphi/train/wandb_utils.py +++ b/src/delphi/train/wandb_utils.py @@ -10,6 +10,7 @@ def silence_wandb(): # set env var WANDB_SILENT=true + logging.info("silencing wandb output") os.environ["WANDB_SILENT"] = "true" @@ -35,7 +36,6 @@ def log_to_wandb(eval_data: EvalData): "loss/train": eval_data.losses["train"], "loss/val": eval_data.losses["val"], "lr": mts.lr, - "mfu": mts.running_mfu * 100, # convert to percentage }, step=mts.iter_num, ) diff --git a/tests/train/config/models/test_model_config.py b/tests/train/config/models/test_model_config.py index aad4a13a..eca5fe69 100644 --- a/tests/train/config/models/test_model_config.py +++ b/tests/train/config/models/test_model_config.py @@ -1,9 +1,8 @@ import pytest from dacite import from_dict -from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import BloomConfig, BloomForCausalLM -from delphi.train.config.models import ModelConfig, TypedLlamaConfig -from delphi.train.config.models.model_config import ModelConfig +from delphi.train.config.model_config import ModelConfig @pytest.fixture @@ -11,8 +10,8 @@ def llama_config(): return from_dict( ModelConfig, { - "model_type": "llama2", - "llama2": {"hidden_size": 49, "num_attention_heads": 7}, + "model_type": "LlamaForCausalLM", + "model_params": {"hidden_size": 49, "num_attention_heads": 7}, }, ) @@ -23,36 +22,20 @@ def bloom_config(): ModelConfig, { "model_type": "BloomForCausalLM", - "transformers_config": {"layer_norm_epsilon": 0.0042}, + "model_params": {"layer_norm_epsilon": 0.0042}, }, ) -def test_deserialziation(llama_config): - direct_llama_config = ModelConfig( - model_type="llama2", - llama2=TypedLlamaConfig(hidden_size=49, num_attention_heads=7), +def test_deserialziation(bloom_config): + direct_bloom_config = ModelConfig( + model_type="BloomForCausalLM", + model_params=dict(layer_norm_epsilon=0.0042), ) - assert llama_config == direct_llama_config + assert bloom_config == direct_bloom_config -def test_model_config_is_predefined_type(llama_config): - assert llama_config.is_predefined_type() - - -def test_model_config_is_not_predefined_type(bloom_config): - assert not bloom_config.is_predefined_type() - - -def test_config_to_model_predefined(llama_config): - model = llama_config.get_model() - - assert isinstance(model, LlamaForCausalLM) - assert isinstance(model.config, LlamaConfig) - assert model.config.hidden_size == 49 - - -def test_config_to_model_generic_type(bloom_config): +def test_config_to_model(bloom_config): model = bloom_config.get_model() assert isinstance(model, BloomForCausalLM) diff --git a/tests/train/test_wandb_utils.py b/tests/train/test_wandb_utils.py index 093fd0c5..a7cbb80e 100644 --- a/tests/train/test_wandb_utils.py +++ b/tests/train/test_wandb_utils.py @@ -4,10 +4,11 @@ import pytest import torch +import transformers from dacite import from_dict from delphi.train.config import GigaConfig -from delphi.train.config.models import TypedLlamaConfig +from delphi.train.config.utils import load_preset from delphi.train.run_context import RunContext from delphi.train.utils import EvalData, initialize_model_training_state from delphi.train.wandb_utils import init_wandb, log_to_wandb, silence_wandb @@ -21,8 +22,15 @@ def mock_giga_config(): "run_name": "test_run", "device": "cpu", "model_config": { - "model_type": "llama2", - "llama2": asdict(TypedLlamaConfig()), + "model_type": "LlamaForCausalLM", + "model_params": { + "hidden_size": 48, + "intermediate_size": 48, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 2, + "vocab_size": 4096, + }, }, "wandb_config": { "log": True, @@ -43,7 +51,6 @@ def mock_model_training_state(mock_giga_config): mts.epoch = 1 mts.iter_num = 1 mts.lr = 0.001 - mts.running_mfu = 3.0 return mts @@ -93,7 +100,6 @@ def test_log_to_wandb(mock_wandb_log, mock_eval_data): "loss/train": 0.5, "loss/val": 0.4, "lr": 0.001, - "mfu": 300.0, }, step=1, )