Skip to content

Commit

Permalink
Simplify reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Feb 2, 2024
1 parent 041f2a4 commit 147d216
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions fla/ops/abc/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def chunk_abc_bwd_kernel_dqkv(

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, (i_t+1)*K), (s_h_d, s_h_t), (i_v * BV, i_t*K+i_k*BK), (BV, BK), (0, 1))
Expand All @@ -355,30 +356,30 @@ def chunk_abc_bwd_kernel_dqkv(
b_r = tl.load(p_r, boundary_check=(0,))
b_do = (b_do * b_z * b_r[None, :]).to(b_do.dtype)
# [BT, BT]
b_ds = tl.where(m_s, tl.dot(b_do, tl.trans(b_v), allow_tf32=False), 0).to(b_v.dtype)
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) + tl.dot(b_ds, b_k, allow_tf32=False)
b_dq += tl.dot(b_do, b_h, allow_tf32=False)

# [BT, BT]
b_ds = tl.trans(b_ds)
# [BT, BK]
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s, b_do, allow_tf32=False)
# the rescale term m cancels the denominator of either v or k out, so in general dk is safe
if NORMQ:
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) * b_k
b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) * b_k)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False))
if not NORMQ:
b_dv = b_v * b_dv
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

b_ds = tl.where(m_s, b_ds, 0.).to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_ds = tl.trans(b_ds).to(b_k.dtype)
if NORMQ:
p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_r = tl.make_block_ptr(r + i_bh * s_k_t * NT, (NT*K,), (s_k_d,), (i_t * K + i_k * BK,), (BK,), (0,))
b_z = tl.load(p_z, boundary_check=(0, 1)).to(tl.float32)
b_r = tl.load(p_r, boundary_check=(0,)).to(tl.float32)
b_dq = b_dq * b_z * b_r[None, :]
b_dk = b_dk * b_k + safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) * b_k)
else:
b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False))
b_dq = safe_exy(b_m, b_dq)

p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
Expand Down

0 comments on commit 147d216

Please sign in to comment.