Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 6, 2024
1 parent 05afef0 commit 0de1482
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/levanter/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def next_token_loss(
logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype)
target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype)
return cross_entropy_and_logsumexp_penalty(
logits, Vocab, target_y_full,
logits,
Vocab,
target_y_full,
reduction=reduction,
reduction_axis=reduction_axis,
where=loss_mask,
logsumexp_weight=logsumexp_weight
logsumexp_weight=logsumexp_weight,
)

# Compute the loss with optional block-wise processing
Expand Down Expand Up @@ -398,7 +400,9 @@ def process_block(block_idx, acc, current_block_size):

# Compute gradients for the current block
# embeddings has shape [Batch, Seq, Embed], so we need to eliminate Block
g_embeddings_b = hax.dot(dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype) # [Batch, Seq, Embed]
g_embeddings_b = hax.dot(
dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype
) # [Batch, Seq, Embed]

# lm_head has shape [Block, Embed], so we need to eliminate Batch, Seq, etc.
eliminated_axes_W = hax.axis.without_axes(pred_embeddings.axes, lm_head_b.axes)
Expand Down

0 comments on commit 0de1482

Please sign in to comment.