From 147d216a3dbba86e94ec6878e7bf16c37b7f8bad Mon Sep 17 00:00:00 2001 From: yzhangcs Date: Fri, 2 Feb 2024 13:12:18 +0800 Subject: [PATCH] Simplify reductions --- fla/ops/abc/chunk.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py index 7842383cd..3c6da5770 100644 --- a/fla/ops/abc/chunk.py +++ b/fla/ops/abc/chunk.py @@ -330,6 +330,7 @@ def chunk_abc_bwd_kernel_dqkv( b_dq = tl.zeros([BT, BK], dtype=tl.float32) b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, (i_t+1)*K), (s_h_d, s_h_t), (i_v * BV, i_t*K+i_k*BK), (BV, BK), (0, 1)) @@ -355,30 +356,30 @@ def chunk_abc_bwd_kernel_dqkv( b_r = tl.load(p_r, boundary_check=(0,)) b_do = (b_do * b_z * b_r[None, :]).to(b_do.dtype) # [BT, BT] - b_ds = tl.where(m_s, tl.dot(b_do, tl.trans(b_v), allow_tf32=False), 0).to(b_v.dtype) + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) # [BT, BK] - b_dq += tl.dot(b_do, b_h, allow_tf32=False) + tl.dot(b_ds, b_k, allow_tf32=False) + b_dq += tl.dot(b_do, b_h, allow_tf32=False) - # [BT, BT] - b_ds = tl.trans(b_ds) # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s, b_do, allow_tf32=False) # the rescale term m cancels the denominator of either v or k out, so in general dk is safe - if NORMQ: - b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) * b_k - b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) * b_k) - else: - b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) - b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)) + if not NORMQ: b_dv = b_v * b_dv tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) - + b_ds = tl.where(m_s, b_ds, 0.).to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_ds = tl.trans(b_ds).to(b_k.dtype) if NORMQ: p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_r = tl.make_block_ptr(r + i_bh * s_k_t * NT, (NT*K,), (s_k_d,), (i_t * K + i_k * BK,), (BK,), (0,)) b_z = tl.load(p_z, boundary_check=(0, 1)).to(tl.float32) b_r = tl.load(p_r, boundary_check=(0,)).to(tl.float32) b_dq = b_dq * b_z * b_r[None, :] + b_dk = b_dk * b_k + safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) * b_k) + else: + b_dk += safe_exy(b_m, tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)) b_dq = safe_exy(b_m, b_dq) p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))