diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index fb3211356..d8bbc1ba8 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -759,6 +759,9 @@ def _tpu_splash_attention( q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) + # pre-divide q_ by sqrt(d) to match the reference implementation + query = query / jnp.sqrt(query.resolve_axis(Key).size) + q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class, output_order=list("BHSD")).array k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array diff --git a/tests/test_attention.py b/tests/test_attention.py index 5677faa10..6d95316fc 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -207,8 +207,10 @@ def test_tpu_splash_attention(): k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Head, Key)) v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Head, Key)) - flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True) - hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v) + mask = AttentionMask.causal() + + flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE) + hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) assert hax_out.axes == flash_out.axes assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3)