From 7d218896e37e95229d2490fe54751980a5bbfc87 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 20 May 2024 15:55:36 -0700 Subject: [PATCH] jkandcjkancjka --- .github/workflows/tpu_unit_tests.yaml | 2 +- tests/test_attention.py | 30 ++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tpu_unit_tests.yaml b/.github/workflows/tpu_unit_tests.yaml index b2b8ef351..ab56c21b4 100644 --- a/.github/workflows/tpu_unit_tests.yaml +++ b/.github/workflows/tpu_unit_tests.yaml @@ -6,7 +6,7 @@ jobs: test: runs-on: ubuntu-latest env: - TPU_ZONE: "us-central1-b" + TPU_ZONE: "us-central2-b" steps: - name: Checkout code diff --git a/tests/test_attention.py b/tests/test_attention.py index c3a156892..1ece10b4b 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,10 +1,16 @@ +import jax import jax.numpy as jnp import pytest from chex import assert_trees_all_close import haliax as hax -from levanter.models.attention import AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention +from levanter.models.attention import ( + AttentionMask, + _bin_and_group_axes_by_function, + _te_flash_attention, + _tpu_splash_attention, +) from test_utils import skip_if_module_missing @@ -183,3 +189,25 @@ def test_gpt2_attention_uses_te(): attention_dtype=jnp.bfloat16, ) assert_trees_all_close(out.array, 0.0) + + +def test_tpu_splash_attention(): + if jax.default_backend() != "tpu": + pytest.skip("TPU only") + + BLOCK_SIZE = 512 + + Head = hax.Axis("Head", 8) + Key = hax.Axis("Key", 128) # splash only supports 128 + QPos = hax.Axis("QPos", BLOCK_SIZE * 2) + KPos = hax.Axis("KPos", BLOCK_SIZE * 2) + + q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Head, Key)) + k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Head, Key)) + v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Head, Key)) + + flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True) + hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v) + + assert hax_out.axes == flash_out.axes + assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3)