diff --git a/fla/ops/delta_rule/chunk.py b/fla/ops/delta_rule/chunk.py index 9a589d30d..fc19195d8 100644 --- a/fla/ops/delta_rule/chunk.py +++ b/fla/ops/delta_rule/chunk.py @@ -15,7 +15,6 @@ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8) ], key=["BT", "BK", "BV"], ) @@ -80,7 +79,6 @@ def fwd_prepare_dv(q, k, do, BT, scale): triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8) ], key=["BT", "BK", "BV"], ) @@ -141,7 +139,7 @@ def chunk_delta_rule_fwd_kernel_h( b_d = tl.load(p_d, boundary_check=(0, 1)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype)) # [BK, BV] tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) @@ -157,7 +155,6 @@ def chunk_delta_rule_fwd_kernel_h( triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8) ], key=["BT", "BK", "BV"], ) @@ -218,8 +215,7 @@ def chunk_linear_attn_fwd_kernel_o( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8) + triton.Config({}, num_warps=4) ], key=["BT", "BK", "BV"], ) @@ -307,9 +303,7 @@ def chunk_delta_rule_bwd_kernel_dhu( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16) + triton.Config({}, num_warps=4) ], key=["BT", "BK", "BV"], ) @@ -357,8 +351,8 @@ def chunk_delta_rule_bwd_kernel_dqkw( for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) @@ -371,7 +365,7 @@ def chunk_delta_rule_bwd_kernel_dqkw( b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) # [BT, BK] b_dq += tl.dot(b_do, b_h, allow_tf32=False) - b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + b_dk += tl.dot(b_v, b_dh, allow_tf32=False) b_dv = tl.load(p_dv, boundary_check=(0, 1)) b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) @@ -479,7 +473,7 @@ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT, scale): NT = triton.cdiv(T, BT) grid = (NK, NT, B * H) dq = torch.empty_like(q) - dk = torch.empty_like(k) + dk = torch.empty_like(k, dtype=torch.float32) dw = torch.empty_like(w) chunk_delta_rule_bwd_kernel_dqkw[grid]( q, k, v_new, w, h, do, dh, dq, dk, du, dw, @@ -489,7 +483,7 @@ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT, scale): scale=scale, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, ) - return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + return dq, dk, dw class ChunkDeltaRuleFunction(torch.autograd.Function): @@ -539,6 +533,7 @@ def backward(ctx, do, dht): return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, dh0, None, None, None + def chunk_delta_rule( q: torch.Tensor, k: torch.Tensor, @@ -572,9 +567,10 @@ def chunk_delta_rule( assert q.dtype == k.dtype == v.dtype assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)." + assert BT in [32, 64], "ChunkDeltaRuleFunction only supports BT=32/64." if scale is None: scale = k.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive." o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, scale, initial_state, output_final_state) - return o, final_state + return o, final_state \ No newline at end of file diff --git a/fla/ops/delta_rule/wy_fast.py b/fla/ops/delta_rule/wy_fast.py index fdfc02065..6ff7f4403 100644 --- a/fla/ops/delta_rule/wy_fast.py +++ b/fla/ops/delta_rule/wy_fast.py @@ -6,8 +6,6 @@ from einops import rearrange from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous - - # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 @triton.autotune( configs=[ @@ -17,10 +15,10 @@ triton.Config({}, num_warps=8), triton.Config({}, num_warps=16) ], - key=["BT", "BK", "BV"], + key=["BK"] ) @triton.jit -def fwd_prepare_wy_repr_kernel( +def fwd_prepare_wy_repr_kernel_chunk32( k, v, beta, @@ -38,6 +36,7 @@ def fwd_prepare_wy_repr_kernel( V, BT: tl.constexpr, BK: tl.constexpr, + BC: tl.constexpr, BV: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) @@ -66,22 +65,90 @@ def fwd_prepare_wy_repr_kernel( tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1)) b_A = b_A.to(k.dtype.element_ty) - for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) - b_u = tl.dot(b_A, b_vb) - p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BK"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel_chunk64( + k, + v, + beta, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + p_beta2 = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + b_beta2 = tl.load(p_beta2, boundary_check=(0,)) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) - b_w = tl.dot(b_A, b_kb) - p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype) + b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A2 += tl.dot(b_kb2, tl.trans(b_k2), allow_tf32=False) + b_A3 += tl.dot(b_kb2, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = -tl.dot(tl.dot(b_A2, b_A3), b_A) + + p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + + @triton.autotune( configs=[ @@ -244,8 +311,8 @@ def bwd_prepare_wy_repr_kernel( tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) - b_dA = tl.dot(b_dA.to(b_A.dtype), b_A, allow_tf32=False) - b_dA = tl.dot(b_A, b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) for i_k in range(tl.cdiv(K, BK)): @@ -273,12 +340,15 @@ def fwd_prepare_wy_repr(k, v, beta, BT): BK = min(triton.next_power_of_2(K), 64) BV = min(triton.next_power_of_2(V), 64) A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype) - fwd_prepare_wy_repr_kernel[(NT, B*H)]( + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + + fwd_fn[(NT, B*H)]( k, v, beta, w, u, A, k.stride(1), k.stride(2), k.stride(3), v.stride(1), v.stride(2), v.stride(3), - T, K, V, BT, BK, BV + T, K, V, BT, BK, 32, BV ) + w, u = fwd_recompute_w_u(k, v, beta, A, BT) return w, u, A @@ -313,7 +383,6 @@ def fwd_recompute_w(k, beta, A, BT): def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT): B, H, T, K, V = *k.shape, v.shape[-1] - NT = triton.cdiv(T, BT) BK = min(triton.next_power_of_2(K), 64) BV = min(triton.next_power_of_2(V), 64) @@ -339,6 +408,7 @@ class WYRepresentationPrepration(torch.autograd.Function): @contiguous @autocast_custom_fwd def forward(ctx, k, v, beta, chunk_size=64): + assert chunk_size in [16, 32, 64] ctx.BT = chunk_size w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT) ctx.save_for_backward(k, v, beta, A) @@ -419,5 +489,4 @@ def naive(k, v, beta, chunk_size): k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad print((k_grad2-k_grad).abs().max()) print((v_grad2-v_grad).abs().max()) - print((beta_grad2-beta_grad).abs().max()) - breakpoint() + print((beta_grad2-beta_grad).abs().max()) \ No newline at end of file diff --git a/tests/ops/test_delta.py b/tests/ops/test_delta.py index 8def94104..c7f4dde51 100644 --- a/tests/ops/test_delta.py +++ b/tests/ops/test_delta.py @@ -6,6 +6,11 @@ from fla.ops.delta_rule import (chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule) +def get_err_ratio(x, y): + err = (x-y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + @pytest.mark.parametrize("B", [2]) @pytest.mark.parametrize("H", [2]) @@ -17,27 +22,32 @@ def test_fused_chunk_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dt q = torch.randn(B, H, T, D, dtype=dtype) k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) v = torch.randn(B, H, T, D, dtype=dtype) - beta = torch.rand(B, H, T, dtype=dtype).sigmoid().fill_(1) + beta = torch.rand(B, H, T, dtype=dtype).sigmoid() h0 = torch.randn(B, H, D, D, dtype=torch.float32) q, k, v, beta, h0 = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta, h0)) do = torch.rand_like(v) - dh0 = torch.rand_like(h0) + dht = torch.rand_like(h0) - o2, h2 = fused_chunk_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(),scale=scale, output_final_state=True, initial_state=h0.clone()) - ((o2 * do).sum() + (h2 * dh0).sum()).backward(retain_graph=True) - q_grad2, k_grad2, v_grad2, beta_grad2, h0_grad2 = q.grad, k.grad, v.grad, beta.grad, h0.grad + tri, tri_ht = fused_chunk_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(),scale=scale, output_final_state=True, initial_state=h0.clone()) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad q.grad = k.grad = v.grad = beta.grad = h0.grad = None - o, h1 = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), scale=scale, output_final_state=True, initial_state=h0.clone()) - ((o * do).sum() + (h1 * dh0).sum()).backward(retain_graph=True) - assert torch.abs(o - o2).max() < 5 - assert torch.abs(h1 - h2).max() < 5 - assert torch.abs(q.grad - q_grad2).max() < 5 - assert torch.abs(k.grad - k_grad2).max() < 5 - assert torch.abs(v.grad - v_grad2).max() < 5 - assert torch.abs(beta.grad - beta_grad2).max() < 5 - assert torch.abs(h0_grad2 - h0.grad).max() < 5 + + ref, ref_ht = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), scale=scale, output_final_state=True, initial_state=h0.clone()) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad + + assert get_err_ratio(tri, ref) < 0.005, f" o diff: {torch.abs(ref - tri).max()}, ref_o_max: {ref.abs().max()}, tri_o_max: {tri.abs().max()}, ratio: {get_err_ratio(ref, tri)}" + assert get_err_ratio(tri_ht, ref_ht) < 0.005, f"ht diff: {torch.abs(ref_ht - tri_ht).max()}, ratio: {get_err_ratio(ref_ht, tri_ht)}" + assert get_err_ratio(tri_dq, ref_dq) < 0.007, f"dq diff: {torch.abs(ref_dq - tri_dq).max()}, ratio: {get_err_ratio(ref_dq, tri_dq)}" + assert get_err_ratio(tri_dk, ref_dk) < 0.007, f"dk diff: {torch.abs(ref_dk - tri_dk).max()}, ratio: {get_err_ratio(ref_dk, tri_dk)}" + assert get_err_ratio(tri_dv, ref_dv) < 0.007, f"dv diff: {torch.abs(ref_dv - tri_dv).max()}, ratio: {get_err_ratio(ref_dv, tri_dv)}" + assert get_err_ratio(tri_dbeta, ref_dbeta) < 0.007, f"dg diff: {torch.abs(ref_dg - tri_dg).max()}, ref_dg_max: {ref_dg.abs().max()}, tri_dg_max: {tri_dg.abs().max()}, ratio: {get_err_ratio(ref_dg, tri_dg)}" + assert get_err_ratio(tri_dh0, ref_dh0) < 0.007, f"dh0 diff: {torch.abs(ref_dh0 - tri_dh0).max()}, ref_dho_max: {ref_dh0.abs().max()}, tri_dh0_max: {tri_dh0.abs().max()}, ratio: {get_err_ratio(ref_dh0, tri_dh0)}" + + @pytest.mark.parametrize("B", [2]) @pytest.mark.parametrize("H", [2]) @pytest.mark.parametrize("T", [256, 486]) @@ -48,55 +58,26 @@ def test_chunk_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype, s q = torch.randn(B, H, T, D, dtype=dtype) k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=torch.float32), p=2, dim=-1).to(dtype) v = torch.randn(B, H, T, D, dtype=dtype) - beta = torch.rand(B, H, T, dtype=dtype).sigmoid().fill_(1) + beta = torch.rand(B, H, T, dtype=dtype).sigmoid() h0 = torch.randn(B, H, D, D, dtype=torch.float32) q, k, v, beta, h0 = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta, h0)) do = torch.rand_like(v) - dh0 = torch.rand_like(h0) - - o2, h2 = chunk_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(),scale=scale, output_final_state=True, initial_state=h0.clone()) - ((o2 * do).sum() + (h2 * dh0).sum()).backward(retain_graph=True) - q_grad2, k_grad2, v_grad2, beta_grad2, h0_grad2 = q.grad, k.grad, v.grad, beta.grad, h0.grad + dht = torch.rand_like(h0) + + tri, tri_ht = chunk_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(),scale=scale, output_final_state=True, initial_state=h0.clone(), BT=64) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad q.grad = k.grad = v.grad = beta.grad = h0.grad = None - o, h1 = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), scale=scale, output_final_state=True, initial_state=h0.clone()) - ((o * do).sum() + (h1 * dh0).sum()).backward(retain_graph=True) - assert torch.abs(o - o2).max() < 5 - assert torch.abs(h1 - h2).max() < 5 - assert torch.abs(q.grad - q_grad2).max() < 5 - assert torch.abs(k.grad - k_grad2).max() < 5 - assert torch.abs(v.grad - v_grad2).max() < 5 - assert torch.abs(beta.grad - beta_grad2).max() < 5 - assert torch.abs(h0_grad2 - h0.grad).max() < 5 - - -# @pytest.mark.parametrize("B", [8]) -# @pytest.mark.parametrize("H", [4]) -# @pytest.mark.parametrize("T", [1024]) -# @pytest.mark.parametrize("D", [128]) -# @pytest.mark.parametrize("dtype", [torch.float]) -# def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype): -# torch.manual_seed(17) -# q = torch.randn(B, H, T, D, dtype=dtype) -# k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=dtype), p=2, dim=-1) -# v = torch.randn(B, H, T, D, dtype=dtype) -# beta = torch.rand(B, H, T, D, dtype=dtype).sigmoid() -# q, k, v, beta = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta)) -# do = torch.rand_like(v) - -# o = delta_rule_recurrence(q.clone(), k.clone(), v.clone(), beta.clone()) -# o.backward(do, retain_graph=True) -# q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad -# q.grad = k.grad = v.grad = beta.grad = None - -# o2, _ = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone()) -# o2.backward(do, retain_graph=True) -# q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad -# q.grad = k.grad = v.grad = beta.grad = None + + ref, ref_ht = fused_recurrent_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), scale=scale, output_final_state=True, initial_state=h0.clone()) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, h0.grad -# assert o.allclose(o2, rtol=0, atol=2e-5), f"Diff: {torch.abs(o - o2).max()}" -# assert q_grad.allclose(q_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(q_grad - q_grad2).max()}" -# assert k_grad.allclose(k_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(k_grad - k_grad2).max()}" -# assert v_grad.allclose(v_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(v_grad - v_grad2).max()}" -# # FIXME: this gradient does not match when beta a vector. matches when a scalar. -# assert beta_grad.allclose(beta_grad2, rtol=0, atol=1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}" + assert get_err_ratio(tri, ref) < 0.005, f" o diff: {torch.abs(ref - tri).max()}, ref_o_max: {ref.abs().max()}, tri_o_max: {tri.abs().max()}, ratio: {get_err_ratio(ref, tri)}" + assert get_err_ratio(tri_ht, ref_ht) < 0.005, f"ht diff: {torch.abs(ref_ht - tri_ht).max()}, ratio: {get_err_ratio(ref_ht, tri_ht)}" + assert get_err_ratio(tri_dq, ref_dq) < 0.007, f"dq diff: {torch.abs(ref_dq - tri_dq).max()}, ratio: {get_err_ratio(ref_dq, tri_dq)}" + assert get_err_ratio(tri_dk, ref_dk) < 0.007, f"dk diff: {torch.abs(ref_dk - tri_dk).max()}, ratio: {get_err_ratio(ref_dk, tri_dk)}" + assert get_err_ratio(tri_dv, ref_dv) < 0.007, f"dv diff: {torch.abs(ref_dv - tri_dv).max()}, ratio: {get_err_ratio(ref_dv, tri_dv)}" + assert get_err_ratio(tri_dbeta, ref_dbeta) < 0.007, f"dg diff: {torch.abs(ref_dg - tri_dg).max()}, ref_dg_max: {ref_dg.abs().max()}, tri_dg_max: {tri_dg.abs().max()}, ratio: {get_err_ratio(ref_dg, tri_dg)}" + assert get_err_ratio(tri_dh0, ref_dh0) < 0.007, f"dh0 diff: {torch.abs(ref_dh0 - tri_dh0).max()}, ref_dho_max: {ref_dh0.abs().max()}, tri_dh0_max: {tri_dh0.abs().max()}, ratio: {get_err_ratio(ref_dh0, tri_dh0)}"