diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index bf0bd380e..fc6e424a1 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -43,7 +43,7 @@ def maybe_fused_next_token_loss( NamedArray: Computed loss. """ # Resolve axes - Pos = pred_embeddings.resolve_axis(Pos) + Pos = pred_embeddings.resolve_axis(Pos.name) Vocab = pred_lm_head.resolve_axis(Vocab) if block_size is None: