Skip to content

Commit

Permalink
this?
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 2, 2024
1 parent 4fff4c3 commit 5052f36
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def test_segment_ids_are_respected(impl):
keys = hax.named(keys, (KPos, Head))
values = hax.named(values, (KPos, Head))

query, keys, values = jax.device_put([query, keys, values])

segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32)
segment_ids = jax.device_put(segment_ids)
segment_ids = hax.named(segment_ids, (Pos,))
Expand Down

0 comments on commit 5052f36

Please sign in to comment.