Skip to content

Commit

Permalink
blech think i figured out splash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed May 21, 2024
1 parent b05e093 commit 39ec0d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ def _tpu_splash_attention(

q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key)

# pre-divide q_ by sqrt(d) to match the reference implementation
query = query / jnp.sqrt(query.resolve_axis(Key).size)

q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class, output_order=list("BHSD")).array
k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array
v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array
Expand Down
6 changes: 4 additions & 2 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ def test_tpu_splash_attention():
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Head, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Head, Key))

flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True)
hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v)
mask = AttentionMask.causal()

flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE)
hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos))

assert hax_out.axes == flash_out.axes
assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3)

0 comments on commit 39ec0d7

Please sign in to comment.