diff --git a/tests/test_attention.py b/tests/test_attention.py index 15dfbcebb..4dde69579 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -241,6 +241,8 @@ def test_segment_ids_are_respected(impl): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) + query, keys, values = jax.device_put([query, keys, values]) + segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) segment_ids = jax.device_put(segment_ids) segment_ids = hax.named(segment_ids, (Pos,))