diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 8177cf4ebf30..988438340abe 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -301,35 +301,28 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - # num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, + cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -446,13 +439,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - q = (q * qk_scale).to(q.type.element_ty) + q = (q * QK_SCALE).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -509,7 +502,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL) # epilogue - acc = acc / l_i[:, None] + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, @@ -1198,6 +1194,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)