diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 826480d09..154fc66ac 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -61,11 +61,13 @@ def next_token_loss( logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) return cross_entropy_and_logsumexp_penalty( - logits, Vocab, target_y_full, + logits, + Vocab, + target_y_full, reduction=reduction, reduction_axis=reduction_axis, where=loss_mask, - logsumexp_weight=logsumexp_weight + logsumexp_weight=logsumexp_weight, ) # Compute the loss with optional block-wise processing @@ -398,7 +400,9 @@ def process_block(block_idx, acc, current_block_size): # Compute gradients for the current block # embeddings has shape [Batch, Seq, Embed], so we need to eliminate Block - g_embeddings_b = hax.dot(dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype) # [Batch, Seq, Embed] + g_embeddings_b = hax.dot( + dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype + ) # [Batch, Seq, Embed] # lm_head has shape [Block, Embed], so we need to eliminate Batch, Seq, etc. eliminated_axes_W = hax.axis.without_axes(pred_embeddings.axes, lm_head_b.axes)