diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index cdc344c54..39f3cb3df 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -440,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes of Q, K, and V into the right bins to match that format. - NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD - the size of the axes is a bit flexible, - with the following conditions: + NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD. + + The size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) - S must be the same for K and V. Q's S can be different - H: Q's H must be a multiple of K's H (for GQA or MQA)