From 381ce3ac0f0d9bcd7682721ae8e732120ef67165 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 09:02:42 -0800 Subject: [PATCH] cleanup --- src/levanter/models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index cdc344c54..39f3cb3df 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -440,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes of Q, K, and V into the right bins to match that format. - NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD - the size of the axes is a bit flexible, - with the following conditions: + NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD. + + The size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) - S must be the same for K and V. Q's S can be different - H: Q's H must be a multiple of K's H (for GQA or MQA)