diff --git a/fla/ops/simple_gla/__init__.py b/fla/ops/simple_gla/__init__.py index da0df46c6..a01ac126d 100644 --- a/fla/ops/simple_gla/__init__.py +++ b/fla/ops/simple_gla/__init__.py @@ -2,8 +2,10 @@ from .chunk import chunk_simple_gla from .fused_recurrent import fused_recurrent_simple_gla +from .parallel import ParallelSimpleGLAFunction __all__ = [ 'chunk_simple_gla', - 'fused_recurrent_simple_gla' + 'fused_recurrent_simple_gla', + 'ParallelSimpleGLAFunction' ] diff --git a/fla/ops/simple_gla/parallel.py b/fla/ops/simple_gla/parallel.py new file mode 100644 index 000000000..75692a1f2 --- /dev/null +++ b/fla/ops/simple_gla/parallel.py @@ -0,0 +1,567 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_reversed_cumsum, chunk_local_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']), + 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None +}) +@triton.jit +def parallel_simple_gla_fwd_kernel( + q, + k, + v, + g, + o, + attn, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + (i_k * B * H + i_bh) * T * T, (T, T), (T, 1), (i_t * BT, 0), (BT, BS), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_g = tl.load(p_g, boundary_check=(0,)) + b_gn = tl.load(g + i_bh * T + (i_s // BT) * BT + BT - 1) + + b_kg = (b_k * tl.exp(b_gn - b_g)).to(b_k.dtype) + # [BT, BS] + b_s = tl.dot(b_q, b_kg, allow_tf32=False) + # do this check to avoid some layout bugs + # [[BT, BV] + if i_s > 0 and i_s % BT == 0: + b_o = b_o * tl.exp(b_gn) + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + + if OUTPUT_ATTENTIONS: + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + p_a = tl.advance(p_a, (0, BS)) + + tl.debug_barrier() + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_g, boundary_check=(0,)) + # rescale interchunk output + b_o *= tl.exp(b_gq)[:, None] + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + (i_k * B * H + i_bh) * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BS), (1, 0)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. + # masks required + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.where(m_s, tl.dot(b_q, b_k, allow_tf32=False) * tl.exp(b_gq[:, None] - b_gk[None, :]), 0) + # [BT, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + if OUTPUT_ATTENTIONS: + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + p_a = tl.advance(p_a, (0, BS)) + o_k += BS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def parallel_simple_gla_bwd_kernel_dq( + i_bh, + i_t, + i_k, + i_v, + k, + v, + g, + do, + dq, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (1, s_v_t), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_g = tl.load(p_g, boundary_check=(0,)) + + b_gn = tl.load(g + i_bh * T + (i_s // BT) * BT + BT - 1) + # [BT, BS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * tl.exp(b_gn - b_g)[None, :] + # [BT, BK] + if i_s > 0 and i_s % BT == 0: + b_dq *= tl.exp(b_gn) + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False) + + p_gq = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_gq, boundary_check=(0,)) + # [BT, BK] + b_dq *= tl.exp(b_gq)[:, None] * scale + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. masks required + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (1, s_v_t), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.where(m_s, tl.dot(b_do, b_v, allow_tf32=False) * tl.exp((b_gq[:, None] - b_gk[None, :])), 0) * scale + # [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + + o_k += BS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def parallel_simple_gla_bwd_kernel_dkv( + i_bh, + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dk, + dv, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + # [BT,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + + NTS = tl.cdiv(T, BS) + # [BT, BK] + b_kg = (b_k * tl.exp(tl.load(g + i_bh * T + min(i_t * BT + BT, T) - 1) - b_gk)[:, None]).to(b_k.dtype) + for i_s in range(NTS * BS - BS, min((i_t + 1) * BT, T) - BS, -BS): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gq = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BS,] + b_gq = tl.load(p_gq, boundary_check=(0,)) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_gq)[:, None]).to(b_do.dtype) + + if i_s % BT == 0: + # overall decay rate for an entire block + b_gn = tl.exp(tl.load(g + i_bh * T + min(i_s + BT, T) - 1)) + # [BS, BK] + b_dk *= b_gn + # [BS, BV] + b_dv *= b_gn + # [BT, BS] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_s = tl.dot(b_kg, tl.trans(b_q), allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q, allow_tf32=False) + # [BT, BV] + b_dv += tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + + # [BT, BK] + b_dk *= tl.exp(tl.load(g + i_bh * T + min(T, i_t * BT + BT) - 1) - b_gk)[:, None] * scale + # [BT, BV] + b_dv *= scale + + tl.debug_barrier() + o_q = i_t * BT + tl.arange(0, BS) + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gq = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_s,), (BS,), (0,)) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_gq = tl.load(p_gq, boundary_check=(0,)) + # [BT, BS] + m_s = o_k[:, None] <= o_q[None, :] + d_s = tl.where(m_s, tl.exp(-b_gk[:, None] + b_gq[None, :]), 0) * scale + + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) * d_s + b_s = tl.dot(b_k, tl.trans(b_q), allow_tf32=False) * d_s + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q, allow_tf32=False) + b_dv += tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + o_q += BS + p_dk = tl.make_block_ptr(dk + (i_v * B * H + i_bh) * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k * B * H + i_bh) * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']) +}) +@triton.jit +def parallel_simple_gla_bwd_kernel( + q, + k, + v, + g, + do, + dq, + dk, + dv, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + + parallel_simple_gla_bwd_kernel_dq( + i_bh, + i_t, + i_k, + i_v, + k, + v, + g, + do, + dq, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + tl.debug_barrier() + parallel_simple_gla_bwd_kernel_dkv( + i_bh, + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dk, + dv, + s_k_h, + s_k_t, + s_v_h, + s_v_t, + scale, + B, + H, + T, + K, + V, + BT, + BS, + BK, + BV + ) + + +def parallel_simple_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + output_attentions: bool = False, + chunk_size: int = 64 +): + B, H, T, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + + # local cumulative decay in log space + g = chunk_local_cumsum(g, BT) + + grid = (NK * NV, triton.cdiv(T, BT), B * H) + o = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None + parallel_simple_gla_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + o=o, + attn=attn, + s_k_h=k.stride(1), + s_k_t=k.stride(2), + s_v_h=v.stride(1), + s_v_t=v.stride(2), + scale=scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + num_stages=num_stages, + num_warps=num_warps + ) + o = o.sum(0) + if output_attentions: + attn = attn.sum(0) + return o, g, attn + + +def parallel_simple_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + scale: float, + chunk_size: int = 64 +): + B, H, T, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + grid = (NK * NV, triton.cdiv(T, BT), B * H) + parallel_simple_gla_bwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + do=do, + dq=dq, + dk=dk, + dv=dv, + s_k_h=k.stride(1), + s_k_t=k.stride(2), + s_v_h=v.stride(1), + s_v_t=v.stride(2), + scale=scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + num_stages=num_stages, + num_warps=num_warps + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg = chunk_global_reversed_cumsum((dq * q.float() - dk * k.float()).sum(-1)) + return dq, dk, dv, dg + + +class ParallelSimpleGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, output_attentions): + BT = 64 + o, g, attn = parallel_simple_gla_fwd(q, k, v, g, scale, output_attentions, BT) + ctx.save_for_backward(q, k, v, g) + ctx.scale = scale + ctx.BT = BT + return o.to(q.dtype), attn + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, da=None): + q, k, v, g = ctx.saved_tensors + dq, dk, dv, dg = parallel_simple_gla_bwd(q, k, v, g, do, ctx.scale, ctx.BT) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None + + +def parallel_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float = None, + output_attentions: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` + g (torch.Tensor): + Forget gates of shape `(B, H, T)` applied to keys. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + return ParallelSimpleGLAFunction.apply(q, k, v, g, scale, output_attentions) + return ParallelSimpleGLAFunction.apply(q, k, v, g, scale, output_attentions) diff --git a/tests/ops/test_simple_gla.py b/tests/ops/test_simple_gla.py index 2bd57ce5b..7ee8288e3 100644 --- a/tests/ops/test_simple_gla.py +++ b/tests/ops/test_simple_gla.py @@ -6,6 +6,11 @@ from fla.ops.simple_gla import chunk_simple_gla from fla.ops.simple_gla.naive import torch_simple_gla_recurrent +from fla.ops.simple_gla.parallel import parallel_simple_gla + + +def get_abs_err(x, y): + return (x-y).flatten().abs().max().item() def get_err_ratio(x, y): @@ -14,6 +19,12 @@ def get_err_ratio(x, y): return err / base +def assert_close(prefix, ref, tri, ratio): + msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + assert get_err_ratio(ref, tri) < ratio, msg + + @pytest.mark.parametrize("B", [2]) @pytest.mark.parametrize("H", [2]) @pytest.mark.parametrize("T", [100, 512]) @@ -27,7 +38,7 @@ def test_chunk( dtype: torch.dtype ): torch.manual_seed(42) - # [B, H, T, d_head] + # [B, H, T, D] q = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) k = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) v = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) @@ -35,6 +46,7 @@ def test_chunk( h0 = torch.rand((B, H, D, D), dtype=torch.float32, device='cuda').requires_grad_(True) g = F.logsigmoid(g).requires_grad_(True) do = torch.randn_like(v) + ref, ref_ht = torch_simple_gla_recurrent(q, k, v, g, initial_state=h0) d_ht = torch.randn_like(ref_ht) ((ref * do).sum() + (ref_ht * d_ht).sum()).backward() @@ -44,7 +56,6 @@ def test_chunk( ref_dg, g.grad = g.grad.clone(), None ref_dh0, h0.grad = h0.grad.clone(), None - # triton implementation tri, tri_ht = chunk_simple_gla(q, k, v, g, initial_state=h0, output_final_state=True) ((tri * do).sum() + (tri_ht * d_ht).sum()).backward() tri_dq, q.grad = q.grad.clone(), None @@ -68,6 +79,48 @@ def test_chunk( f"dh0 diff: {torch.abs(ref_dh0 - tri_dh0).max()}, ratio: {get_err_ratio(ref_dh0, tri_dh0)}" +@pytest.mark.parametrize("B", [4]) +@pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("T", [300, 512]) +@pytest.mark.parametrize("D", [32, 64, 100]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_parallel( + B: int, + H: int, + T: int, + D: int, + dtype: torch.dtype +): + torch.manual_seed(42) + + q = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) + k = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) + v = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_(True) + h0 = torch.zeros((B, H, D, D), dtype=torch.float32, device='cuda') + g = F.logsigmoid(torch.randn((B, H, T), dtype=dtype, device='cuda')).requires_grad_(True) + do = torch.randn_like(v) + + ref, _ = torch_simple_gla_recurrent(q, k, v, g, initial_state=h0) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + + tri, _ = parallel_simple_gla(q, k, v, g) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + + assert_close(" o", ref, tri, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.005) + assert_close("dk", ref_dk, tri_dk, 0.005) + assert_close("dv", ref_dv, tri_dv, 0.005) + assert_close("dg", ref_dg, tri_dg, 0.005) + + @pytest.mark.parametrize("vary_A", [True, False]) @pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) def test_simple_gla_to_mamba2(vary_A, dtype):