diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index c44bb0c6a112..122ef9e50d80 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -332,7 +332,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -880,8 +880,7 @@ def forward(ctx, q, k, v, o, metadata): metadata.check_args(q, k, v, o) batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( - q, k, v, o, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() @@ -1115,9 +1114,9 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, # Transpose here if layout is bshd so we have same reference code for all layouts if layout == 'bshd': - q = q.transpose(1,2).clone() - k = k.transpose(1,2).clone() - v = v.transpose(1,2).clone() + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], @@ -1142,7 +1141,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare if layout == 'bshd': - ref_out = ref_out.transpose(1,2).clone() + ref_out = ref_out.transpose(1, 2).clone() torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)