From 5052f363fde363129da2f1660e64ecbffd2014f0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 15:55:22 -0800 Subject: [PATCH] this? --- tests/test_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 15dfbcebb..4dde69579 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -241,6 +241,8 @@ def test_segment_ids_are_respected(impl): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) + query, keys, values = jax.device_put([query, keys, values]) + segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) segment_ids = jax.device_put(segment_ids) segment_ids = hax.named(segment_ids, (Pos,))