diff --git a/tests/test_attention.py b/tests/test_attention.py index 4dde69579..140e3cf7c 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -4,6 +4,7 @@ import numpy as np import pytest from chex import assert_trees_all_close +from jax.sharding import Mesh import haliax as hax from haliax import Axis @@ -241,16 +242,21 @@ def test_segment_ids_are_respected(impl): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) - query, keys, values = jax.device_put([query, keys, values]) + query, keys, values = jax.device_put( + [query, keys, values], jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1) + ) segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) - segment_ids = jax.device_put(segment_ids) + segment_ids = jax.device_put(segment_ids, jax.sharding.PositionalSharding(jax.devices())) segment_ids = hax.named(segment_ids, (Pos,)) mask = AttentionMask(is_causal=True, segment_ids=segment_ids) - result = hax.named_jit(dot_product_attention)( - Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 - ) + devices = jax.devices() + + with Mesh(devices, ("dp",)): + result = hax.named_jit(dot_product_attention)( + Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 + ) # the first 3 positions should all have a value of 300.0 assert_trees_all_close(result.array[0:3, 1], 300.0, atol=1e-3, rtol=1e-3)