Skip to content

Commit

Permalink
jkandcjkancjka
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed May 20, 2024
1 parent 1866b96 commit 7d21889
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tpu_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
runs-on: ubuntu-latest
env:
TPU_ZONE: "us-central1-b"
TPU_ZONE: "us-central2-b"

steps:
- name: Checkout code
Expand Down
30 changes: 29 additions & 1 deletion tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import jax
import jax.numpy as jnp
import pytest
from chex import assert_trees_all_close

import haliax as hax

from levanter.models.attention import AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention
from levanter.models.attention import (
AttentionMask,
_bin_and_group_axes_by_function,
_te_flash_attention,
_tpu_splash_attention,
)
from test_utils import skip_if_module_missing


Expand Down Expand Up @@ -183,3 +189,25 @@ def test_gpt2_attention_uses_te():
attention_dtype=jnp.bfloat16,
)
assert_trees_all_close(out.array, 0.0)


def test_tpu_splash_attention():
if jax.default_backend() != "tpu":
pytest.skip("TPU only")

BLOCK_SIZE = 512

Head = hax.Axis("Head", 8)
Key = hax.Axis("Key", 128) # splash only supports 128
QPos = hax.Axis("QPos", BLOCK_SIZE * 2)
KPos = hax.Axis("KPos", BLOCK_SIZE * 2)

q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Head, Key))
k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Head, Key))
v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Head, Key))

flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True)
hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v)

assert hax_out.axes == flash_out.axes
assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3)

0 comments on commit 7d21889

Please sign in to comment.