diff --git a/config/llama3_small_fast.yaml b/config/llama3_small_fast.yaml new file mode 100644 index 000000000..df1df9f96 --- /dev/null +++ b/config/llama3_small_fast.yaml @@ -0,0 +1,32 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/" + tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" +model: + type: llama + hidden_dim: 768 + intermediate_dim: 2048 + num_heads: 12 + num_kv_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true +trainer: + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "llama", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/config/llama_7b_with_dclm.yaml b/config/llama_7b_with_dclm.yaml index 980e64e41..11a182f09 100644 --- a/config/llama_7b_with_dclm.yaml +++ b/config/llama_7b_with_dclm.yaml @@ -17,7 +17,7 @@ trainer: mp: p=f32,c=bfloat16 train_batch_size: 2048 - num_train_steps: 70000 # 280B / 4M + num_train_steps: 480000 # 2T / 4M steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index b04648079..1a9342c54 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): # ray doesn't merge the runtime envs properly, so we have to do it ourselves # we need to do a deep merge - runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None] + runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE) remote_fn = remote_fn.options( runtime_env=runtime_env, diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 2a955395f..4de8accc7 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -401,7 +401,7 @@ def init(Vocab: Axis, config: BackpackConfig, *, key): ) @named_call - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: k_embed, k_transformer, k_senses, k_sa = haliax.jax_utils.maybe_rng_split(key, 4) @@ -428,9 +428,10 @@ def __call__( scale = self.config.Senses.size hidden_states = hidden_states / scale - lm_logits = self.embeddings.unembed(hidden_states) + return hidden_states - return lm_logits + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None): new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index af5cc44be..c38acf5ef 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -339,6 +339,9 @@ def vocab_size(self) -> int: def Vocab(self) -> Axis: return self.embeddings.Vocab + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings.weight + @classmethod def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) @@ -346,7 +349,7 @@ def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel": embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) return GemmaLMHeadModel(transformer, embeddings) - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, @@ -363,8 +366,7 @@ def __call__( """ x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=key) - lm_logits = self.embeddings.unembed(x) - return lm_logits + return x def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]": new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index a921074e9..28e878193 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -391,15 +391,17 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel": return Gpt2LMHeadModel(transformer, embeddings) - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: k_embed, k_transformer = haliax.jax_utils.maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids, key=k_embed) x = self.transformer(x, attn_mask, key=k_transformer) - lm_logits = self.embeddings.unembed(x) - return lm_logits + return x + + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings.weight def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "Gpt2LMHeadModel": new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 1e09ffbc5..85861da6a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -557,6 +557,31 @@ def __call__( lm_logits = self.embeddings.unembed(x) return lm_logits + def activations( + self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None + ) -> NamedArray: + """ + Compute the activations for the next token in a sequence. + Args: + input_ids: token IDs with shape {Pos} + attn_mask: attention mask with shape {Pos, KeyPos} + key: PRNGKey for random number generation + + Returns: + NamedArray: activations with shape {Pos, Embed} + + """ + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=key) + + return x + + def get_lm_head(self) -> hax.NamedArray: + if self.lm_head is None: + return self.embeddings.token_embeddings.weight + else: + return self.lm_head.weight + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 468f6a4a4..911e74b09 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -64,6 +64,18 @@ def KeyPos(self) -> Axis: def Pos(self) -> Axis: pass + @property + @abc.abstractmethod + def Embed(self) -> Axis: + pass + + cross_entropy_block_size: Optional[int] = 64000 + """ + The block size for computing cross-entropy loss. This is the number of tokens that are processed together + in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large + value because it usually faster to compute the loss in larger blocks. + """ + def flops_per_token(self, vocab_size: int) -> Optional[float]: return None @@ -94,17 +106,58 @@ def Pos(self) -> Axis: def KeyPos(self) -> Axis: return self.config.KeyPos + @property + def Embed(self) -> Axis: + return self.config.Embed + @classmethod @abc.abstractmethod def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[LmConfigT]": pass - @abc.abstractmethod def __call__( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: + """ + Compute the logits for the next token in a sequence. + Args: + input_ids: token IDs with shape [..., Pos] + attn_mask: attention mask with shape [..., Pos, KeyPos] + key: PRNGKey for random number generation + + Returns: + NamedArray: logits with shape [..., Pos, Vocab] + + """ + x = self.activations(input_ids, attn_mask, key=key) + lm_logits = hax.dot(x, self.get_lm_head(), axis=self.Embed) + + return lm_logits + + @abc.abstractmethod + def activations( + self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None + ) -> NamedArray: + """ + Compute the activations for the next token in a sequence. + Args: + input_ids: token IDs with shape {Pos} + attn_mask: attention mask with shape {Pos, KeyPos} + key: PRNGKey for random number generation + + Returns: + NamedArray: activations with shape {Pos, Embed} + + """ pass + @abc.abstractmethod + def get_lm_head(self) -> hax.NamedArray: + """ + The language modeling head of the model. Should have shape {Embed, Vocab}. + """ + raise NotImplementedError("get_lm_head not implemented") + @abc.abstractmethod def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]": """ @@ -133,19 +186,21 @@ def compute_next_token_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = model(example.tokens, example.attn_mask, key=key) - if loss_dtype is not None: - logits = logits.astype(loss_dtype) + activations = model.activations(example.tokens, example.attn_mask, key=key) loss = next_token_loss( model.Pos, + model.Embed, model.Vocab, - logits, + activations, + model.get_lm_head(), example.tokens, loss_mask=example.loss_mask, reduction=reduction, reduction_axis=reduction_axis, logsumexp_weight=logsumexp_weight, + dtype=loss_dtype, + block_size=model.config.cross_entropy_block_size, ) return loss diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 1ef7e81f9..154fc66ac 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -1,5 +1,8 @@ +import functools from typing import Optional +import equinox +import jax import jax.numpy as jnp import haliax as hax @@ -9,34 +12,77 @@ def next_token_loss( Pos: hax.AxisSelector, + Embed: hax.AxisSelector, Vocab: hax.AxisSelector, - pred_ids: NamedArray, + pred_embeddings: NamedArray, + pred_lm_head: NamedArray, true_ids: NamedArray, loss_mask: Optional[NamedArray] = None, reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, logsumexp_weight: Optional[float] = None, -): - Pos, Vocab = pred_ids.resolve_axis((Pos, Vocab)) - # need to roll the target tokens back by one so that each token is predicting the next token + block_size: Optional[int] = None, + dtype: Optional[jnp.dtype] = jnp.float32, +) -> NamedArray: + """ + Compute the next token loss with optional block-wise processing. + + Args: + Pos (hax.AxisSelector): Position axis selector. + Vocab (hax.AxisSelector): Vocabulary axis selector. + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + true_ids (NamedArray): True token IDs. + loss_mask (Optional[NamedArray]): Mask to apply to the loss. + reduction (Optional[hax.ReductionFunction]): Reduction function. + reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction. + logsumexp_weight (Optional[float]): Weight for logsumexp penalty. + block_size (Optional[int]): Size of each block for processing. + + Returns: + NamedArray: Computed loss. + """ + # Resolve axes + Pos = pred_embeddings.resolve_axis(Pos) + Vocab = pred_lm_head.resolve_axis(Vocab) + + # Shift target tokens to predict the next token target_y = hax.roll(true_ids, -1, Pos) - target_y = hax.nn.one_hot(target_y, Vocab, dtype=pred_ids.dtype) # type: ignore - # one everywhere except the last token + # Create a mask that excludes the last token not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore if loss_mask is not None: loss_mask = loss_mask * not_last_loss_mask else: loss_mask = not_last_loss_mask - return cross_entropy_and_logsumexp_penalty( - pred_ids, - Vocab, - target_y, + if block_size is None: + # Full softmax computation + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) + target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) + return cross_entropy_and_logsumexp_penalty( + logits, + Vocab, + target_y_full, + reduction=reduction, + reduction_axis=reduction_axis, + where=loss_mask, + logsumexp_weight=logsumexp_weight, + ) + + # Compute the loss with optional block-wise processing + return fused_cross_entropy_loss_and_logsumexp_penalty( + pred_embeddings, + pred_lm_head, + Contract=Embed, + Label=Vocab, + target_y=target_y, reduction=reduction, reduction_axis=reduction_axis, where=loss_mask, logsumexp_weight=logsumexp_weight, + block_size=block_size, + dtype=dtype, ) @@ -58,3 +104,345 @@ def cross_entropy_and_logsumexp_penalty( loss = loss + logsumexp_weight * (log_normalizers**2) return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + + +def fused_cross_entropy_loss_and_logsumexp_penalty( + pred_embeddings: NamedArray, + pred_lm_head: NamedArray, + Contract: hax.AxisSelector, + Label: hax.AxisSelector, + target_y: NamedArray, + *, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + where: Optional[NamedArray] = None, + logsumexp_weight: float | None = 0.0, + block_size: int, + dtype: Optional[jnp.dtype] = jnp.float32, +) -> NamedArray: + """ + Compute the cross-entropy loss and logsumexp penalty using embeddings and lm_head, + with optional block-wise processing. + + Args: + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + Contract (hax.AxisSelector): Axis to contract over. + Label (hax.AxisSelector): Label (Vocab) axis. + target_y (NamedArray): One-hot encoded target tokens. + reduction (Optional[hax.ReductionFunction]): Reduction function. + reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction. + where (Optional[NamedArray]): Mask to apply to the loss. + logsumexp_weight (float): Weight for logsumexp penalty. + block_size (int): Size of each block for processing. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Returns: + NamedArray: Computed loss. + """ + + # Block-wise softmax computation + loss, log_normalizers = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), Contract, Label, target_y, block_size, dtype=dtype + ) + + if logsumexp_weight is not None and (not isinstance(logsumexp_weight, (int, float)) or logsumexp_weight != 0.0): + loss = loss + logsumexp_weight * (log_normalizers**2) + + return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + + +@equinox.filter_custom_vjp +def _blockwise_cross_entropy_loss( + # pred_embeddings: NamedArray, + # pred_lm_head: NamedArray, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[NamedArray, NamedArray]: + """ + Compute cross-entropy loss and log normalizers in a block-wise manner without materializing the full logits. + + Args: + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + Contract (hax.Axis): Axis to contract over. + Label (hax.AxisSelector): Label (Vocab) axis. + labels_y (NamedArray): label tensor. + block_size (int): Size of each block for processing. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Notes: + labels_y being anything other than the label tensor would remove any benefits + + TODO: but if XLA smart enough to optimize it out? + + Returns: + tuple[NamedArray, NamedArray]: tuple of loss and log_normalizers. + """ + + return _block_cross_entropy_forward(None, pred, Contract, Label, labels_y, block_size, dtype)[0] + + +def _block_cross_entropy_forward( + ignore, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[tuple[NamedArray, NamedArray], tuple[NamedArray]]: + """ + Forward pass for block-wise cross-entropy loss. + + This function computes the cross-entropy loss and log-sum-exp (`log_z`) in a block-wise manner + to maintain memory efficiency by processing subsets of the vocabulary at a time. + + Args: + ignore: Placeholder argument (unused). + pred (Tuple[NamedArray, NamedArray]): Tuple containing predicted embeddings and language model head weights. + Contract (hax.Axis): Axis to contract over (e.g., embedding axis). + Label (hax.Axis): Label axis (e.g., vocabulary axis). + labels_y (NamedArray): True target labels [Batch, Seq]. + block_size (int): Number of vocabulary tokens per block. + dtype (Optional[jnp.dtype]): Data type for the computations. + + Returns: + Tuple: + - Tuple[NamedArray, NamedArray]: Computed loss and logsumexp. + - Tuple[NamedArray]: Residuals needed for the backward pass. + """ + vocab_size = Label.size + + pred_embeddings, pred_lm_head = pred + + # + # if num_blocks == 1: + # # No need for block-wise processing + # logits = hax.dot(pred_embeddings, pred_lm_head, axis=Contract) + # labels_y = hax.nn.one_hot(labels_y, Label, dtype=pred_embeddings.dtype) + # return cross_entropy_loss_and_log_normalizers(logits, Label, labels_y) + # + # ensure block size divides vocab size + if vocab_size % block_size != 0: + has_stragglers = True + else: + has_stragglers = False + + num_blocks = vocab_size // block_size + + # Initialize accumulators: loss, logsumexp, max_logits + initial_O = hax.zeros(labels_y.axes) + initial_logsumexp = hax.full(labels_y.axes, -jnp.inf) + initial_max = hax.full(labels_y.axes, -jnp.inf) + # We don't need this b/c we're using one-hot targets + # initial_sumV = hax.full(labels_y.axes, 0.0) + + def process_block(block_idx, acc, current_block_size): + """ + Process a single block of the Vocab dimension. + + Args: + block_idx (int): Index of the current block. + acc (tuple[NamedArray, NamedArray, jnp.ndarray]): Accumulators for loss, logsumexp, and max logits. + current_block_size (int): Size of the current block (used for stragglers). + + Returns: + tuple[NamedArray, NamedArray, jnp.ndarray]: Updated accumulators + """ + loss, logsumexp_prev, max_logit_prev = acc + + start = block_idx * block_size + Block = Label.resize(current_block_size) + + # Materialize the logits for the current block + lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] + logits_b = hax.dot( + pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype + ) # [Batch, Seq, Block] + + # Update max and logsumexp + max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq] + # reweight the previous logsumexp by the new max, fold in the new logits' contribution + logsumexp = max_logit + hax.log( + hax.exp(logsumexp_prev - max_logit) + hax.sum(hax.exp(logits_b - max_logit), axis=Block) + ) # [Batch, Seq] + + # Materialize the target for the current block (one-hot) + target_y_b = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] + + # Update sumV. This is actually unnecessary if we're using one-hot targets + # sV = sV_prev + hax.sum(target_y_b, axis=Label.name) + + loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq] + + return loss, logsumexp, max_logit # , sV + + if num_blocks == 0: + o = initial_O + log_z = initial_logsumexp + max_logits = initial_max + elif num_blocks == 1: + o, log_z, max_logits = process_block(0, (initial_O, initial_logsumexp, initial_max), vocab_size) + else: + (o, log_z, max_logits) = jax.lax.fori_loop( + lower=0, + upper=num_blocks, + body_fun=functools.partial(process_block, current_block_size=block_size), + init_val=(initial_O, initial_logsumexp, initial_max), # , initial_sumV + ) + + if has_stragglers: + # Handle the stragglers + remainder_size = vocab_size - num_blocks * block_size + o, log_z, _ = process_block(num_blocks, (o, log_z, max_logits), remainder_size) + + # unnecessary if we're using one-hot targets + # logz_outer = hax.einsum("->...", log_z, sum_v) + o = log_z - o + + return (o, log_z), (log_z,) + + +def _block_cross_entropy_backward( + residuals: tuple[NamedArray,], + grad_in: tuple[NamedArray, NamedArray], + ignore, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[NamedArray, NamedArray]: + """ + Compute the gradients of the block-wise cross-entropy loss. + + Args: + residuals (tuple[NamedArray, NamedArray]): Residuals from the forward pass. + grad_in (tuple[NamedArray, NamedArray]): Incoming gradients. + pred (tuple[NamedArray, NamedArray]): Predictions. + Contract (hax.Axis): Axis to contract over. + Label (hax.Axis): Label axis. + labels_y (NamedArray): Target labels. + block_size (int): Size of each block. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Returns: + tuple[NamedArray, NamedArray]: Gradients. + """ + + (log_z,) = residuals + grad_loss, grad_log_z = grad_in + + vocab_size = Label.size + + pred_embeddings, pred_lm_head = pred + + if vocab_size % block_size != 0: + has_stragglers = True + else: + has_stragglers = False + + num_blocks = vocab_size // block_size + + grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype) + + def process_block(block_idx, acc, current_block_size): + """ + Process a single block of the Vocab dimension. + + Args: + block_idx (int): Index of the current block. + acc (tuple[NamedArray, NamedArray]): Accumulators for gradients. + current_block_size (int): Size of the current block (used for stragglers). + + Returns: + tuple[NamedArray, NamedArray]: Updated accumulators. + """ + grad_embeddings_prev, grad_lm_head_prev = acc + + start = block_idx * block_size + Block = Label.resize(current_block_size) + + # Materialize the logits for the current block + lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] + logits_b = hax.dot( + pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype + ) # [Batch, Seq, Block] + + # Materialize the target for the current block (one-hot) + target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] + + # materialize the softmax for the current block + p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block] + + delta_b = p_b - target_y_block + + # # dLoss/dL = g_loss * delta_b + g_log_z * probs_b + # # = g_loss * (probs_b - Y) + g_log_z * probs_b + # # = (g_loss + g_log_z) * probs_b - g_loss * Y + + # Compute gradients. We get None if the gradient is not provided. + if grad_loss.array is not None: + dLoss = grad_loss * delta_b # [Batch, Seq, Block] + else: + dLoss = 0.0 + + # Add the gradient of the logsumexp term (should be None if not provided) + if grad_log_z.array is not None: + dLoss += grad_log_z * p_b # [Batch, Seq, Block] + + # Compute gradients for the current block + # embeddings has shape [Batch, Seq, Embed], so we need to eliminate Block + g_embeddings_b = hax.dot( + dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype + ) # [Batch, Seq, Embed] + + # lm_head has shape [Block, Embed], so we need to eliminate Batch, Seq, etc. + eliminated_axes_W = hax.axis.without_axes(pred_embeddings.axes, lm_head_b.axes) + g_lm_head_b = hax.dot( + dLoss, pred_embeddings, axis=eliminated_axes_W, preferred_element_type=grad_lm_head_prev.dtype + ) # [Block, Embed] + + g_lm_head = grad_lm_head_prev.at[Label, hax.dslice(start, Block)].set(g_lm_head_b) + g_embeddings = grad_embeddings_prev + g_embeddings_b + + return g_embeddings, g_lm_head + + if num_blocks == 0: + pass + elif num_blocks == 1: + grad_embeddings, grad_lm_head = process_block(0, (grad_embeddings, grad_lm_head), vocab_size) + else: + grad_embeddings, grad_lm_head = jax.lax.fori_loop( + lower=0, + upper=num_blocks, + body_fun=functools.partial(process_block, current_block_size=block_size), + init_val=(grad_embeddings, grad_lm_head), + ) + + if has_stragglers: + # Handle the stragglers + remainder_size = vocab_size - num_blocks * block_size + grad_embeddings, grad_lm_head = process_block(num_blocks, (grad_embeddings, grad_lm_head), remainder_size) + + return grad_embeddings.astype(pred_embeddings.dtype), grad_lm_head.astype(pred_lm_head.dtype) + + +_blockwise_cross_entropy_loss.def_fwd(_block_cross_entropy_forward) +_blockwise_cross_entropy_loss.def_bwd(_block_cross_entropy_backward) + + +def _block_one_hot(LBlock, block_start, labels, dtype): + end = block_start + LBlock.size + target_is_in_this_block = hax.logical_and(labels >= block_start, labels < end) + target_y_block = hax.nn.one_hot(labels - block_start, LBlock, dtype=dtype) + # 0 out the logits that are not in this block + target_y_block *= target_is_in_this_block + return target_y_block diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index b48bfbe91..764e18aea 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -175,7 +175,11 @@ def init(cls, Vocab: Axis, config: MistralConfig, *, key) -> "MistralLMHeadModel lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) return MistralLMHeadModel(transformer, embeddings, lm_head) - def __call__( + def get_lm_head(self) -> hax.NamedArray: + assert self.lm_head.bias is None + return self.lm_head.weight + + def activations( self, input_ids: NamedArray, attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, @@ -193,8 +197,7 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - lm_logits = self.lm_head(x, key=k_head) - return lm_logits + return x def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]": new_Vocab = self.Vocab.resize(new_size) diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 00044a4ed..0809d9d23 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -447,14 +447,15 @@ def init(cls, Vocab: Axis, config: MptConfig, *, key): return MptLmHeadModel(wte, transformer, config) @named_call - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray], *, key=None ) -> NamedArray: hidden_states = self.wte.embed(input_ids) hidden_states = self.transformer(hidden_states, attention_mask=attn_mask, key=key) - output_logits = self.wte.unembed(hidden_states) + return hidden_states - return output_logits + def get_lm_head(self) -> hax.NamedArray: + return self.wte.weight def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "MptLmHeadModel": if new_size == self.vocab_size: diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 5e8657fc0..558bbfceb 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -840,7 +840,7 @@ class _ShardFinished: path_to_shard: str -@ray.remote(num_cpus=1) +@ray.remote(num_cpus=1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"})) def _core_writer_task( parent, cache_dir, diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 8e98eaedb..92d7af4ac 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -499,7 +499,8 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwa grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) mbs = self.config.microbatch_size grad_fn = microbatched(grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping) - return grad_fn(model, *batch, **batch_kwargs) + with hax.axis_mapping(self.compute_axis_mapping): + return grad_fn(model, *batch, **batch_kwargs) def _initialize_global_tracker(config, run_id): diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 7a5475738..a0002b1c1 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -19,7 +19,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel -from levanter.models.loss import next_token_loss +from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.optim import AdamConfig from levanter.utils.tree_utils import inference_mode from test_utils import arrays_only, skip_if_no_torch @@ -132,12 +132,10 @@ def torch_loss(model, input_ids) -> torch.Tensor: return model(input_ids, labels=input_ids)[0] torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input.array)).to(torch.int64).unsqueeze(0)) - causal_mask = AttentionMask.causal() - def compute_loss(model, input_ids): - pred_y = model(input_ids, key=None, attn_mask=causal_mask) - - return next_token_loss(model.Pos, model.Vocab, pred_y, input_ids).scalar() + def compute_loss(model: LmHeadModel, input_ids): + example = LmExample.causal(input_ids) + return compute_next_token_loss(model, example, key=None).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) jax_grad: Gpt2LMHeadModel diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 000000000..30d140ede --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,325 @@ +# test_cross_entropy.py +import math + +import equinox +import jax.numpy as jnp +import jax.random +import pytest + +import haliax as hax +from haliax import NamedArray + +# Import the functions from your module +# Replace 'your_module' with the actual module name where your functions are defined +from levanter.models.loss import _blockwise_cross_entropy_loss, cross_entropy_loss_and_log_normalizers +from levanter.utils.jax_utils import key_iterator + + +Batch = hax.Axis("batch", size=2) +Seq = hax.Axis("seq", size=3) +Embed = hax.Axis("embed", size=8) +Vocab = hax.Axis("vocab", size=16) + + +@pytest.fixture +def test_data(): + """ + Create synthetic test data for cross-entropy loss computation. + """ + + key = key_iterator(jax.random.PRNGKey(0)) + + # Initialize pred_embeddings with ones + pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) + + # Initialize pred_lm_head with ones + pred_lm_head = hax.random.normal(next(key), (Vocab, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) + + # Define true_ids such that the target is always the first token in vocab + true_ids = hax.random.randint(next(key), (Batch, Seq), 0, Vocab.size) + + return pred_embeddings, pred_lm_head, true_ids + + +def test_basic_equivalence(test_data): + """ + Test that block-wise loss equals full loss when block_size perfectly divides vocab_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss_full, norm_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full) + + loss_block, norm_this = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=8, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." + + +def test_single_block(test_data): + """ + Test behavior when vocab_size equals block_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + loss_full, sumexp_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids) + + # Compute block-wise loss with block_size=4 (vocab_size=4) + with jax.disable_jit(): + loss_block, sumexp_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=Vocab.size, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(sumexp_full, sumexp_block, atol=1e-3, rtol=1e-3) + ), "Single block-wise sumexp does not match full sumexp." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Single block-wise loss does not match full loss." + + +def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids): + logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss_full, sumexp_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full) + return loss_full, sumexp_full + + +def test_multiple_blocks(test_data): + """ + Test block-wise loss with multiple blocks. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + loss_full, logz_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids) + + # Compute block-wise loss with block_size=1 (vocab_size=4) + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=1, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise logz does not match full logz." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise loss does not match full loss." + + +def test_block_size_not_dividing_vocab(test_data): + pred_embeddings, pred_lm_head, true_ids = test_data + + # Set block_size that does not divide vocab_size + block_size = 3 # vocab_size=4 + + # should be fine now + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Block-wise logz does not match full logz." + + +def test_vocab_size_less_than_block_size(test_data): + """ + Test behavior when vocab_size is less than block_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Set block_size greater than vocab_size + block_size = 5 # vocab_size=4 + + # should be fine now + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Assert that the losses are close + assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)), "loss does not match full loss." + assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)), "logz does not match full logz." + + +def test_large_vocab(): + """ + Test block-wise loss with a larger vocabulary. + """ + Batch = hax.Axis("batch", size=4) + Seq = hax.Axis("seq", size=5) + Embed = hax.Axis("embed", size=6) + Vocab = hax.Axis("vocab", size=12) + + pred_embeddings = NamedArray( + jnp.ones((Batch.size, Seq.size, Embed.size)), + axes=(Batch, Seq, Embed), + ) + pred_lm_head = NamedArray( + jnp.ones((Embed.size, Vocab.size)), + axes=(Embed, Vocab), + ) + true_ids = NamedArray( + jnp.zeros((Batch.size, Seq.size), dtype=jnp.int32), + axes=(Batch, Seq), + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Compute block-wise loss with block_size=3 (vocab_size=12 is divisible by 3) + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=3, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Large vocab block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Large vocab block-wise logz does not match full logz." + + +@pytest.mark.parametrize("block_size", [1, 2, 3, 4, 5]) +def test_gradient_block_cross_entropy(block_size, test_data): + """ + Test the gradient of block-wise cross-entropy loss. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute block-wise loss + def custom_fn(pred): + pred_embeddings, pred_lm_head = pred + a, b = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + return (a.mean() + b.mean()).scalar() + + g_embed, g_head, = equinox.filter_grad( + custom_fn + )((pred_embeddings, pred_lm_head)) + + # compute directly + + def direct_fn(pred): + pred_embeddings, pred_lm_head = pred + logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss, logz = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y) + return (loss.mean() + logz.mean()).scalar() + + g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) + + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match." + + +def test_grad_loss_without_logz(test_data): + """ + Test the gradient of block-wise cross-entropy loss without logz. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute block-wise loss + def custom_fn(pred): + pred_embeddings, pred_lm_head = pred + a, b = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=2, + dtype=pred_embeddings.dtype, + ) + + return a.mean().scalar() + + g_embed, g_head, = equinox.filter_grad( + custom_fn + )((pred_embeddings, pred_lm_head)) + + # compute directly + + def direct_fn(pred): + pred_embeddings, pred_lm_head = pred + logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed", preferred_element_type=pred_embeddings.dtype) + target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss, _ = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y) + return loss.mean().scalar() + + g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) + + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match." diff --git a/tests/test_text.py b/tests/test_text.py index a2645c1f9..e4e51acbc 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -26,6 +26,7 @@ def test_dont_blow_up_without_validation_set(): def test_lm_example_handles_ignore_id(): Pos = hax.Axis("Pos", 10) Vocab = hax.Axis("vocab", Pos.size + 1) + Embed = hax.Axis("embed", 10) tokens = hax.arange(Pos, dtype=jnp.int32) ignore_id = 6 @@ -34,11 +35,12 @@ def test_lm_example_handles_ignore_id(): ex_no_ignore = LmExample.causal(tokens) assert ex_ignore.loss_mask[Pos, ignore_id - 1] == 0 - distr = -100 * hax.nn.one_hot(ignore_id, Vocab) - distr = distr.broadcast_axis(Pos) + logits = hax.ones((Pos, Embed)) + lm_head = hax.zeros((Embed, Vocab)) + lm_head = lm_head.at[Vocab, ignore_id].set(-100) - ignored_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_ignore.loss_mask) - no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) + ignored_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask) + no_ignore_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size