Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 2, 2024
1 parent 92247ac commit 381ce3a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key):
NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes
of Q, K, and V into the right bins to match that format.
NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD
the size of the axes is a bit flexible,
with the following conditions:
NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD.
The size of the axes is a bit flexible, with the following conditions:
- B must be the same for all (TODO: is this true?)
- S must be the same for K and V. Q's S can be different
- H: Q's H must be a multiple of K's H (for GQA or MQA)
Expand Down

0 comments on commit 381ce3a

Please sign in to comment.