Skip to content

Commit

Permalink
More datetype support for backwards kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
groenenboomj committed May 8, 2024
1 parent 375fa6c commit fa2e4e7
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _attn_fwd_inner(
# our optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += (bias * 1.44269504089)

if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M)
Expand Down Expand Up @@ -425,7 +425,7 @@ def attn_fwd(
bias_ptr = None

if USE_ALIBI != 0:
a_offset = off_z * stride_az + off_h_q * stride_ah
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
Expand Down Expand Up @@ -668,14 +668,14 @@ def _bwd_kernel_dk_dv(
do = tl.load(DO_block_ptr)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
ppT = ppT.to(do.dtype)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dsT = dsT.to(qT.dtype)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
Expand Down Expand Up @@ -733,7 +733,7 @@ def _bwd_kernel_dq(dq, q, K, V,
vT = tl.load(VT_block_ptr)
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
ds = ds.to(kT.dtype)
# Compute dQ.0.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
Expand Down Expand Up @@ -813,7 +813,7 @@ def _attn_bwd(Q, K, V, sm_scale,
v = tl.load(V_block_ptr)

num_steps = BLOCK_N1 // MASK_BLOCK_M1

dk, dv = _bwd_kernel_dk_dv(
dk, dv,
Q, k, v, sm_scale,
Expand Down Expand Up @@ -931,7 +931,7 @@ def _attn_bwd(Q, K, V, sm_scale,
)
dq *= LN2
tl.store(DQ_block_ptr, dq.to(q.dtype))

empty = torch.empty(128, device="cuda")


Expand All @@ -940,7 +940,7 @@ class _attention(torch.autograd.Function):
def forward(ctx, q, k, v, o, metadata):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (metadata.bias is not None):
assert(metadata.bias.numel() < 2 ** 31)
assert(metadata.bias.numel() < 2 ** 31)

if o is None:
o = torch.empty_like(q, dtype=v.dtype)
Expand Down Expand Up @@ -991,7 +991,7 @@ def forward(ctx, q, k, v, o, metadata):
metadata.bias.stride(2), metadata.bias.stride(3))
else:
bias_strides = (0,0,0,0)

if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1))
else:
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to

scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale
if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
diagonal=N_CTX_K-N_CTX_Q)
scores[:, :, mask==0] = float("-inf")
if use_alibi:
Expand Down Expand Up @@ -1343,7 +1343,8 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None])
@pytest.mark.parametrize('torch_sdpa_test', [True])
@pytest.mark.parametrize('causal', [False, True])
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, dtype):
torch.manual_seed(20)
if qseqlen_not_equal_kseqlen is not None:
seqlen_q = qseqlen_not_equal_kseqlen
Expand All @@ -1366,7 +1367,7 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd
k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
o = torch.empty_like(q)

if causal:
input_metadata.need_causal()

Expand Down Expand Up @@ -1527,7 +1528,7 @@ def bench_flash_attention(
# else:
# bias = None
bias = None

flops_per_matmul = 0
if varlen:
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype)
Expand Down

0 comments on commit fa2e4e7

Please sign in to comment.