diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 7e92d200b..4cc000e59 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po max_length=Pos.size, ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], Pos) + # padding doesn't do truncation, so we have to do it ourselves. + # Truncate from the left since we want to predict the last tokens + input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 9d048b24f..6d9165cfc 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Optional, Tuple, TypeVar +from typing import Mapping, Optional, Tuple, TypeVar import equinox as eqx import jax.numpy as jnp @@ -56,7 +56,7 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T], initial_proxy: M, ref: M, - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, validation_sets: Optional[dict[str, AsyncDataset[T]]] = None, @@ -184,7 +184,9 @@ def doremi_step(state: DoremiState, ref, batch, domains): # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts with trainer: - tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key) + tagged_mixture: MixtureDataset = domain_tagged_mixture( + data_sources, sampling_weights, domain_to_index, key=data_key + ) state = load_checkpoint_or_initialize( DoremiState.init, trainer.checkpoint_path, @@ -263,7 +265,7 @@ def _prepare_ref_model(ref, trainer): def domain_tagged_mixture( - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], weights: dict[str, float], domain_to_index: dict[str, int], *, diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 12b3e6ae0..742c3229c 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -109,7 +109,7 @@ def init_proxy_model(): train_datasets = config.data.training_sets(ref_model.Pos.size) valid_datasets = config.data.validation_sets(ref_model.Pos.size) - train_datasets = { + causal_train_datasets = { k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) for k, v in train_datasets.items() } @@ -122,7 +122,7 @@ def init_proxy_model(): loss_function, proxy_model, ref=ref_model, - data_sources=train_datasets, + data_sources=causal_train_datasets, trainer_config=config.trainer, optimizer=optimizer, domain_weight_step_size=config.doremi.domain_weight_step_size, diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b411bd59e..f2ad3e7ce 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -268,6 +268,9 @@ def compute_log_probs(model, example): checkpointer = trainer.config.checkpointer.create(trainer.run_id) checkpointer.wait_until_finished() + # This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray) + trainer.tracker.finish() + if __name__ == "__main__": levanter.config.main(main)() diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 911e74b09..1a82aa7be 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -1,4 +1,5 @@ import abc +from dataclasses import dataclass from typing import Generic, Optional, Type, TypeVar import draccus @@ -48,6 +49,7 @@ def causal( # TODO: for some reason, mypy doesn't like the discover_packages_path argument? +@dataclass(frozen=True) class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -69,7 +71,7 @@ def Pos(self) -> Axis: def Embed(self) -> Axis: pass - cross_entropy_block_size: Optional[int] = 64000 + cross_entropy_block_size: Optional[int] = None """ The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 154fc66ac..d705eda4d 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -58,7 +58,9 @@ def next_token_loss( if block_size is None: # Full softmax computation - logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) + if dtype is not None: + logits = logits.astype(dtype) target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) return cross_entropy_and_logsumexp_penalty( logits, @@ -261,9 +263,10 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] + + if dtype is not None: + logits_b = logits_b.astype(dtype) # Update max and logsumexp max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq] @@ -278,7 +281,7 @@ def process_block(block_idx, acc, current_block_size): # Update sumV. This is actually unnecessary if we're using one-hot targets # sV = sV_prev + hax.sum(target_y_b, axis=Label.name) - loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq] + loss += hax.dot(logits_b, target_y_b, axis=Block) # [Batch, Seq] return loss, logsumexp, max_logit # , sV @@ -351,7 +354,7 @@ def _block_cross_entropy_backward( num_blocks = vocab_size // block_size grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) - grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype) def process_block(block_idx, acc, current_block_size): """ @@ -372,14 +375,15 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] # Materialize the target for the current block (one-hot) target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] # materialize the softmax for the current block + if dtype is not None: + logits_b = logits_b.astype(dtype) + p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block] delta_b = p_b - target_y_block diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 0809d9d23..97b61f1dc 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -107,7 +107,7 @@ def from_hf(config: HfMptAttentionConfig): @LmConfig.register_subclass("mpt") -@dataclass +@dataclass(frozen=True) class MptConfig(HFCompatConfig): d_model: int = 768 n_heads: int = 12 diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 360c32171..e819d6459 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}") return + def finish(self): + self.writer.close() + @TrackerConfig.register_subclass("tensorboard") @dataclass diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 8b6816f17..99fd217e5 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + @abc.abstractmethod + def finish(self): + """ + Finish the tracker. This is called when the tracker is no longer needed. This can, e.g., + force a commit of all metrics. + """ + pass + def __enter__(self): import levanter.tracker.tracker_fns as tracker_fns @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio for tracker in self.loggers: tracker.log_artifact(artifact_path, name=name, type=type) + def finish(self): + excs = [] + for tracker in self.loggers: + try: + tracker.finish() + except Exception as e: + excs.append(e) + + if excs: + raise RuntimeError("Errors occurred when finishing trackers") from excs[0] + class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + def finish(self): + pass + @TrackerConfig.register_subclass("noop") @dataclasses.dataclass diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 18f0251ec..981bebf83 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): self.run.log_artifact(artifact_path, name=name, type=type) + def finish(self): + logger.info("Finishing wandb run...") + self.run.finish() + def is_wandb_available(): try: