Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make training more legible #74

Merged
merged 13 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Callable
from typing import cast

Expand Down Expand Up @@ -85,7 +86,9 @@ def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Datas
# or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189
split=f"train{slice}",
)
return cast(Dataset, dataset)
dataset = cast(Dataset, dataset)
logging.info(f" Loaded {data_files_str} ({len(dataset)} entries)")
return dataset


def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset:
Expand Down
38 changes: 38 additions & 0 deletions src/delphi/train/config/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import logging
import os
from dataclasses import fields
from datetime import datetime
from pathlib import Path
from typing import Type

import platformdirs
from beartype.typing import Any, Iterable
Expand Down Expand Up @@ -68,6 +70,24 @@ def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]:
return combined_config


def filter_config_to_actual_config_values(target_dataclass: Type, config: dict):
"""Remove non-config values from config dict.

This can happen if e.g. being lazy and passing in all args from a script
"""
datafields = fields(target_dataclass)
name_to_field = {f.name: f for f in datafields}
to_remove = []
for k, v in config.items():
if k not in name_to_field.keys():
logging.debug(f"removing non-config-value {k}={v} from config dict")
to_remove.append(k)
elif isinstance(v, dict) and isinstance(name_to_field[k].type, type):
filter_config_to_actual_config_values(name_to_field[k].type, v)
for k in to_remove:
config.pop(k)


def set_backup_vals(config: dict[str, Any], config_files: list[Path]):
if len(config_files) == 1:
prefix = f"{config_files[0].stem}__"
Expand All @@ -76,10 +96,23 @@ def set_backup_vals(config: dict[str, Any], config_files: list[Path]):
if "run_name" not in config:
run_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
config["run_name"] = f"{prefix}{run_time}"
logging.info(f"Setting run_name to {config['run_name']}")
if "output_dir" not in config:
config["output_dir"] = os.path.join(
platformdirs.user_data_dir(appname="delphi"), config["run_name"]
)
logging.info(f"Setting output_dir to {config['output_dir']}")


def log_config_recursively(
config: dict, logging_fn=logging.info, indent=" ", prefix=""
):
for k, v in config.items():
if isinstance(v, dict):
logging_fn(f"{prefix}{k}")
log_config_recursively(v, logging_fn, indent, prefix=indent + prefix)
else:
logging_fn(f"{prefix}{k}: {v}")


