From 4533a48c84e834912050369792ce75c11e6e3046 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 20 May 2024 15:36:58 +0000 Subject: [PATCH] Fix bug in varlen test --- python/perf-kernels/flash-attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 9861ca28b9a8..1d87700c8632 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -833,6 +833,7 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, def get_shape_from_layout(q, k, metadata): if metadata.layout == 'thd': nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] batch = metadata.num_contexts elif metadata.layout == 'bhsd': batch, nheads_q, _, head_size = q.shape