diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 5ff969fa7..3251b47e3 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -813,7 +813,7 @@ def _restore_named_axes( attn_output = attn_output.unflatten_axis("S", q_class["S"]) attn_output = attn_output.unflatten_axis("H", q_class["H"]) attn_output = attn_output.unflatten_axis("D", v_class["D"]) - output_axes = eqx.filter_eval_shape( + reference_out_shape = eqx.filter_eval_shape( simple_attention_with_dropout, QPos, KPos, @@ -828,6 +828,6 @@ def _restore_named_axes( attention_dtype, precision, prng=prng, - ).axes - attn_output = attn_output.rearrange(output_axes) + ) + attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype) return attn_output