def build_config_from_files_and_overrides(
Expand All @@ -89,6 +122,11 @@ def build_config_from_files_and_overrides(
combined_config = build_config_dict_from_files(config_files)
_merge_dicts(merge_into=combined_config, merge_from=overrides)
set_backup_vals(combined_config, config_files)
filter_config_to_actual_config_values(GigaConfig, combined_config)
logging.info("User-set config values:")
log_config_recursively(
combined_config, logging_fn=logging.info, prefix=" ", indent=" "
)
return from_dict(GigaConfig, combined_config)


Expand Down
1 change: 1 addition & 0 deletions src/delphi/train/config/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class WandbConfig:
log: bool = False
project: str = "delphi"
entity: str = "set_wandb.entity_to_your_wandb_username_to_make_wandb_logging_work"
silence: bool = False
20 changes: 17 additions & 3 deletions src/delphi/train/iteration_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,23 @@ def set_iteration_params(
* config.batch_size
* config.max_seq_len
)
logging.debug(f"tokens per iteration will be: {tokens_per_iter:,}")
logging.debug(
f"breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} max seq len"
logging.info("Iteration setup:")
logging.info(f" batch size: {config.batch_size}")
logging.info(f" training set size: {len(train_ds)}")
logging.info(f" training batches: {num_batches}")
logging.info(
f" gradient accumulations per step (=batches per step): {config.optimizer.gradient_accumulation_steps}"
)
logging.info(f" steps per batch: {num_steps}")
logging.info(f" tokens per sequence: {config.max_seq_len}")
logging.info(f" tokens per training step will be: {tokens_per_iter:,}")
logging.info(
f" breaks down as: {config.optimizer.gradient_accumulation_steps} grad accum steps * {config.batch_size} batch size * {config.max_seq_len} tokens per sequence"
)
logging.info(f" validation set size: {len(validation_ds)}")
logging.info(f" batches per validation step: {eval_iters}")
logging.info(
f" tokens per validation step: {eval_iters * config.batch_size * config.max_seq_len:,}"
)
return IterationParams(
num_batches, num_steps, eval_iters, lr_decay_iters, tokens_per_iter
Expand Down
70 changes: 11 additions & 59 deletions src/delphi/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,84 +76,36 @@ def train_step(
callback(eval_data)

# 3. forward backward update, with optional gradient accumulation to simulate larger batch size
logging.debug(
f"gradient accumulation steps: {config.optimizer.gradient_accumulation_steps}, "
f"num_steps: {iteration_params.num_steps}, iter_num: {model_training_state.iter_num}"
)
for micro_step in range(config.optimizer.gradient_accumulation_steps):
X, Y = get_next_xy(train_batch_iter, run_context.device)
if config.debug_config.no_training:
logging.debug("no_training set, skipping forward backward update")
loss = torch.Tensor([42.1]).to(run_context.device)
else:
if config.debug_config.no_training:
logging.debug("no_training set, skipping forward backward pass")
loss = torch.Tensor([42.1]).to(run_context.device)
else:
for micro_step in range(config.optimizer.gradient_accumulation_steps):
X, Y = get_next_xy(train_batch_iter, run_context.device)
loss = (
model(X, labels=Y, return_dict=True).loss
/ config.optimizer.gradient_accumulation_steps
)
loss.backward()
if config.debug_config.no_training:
logging.debug("debug no_training is set, skipping optimizer step")
else:
# clip the gradient
if config.grad_clip != 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) # type: ignore
optimizer.step()

# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

# 4. log timing
t1 = time.time()
dt = t1 - model_training_state.t0
model_training_state.t0 = t1
dt = t1 - model_training_state.last_training_step_time
model_training_state.last_training_step_time = t1
if model_training_state.iter_num % config.log_interval == 0:
# get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
lossf = loss.item() * config.optimizer.gradient_accumulation_steps
if (
model_training_state.local_iter_num >= 5
): # let the training loop settle a bit
mfu = estimate_mfu(
config=config, model=model_training_state.model, timedelta=dt
)
model_training_state.running_mfu = (
mfu
if model_training_state.running_mfu == -1.0
else 0.9 * model_training_state.running_mfu + 0.1 * mfu
)
logging.debug(
(
f"{model_training_state.iter_num} | loss {lossf:.4f} | lr {model_training_state.lr:e} | "
f"{dt*1000:.2f}ms | mfu {model_training_state.running_mfu*100:.2f}%"
f"{dt*1000:.2f}ms"
)
)
model_training_state.iter_num += 1
model_training_state.local_iter_num += 1


def estimate_mfu(config: GigaConfig, model: torch.nn.Module, timedelta: float) -> float:
"""estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
# first estimate the number of flops we do per iteration.
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
N = sum(p.numel() for p in model.parameters())
if config.model_config.model_type == ModelTypes.LLAMA2:
cfg = model.config
L, H, Q, T = (
cfg.num_hidden_layers,
cfg.num_attention_heads,
cfg.hidden_size // cfg.num_attention_heads,
cfg.max_position_embeddings,
)
else:
logging.debug(
f"estimate_mfu not implemented for {config.model_config.model_type}, setting MFU to -1"
)
return -1.0
flops_per_token = 6 * N + 12 * L * H * Q * T
flops_per_fwdbwd = flops_per_token * T
fwdbwd_per_iter = config.batch_size * config.optimizer.gradient_accumulation_steps
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of A100 bfloat16 peak flops
flops_achieved = flops_per_iter * (1.0 / timedelta) # per second
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
8 changes: 7 additions & 1 deletion src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]:
logging.debug(f"Run context: {run_context}")

# load data
logging.debug("Loading data...")
logging.info("Loading data...")
train_ds = cast(
Dataset, load_delphi_training_dataset("train", limit=config.train_sample_limit)
)
Expand All @@ -67,14 +67,20 @@ def run_training(config: GigaConfig) -> tuple[ModelTrainingState, RunContext]:
model_training_state = initialize_model_training_state(config, run_context.device)

# setup eval callbacks
logging.info("Setting eval step callbacks...")
eval_callbacks = [save_checkpoint_if_needed]
logging.info(f" added save_checkpoint_if_needed eval callback")
if config.wandb_config.log:
if config.wandb_config.silence:
wandb_utils.silence_wandb()
wandb_utils.init_wandb(config)
eval_callbacks.append(wandb_utils.log_to_wandb)
logging.info(f" added log_to_wandb callback")

# training loop
logging.info("Starting training...")
for epoch in range(config.max_epochs):
logging.info(f"Epoch: {epoch} / {config.max_epochs - 1}")
train_batch_iter = iter(
batch_generator(
train_ds, config.batch_size, epoch, config.batch_ordering_seed
Expand Down
Loading
Loading