Skip to content

Commit

Permalink
[DeltaNet] Optimize WY ops when chunk_size=64
Browse files Browse the repository at this point in the history
  • Loading branch information
sustcsonglin committed Oct 26, 2024
1 parent 1491996 commit 7516c53
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 96 deletions.
24 changes: 10 additions & 14 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -80,7 +79,6 @@ def fwd_prepare_dv(q, k, do, BT, scale):
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -141,7 +139,7 @@ def chunk_delta_rule_fwd_kernel_h(
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
b_v -= tl.dot(b_d, b_h.to(b_k.dtype))
# [BK, BV]
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
Expand All @@ -157,7 +155,6 @@ def chunk_delta_rule_fwd_kernel_h(
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -218,8 +215,7 @@ def chunk_linear_attn_fwd_kernel_o(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
triton.Config({}, num_warps=4)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -307,9 +303,7 @@ def chunk_delta_rule_bwd_kernel_dhu(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16)
triton.Config({}, num_warps=4)
],
key=["BT", "BK", "BV"],
)
Expand Down Expand Up @@ -357,8 +351,8 @@ def chunk_delta_rule_bwd_kernel_dqkw(
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
Expand All @@ -371,7 +365,7 @@ def chunk_delta_rule_bwd_kernel_dqkw(
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dk += tl.dot(b_v, b_dh, allow_tf32=False)

b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
Expand Down Expand Up @@ -479,7 +473,7 @@ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT, scale):
NT = triton.cdiv(T, BT)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dk = torch.empty_like(k, dtype=torch.float32)
dw = torch.empty_like(w)
chunk_delta_rule_bwd_kernel_dqkw[grid](
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
Expand All @@ -489,7 +483,7 @@ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT, scale):
scale=scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
)
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
return dq, dk, dw


class ChunkDeltaRuleFunction(torch.autograd.Function):
Expand Down Expand Up @@ -539,6 +533,7 @@ def backward(ctx, do, dht):
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, dh0, None, None, None



def chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -572,9 +567,10 @@ def chunk_delta_rule(
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
assert BT in [32, 64], "ChunkDeltaRuleFunction only supports BT=32/64."
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive."
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, scale, initial_state, output_final_state)
return o, final_state
return o, final_state
113 changes: 91 additions & 22 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from einops import rearrange

from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
@triton.autotune(
configs=[
Expand All @@ -17,10 +15,10 @@
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16)
],
key=["BT", "BK", "BV"],
key=["BK"]
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
def fwd_prepare_wy_repr_kernel_chunk32(
k,
v,
beta,
Expand All @@ -38,6 +36,7 @@ def fwd_prepare_wy_repr_kernel(
V,
BT: tl.constexpr,
BK: tl.constexpr,
BC: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
Expand Down Expand Up @@ -66,22 +65,90 @@ def fwd_prepare_wy_repr_kernel(
tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
b_A = b_A.to(k.dtype.element_ty)

for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))

@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16)
],
key=["BK"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel_chunk64(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BC: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BC, BC], dtype=tl.float32)
b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BC,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))

p_beta2 = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
b_beta2 = tl.load(p_beta2, boundary_check=(0,))

for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
p_k2 = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
b_k2 = tl.load(p_k2, boundary_check=(0, 1))
b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype)
b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A2 += tl.dot(b_kb2, tl.trans(b_k2), allow_tf32=False)
b_A3 += tl.dot(b_kb2, tl.trans(b_k), allow_tf32=False)

b_A = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
b_A2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)

for i in range(1, BC):
mask = tl.arange(0, BC) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A2 = tl.where(mask[:, None], b_a2, b_A2)

# blockwise computation of lower triangular matrix's inverse
# i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
b_A3 = -tl.dot(tl.dot(b_A2, b_A3), b_A)

p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
# causal mask
tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))




@triton.autotune(
configs=[
Expand Down Expand Up @@ -244,8 +311,8 @@ def bwd_prepare_wy_repr_kernel(
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A, allow_tf32=False)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype), allow_tf32=False)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)

for i_k in range(tl.cdiv(K, BK)):
Expand Down Expand Up @@ -273,12 +340,15 @@ def fwd_prepare_wy_repr(k, v, beta, BT):
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32

fwd_fn[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
T, K, V, BT, BK, 32, BV
)
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
return w, u, A


Expand Down Expand Up @@ -313,7 +383,6 @@ def fwd_recompute_w(k, beta, A, BT):

def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
B, H, T, K, V = *k.shape, v.shape[-1]

NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
Expand All @@ -339,6 +408,7 @@ class WYRepresentationPrepration(torch.autograd.Function):
@contiguous
@autocast_custom_fwd
def forward(ctx, k, v, beta, chunk_size=64):
assert chunk_size in [16, 32, 64]
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
ctx.save_for_backward(k, v, beta, A)
Expand Down Expand Up @@ -419,5 +489,4 @@ def naive(k, v, beta, chunk_size):
k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
print((k_grad2-k_grad).abs().max())
print((v_grad2-v_grad).abs().max())
print((beta_grad2-beta_grad).abs().max())
breakpoint()
print((beta_grad2-beta_grad).abs().max())
Loading

0 comments on commit 7516c53

Please sign in to comment.