Skip to content

Commit

Permalink
enough device puts and we're good
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 3, 2024
1 parent 5052f36 commit 1a7729a
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
from chex import assert_trees_all_close
from jax.sharding import Mesh

import haliax as hax
from haliax import Axis
Expand Down Expand Up @@ -241,16 +242,21 @@ def test_segment_ids_are_respected(impl):
keys = hax.named(keys, (KPos, Head))
values = hax.named(values, (KPos, Head))

query, keys, values = jax.device_put([query, keys, values])
query, keys, values = jax.device_put(
[query, keys, values], jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1)
)

segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32)
segment_ids = jax.device_put(segment_ids)
segment_ids = jax.device_put(segment_ids, jax.sharding.PositionalSharding(jax.devices()))
segment_ids = hax.named(segment_ids, (Pos,))
mask = AttentionMask(is_causal=True, segment_ids=segment_ids)

result = hax.named_jit(dot_product_attention)(
Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128
)
devices = jax.devices()

with Mesh(devices, ("dp",)):
result = hax.named_jit(dot_product_attention)(
Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128
)

# the first 3 positions should all have a value of 300.0
assert_trees_all_close(result.array[0:3, 1], 300.0, atol=1e-3, rtol=1e-3)
Expand Down

0 comments on commit 1a7729a

Please sign in to comment.