Skip to content

Commit

Permalink
hrmph
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Apr 26, 2024
1 parent 1ad2a42 commit 5068dcf
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def _restore_named_axes(
attn_output = attn_output.unflatten_axis("S", q_class["S"])
attn_output = attn_output.unflatten_axis("H", q_class["H"])
attn_output = attn_output.unflatten_axis("D", v_class["D"])
output_axes = eqx.filter_eval_shape(
reference_out_shape = eqx.filter_eval_shape(
simple_attention_with_dropout,
QPos,
KPos,
Expand All @@ -828,6 +828,6 @@ def _restore_named_axes(
attention_dtype,
precision,
prng=prng,
).axes
attn_output = attn_output.rearrange(output_axes)
)
attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype)
return attn_output

0 comments on commit 5068dcf

Please sign in to comment.