Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton flash atten module generates wrong results #514

Closed
zhangxiao-stack opened this issue Feb 20, 2024 · 20 comments
Closed

triton flash atten module generates wrong results #514

zhangxiao-stack opened this issue Feb 20, 2024 · 20 comments

Comments

@zhangxiao-stack
Copy link

Problem Description

Hi,
Can somebody please take a look a this ?
I just tested this code and it generates wrong results.

"""
Fused Attention
===============

This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)

Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
- Adam P. Goucher for simplified vector math

"""

import pytest
import torch

import triton
import triton.language as tl
import os
os.environ["MFMA_TYPE"] = "16"


@triton.jit
def max_fn(x, y):
    return tl.math.max(x, y)

@triton.jit
def _attn_fwd_inner(
    acc, l_i, m_i, q,
    K_block_ptr, V_block_ptr,
    start_m,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr,
    offs_m: tl.constexpr,
    offs_n: tl.constexpr,
    N_CTX,
    pre_load_v: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
        K_block_ptr = tl.advance(K_block_ptr, (0, lo))
        V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # causal = False
    else:
        lo, hi = 0, N_CTX
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(K_block_ptr)
        if pre_load_v:
            v = tl.load(V_block_ptr)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = tl.where(mask, qk, float("-inf"))
        qk += tl.dot(q, k)
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk = qk - m_ij[:, None]
        p = tl.math.exp2(qk)
        # -- update output accumulator --
        alpha = tl.math.exp2(m_i - m_ij)
        acc = acc * alpha[:, None]
        if not pre_load_v:
            v = tl.load(V_block_ptr)
        acc += tl.dot(p.to(tl.float16), v)
        # -- update m_i and l_i
        l_ij = tl.sum(p, 1)
        l_i = l_i * alpha + l_ij
        # update m_i and l_i
        m_i = m_ij
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
    return acc, l_i, m_i


@triton.autotune(
   configs=[
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4),
       triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4),
   ],
   key=['N_CTX', 'STAGE'],
)


@triton.jit
def _attn_fwd(
    Q, K, V, sm_scale, M, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H,
    N_CTX,
    STAGE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    pre_load_v: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qkv_offset = off_hz * stride_qh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # scale sm_scale by log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
    q = tl.load(Q_block_ptr)
    q = (q * qk_scale).to(tl.float16)
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(
            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
            start_m,
            BLOCK_M, BLOCK_DMODEL, BLOCK_N,
            4 - STAGE, offs_m, offs_n,
            N_CTX, pre_load_v,
        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        tl.debug_barrier()
        acc, l_i, m_i = _attn_fwd_inner(
            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
            start_m,
            BLOCK_M, BLOCK_DMODEL, BLOCK_N,
            2, offs_m, offs_n,
            N_CTX, pre_load_v,
        )
    # epilogue
    # write back m
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i + tl.math.log2(l_i))
    # write back O
    O_block_ptr = tl.make_block_ptr(
        base=Out + qkv_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))


@triton.jit
def _bwd_preprocess(
    Out, DO,
    NewDO, Delta,
    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_n = tl.arange(0, D_HEAD)
    # load
    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    # compute
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
    tl.store(Delta + off_m, delta)


@triton.jit
def _bwd_kernel(
    Q, K, V, sm_scale, Out, DO,
    DQ, DK, DV,
    L,
    D,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    Z, H, N_CTX, P_SEQ,
    num_block_q, num_block_kv,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    CAUSAL: tl.constexpr,
):
    off_hz = tl.program_id(0)
    off_z = off_hz // H
    off_h = off_hz % H
    qk_scale = sm_scale * 1.44269504
    # offset pointers for batch/head
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    DO += off_z * stride_qz + off_h * stride_qh
    DQ += off_z * stride_qz + off_h * stride_qh
    DK += off_z * stride_kz + off_h * stride_kh
    DV += off_z * stride_vz + off_h * stride_vh
    # See fwd pass above for explanation.
    qk_scale = sm_scale * 1.44269504
    for start_n in range(0, num_block_kv):
        if CAUSAL:
            lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0)
        else:
            lo = 0
        # initialize row/col offsets
        offs_qm = lo + tl.arange(0, BLOCK_M)
        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
        offs_m = tl.arange(0, BLOCK_N)
        offs_k = tl.arange(0, BLOCK_DMODEL)
        # initialize pointers to value-like data
        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        v_ptrs = V + (offs_n[None, :] * stride_qm + offs_k[:, None] * stride_qk)
        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        # pointer to row-wise quantities in value-like data
        D_ptrs = D + off_hz * N_CTX
        l_ptrs = L + off_hz * N_CTX
        # initialize dk amd dv
        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        # k and v stay in SRAM throughout
        k = tl.load(k_ptrs)
        v = tl.load(v_ptrs)
        # loop over rows
        for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M):
            offs_m_curr = start_m + offs_m
            # load q, k, v, do on-chip
            q = tl.load(q_ptrs)
            # recompute p = softmax(qk, dim=-1).T
            if CAUSAL:
                qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
            else:
                qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk += tl.dot(q, tl.trans(k))
            l_i = tl.load(l_ptrs + offs_m_curr)
            p = tl.math.exp2(qk * qk_scale - l_i[:, None])
            # compute dv
            do = tl.load(do_ptrs)
            dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
            # compute dp = dot(v, do)
            Di = tl.load(D_ptrs + offs_m_curr)
            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
            dp += tl.dot(do, v)
            # compute ds = p * (dp - delta[:, None])
            ds = p * dp * sm_scale
            # compute dk = dot(ds.T, q)
            dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
            # compute dq
            dq = tl.load(dq_ptrs)
            dq += tl.dot(ds.to(Q.dtype.element_ty), k)
            tl.store(dq_ptrs, dq)
            # increment pointers
            dq_ptrs += BLOCK_M * stride_qm
            q_ptrs += BLOCK_M * stride_qm
            do_ptrs += BLOCK_M * stride_qm
        # write-back
        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        tl.store(dk_ptrs, dk)
        tl.store(dv_ptrs, dv)

@triton.jit
def _bwd_kernel_dk_dv(
    Q, K, V, sm_scale, Out, DO,
    DK, DV,
    L,
    D,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    # Q is consumed depending on block ID. Every block uses
    # previous block offset by BLOCK_M x D_HEAD.
    qvk_offset = off_hz * stride_qh
    qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    # Initialize pointers to Q, K, V
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qdo_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, start_m * BLOCK_M),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_vn, stride_vk),
        offsets=(0, start_m * BLOCK_M),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    DO_block_ptr = tl.make_block_ptr(
        base=DO + qdo_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )
    # pointer to row-wise quantities in value-like data
    D_ptrs = D + off_hz * N_CTX
    l_ptrs = L + off_hz * N_CTX
    qk_scale = sm_scale * 1.44269504
    # load k and v: they will stay in SRAM throughout
    k = tl.load(K_block_ptr)
    k = (k * qk_scale).to(tl.float16)
    v = tl.load(V_block_ptr)
    dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # This lower loop bound is because of the causal mask. We create a lower triangular
    # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
    # be ignored in the GEMM.
    lo = start_m * BLOCK_M
    hi = N_CTX
    # loop over q, do
    for start_n in range(lo, hi, BLOCK_N):
        offs_m_curr = offs_n[:, None] + start_n
        # -- load q, do --
        q = tl.load(Q_block_ptr)
        do = tl.load(DO_block_ptr)
        # -- compute qk ----
        qk = tl.dot(q, k)
        qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
        l_i = tl.load(l_ptrs + offs_m_curr)
        p = tl.math.exp2(qk - l_i)
        # -- compute dv ----
        dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
        # compute dp = dot(v, do)
        Di = tl.load(D_ptrs + offs_m_curr)
        dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
        dp += tl.dot(do, v)
        # compute ds = p * (dp - delta[:, None])
        ds = p * dp
        # compute dk
        dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
        # update pointers
        Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
        DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
    # initialize pointers to output
    DK_block_ptr = tl.make_block_ptr(
        base=DK + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_kn, stride_kk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    DV_block_ptr = tl.make_block_ptr(
        base=DV + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    tl.store(DK_block_ptr, (dk * sm_scale).to(tl.float16))
    tl.store(DV_block_ptr, dv.to(tl.float16))

@triton.jit
def _bwd_kernel_dq(
    Q, K, V, sm_scale, Out, DO,
    DQ,
    L,
    D,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qvk_offset = off_hz * stride_qh
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    # Initialize pointers to Q, K, V
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_vn, stride_vk),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    DO_block_ptr = tl.make_block_ptr(
        base=DO + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    # pointer to row-wise quantities in value-like data
    D_ptrs = D + off_hz * N_CTX
    l_ptrs = L + off_hz * N_CTX
    qk_scale = sm_scale * 1.44269504
    # load q and do: they will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    q = (q * qk_scale).to(tl.float16)
    do = tl.load(DO_block_ptr)
    Di = tl.load(D_ptrs + offs_m)
    l_i = tl.load(l_ptrs + offs_m)
    dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # loop over k, v
    lo = 0
    hi = (start_m + 1) * BLOCK_M
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(K_block_ptr)
        v = tl.load(V_block_ptr)
        # -- compute qk ----
        qk = tl.dot(q, k)
        qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
        p = tl.math.exp2(qk - l_i[:, None])
        # compute dp = dot(v, do)
        dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
        dp += tl.dot(do, v)
        # compute ds = p * (dp - delta[:, None])
        ds = p * dp
        # compute dq. Unfortunately we cannot avoid transpose here as this loop
        # uses k both normal and transpose.
        dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
    # initialize pointers to output
    DQ_block_ptr = tl.make_block_ptr(
        base=DQ + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))

empty = torch.empty(128, device="cuda")


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        if torch.version.hip is None:
            BLOCK_M = 128
            BLOCK_N = 64 if Lk <= 64 else 32
            num_stages = 4 if Lk <= 64 else 3
            num_warps = 4 if Lk <= 64 else 8

        stage = 3 if causal else 1
        grid = lambda META: (
            triton.cdiv(q.shape[2], META['BLOCK_M']),
            q.shape[0] * q.shape[1],
            1
        )
        M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            q.shape[0], q.shape[1],
            N_CTX=q.shape[2],
            BLOCK_DMODEL=Lk,
            STAGE=stage,
        )

        ## restore the grid for bwd kernel
        best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage)
        block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
        grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)

        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        ctx.causal = causal
        ctx.split_kernel = split_kernel
        return o

    @staticmethod
    def backward(ctx, do):
        # configuration is not supported
        assert(not (ctx.split_kernel and not ctx.causal))
        if torch.version.hip is not None:
            BLOCK = 64
        else:
            BLOCK = 128
        q, k, v, o, L = ctx.saved_tensors
        do = do.contiguous()
        dq = torch.zeros_like(q, dtype=torch.float32)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        delta = torch.empty_like(L)
        do_scaled = torch.empty_like(do)
        # Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
        # If the two are the same, we don't need this but the bwd pass block size
        # is smaller than the fwd so we need this scaling to ensure we loop over all
        # values and don't skip some blocks. 
        # Alternatively we could compute a new grid but this keeps it consistent
        # with fwd and easier to reason about.
        block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
        _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
            o, do,
            do_scaled, delta,
            BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
        )
        if not ctx.split_kernel:
            _bwd_kernel[(ctx.grid[1],)](
                q, k, v, ctx.sm_scale,
                o, do_scaled,
                dq, dk, dv,
                L, delta,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                block_scale * ctx.grid[0],
                BLOCK_M=BLOCK, BLOCK_N=BLOCK,
                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
                CAUSAL=ctx.causal,
                num_stages=1,
            )
        else :
            dq = torch.zeros_like(q)
            _bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
                q, k, v, ctx.sm_scale,
                o, do_scaled,
                dk, dv,
                L, delta,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                BLOCK_M=BLOCK, BLOCK_N=BLOCK,
                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
                num_stages=1,
            )
            _bwd_kernel_dq[ctx.grid](
                q, k, v, ctx.sm_scale,
                o, do_scaled,
                dq,
                L, delta,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
                num_stages=1,
            )
        # print(h.asm["ttgir"])
        return dq, dk, dv, None, None, None

attention = _attention.apply


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                         [(4, 48, 1024, 64),
                          (4, 48, 2048, 64),
                          (4, 48, 4096, 64),
                          #(4, 48, 8192, 64),
                          #(4, 48, 16384, 64)
                          ])
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
    torch.manual_seed(20)
    q = (
        torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
        .normal_(mean=0., std=0.5)
        .requires_grad_()
    )
    k = (
        torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
        .normal_(mean=0., std=0.5)
        .requires_grad_()
    )
    v = (
        torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
        .normal_(mean=0., std=0.5)
        .requires_grad_()
    )
    sm_scale = 0.5
    dout = torch.randn_like(q)
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    if causal:
        p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    #print("print fwd ref_out",ref_out)
    # triton implementation
    tri_out = attention(q, k, v, causal, sm_scale)
    #print("print fwd tri_out",tri_out)
    # compare
    assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                         [(4, 48, 1024, 64),
                          (4, 48, 2048, 64),
                          (4, 48, 4096, 64),
                          (1, 16, 8192, 64),
                          ])
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
    torch.manual_seed(20)
    causal = True
    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
    sm_scale = 0,5
    split_kernel = True
    dout = torch.randn_like(q)
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    if causal:
        p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    ref_out = torch.matmul(p, v)
    ref_out.backward(dout)
    ref_dv, v.grad = v.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dq, q.grad = q.grad.clone(), None
    # # triton implementation
    tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
    tri_out.backward(dout)
    tri_dv, v.grad = v.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dq, q.grad = q.grad.clone(), None

    #print("print bwd ref_out",ref_out)
    #print("print bwd tri_out",tri_out)
    # compare
    assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
    if torch.version.hip is None:
        assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
    # The current block size for MI200 series is 64x64. This results in
    # larger differences in float results due to rounding.
    else:
        assert torch.allclose(ref_dv, tri_dv, atol=5e-2, rtol=0)
    assert torch.allclose(ref_dk, tri_dk, atol=5e-2, rtol=0)
    assert torch.allclose(ref_dq, tri_dq, atol=5e-2, rtol=0)


try:
    from flash_attn.flash_attn_interface import \
        flash_attn_qkvpacked_func as flash_attn_func
    FLASH_VER = 2
except BaseException:
    try:
        from flash_attn.flash_attn_interface import flash_attn_func
        FLASH_VER = 1
    except BaseException:
        FLASH_VER = None
HAS_FLASH = FLASH_VER is not None

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd', 'bwd']:
    for causal in [False, True]:
        if mode == 'bwd' and causal == False:
            continue
        configs.append(triton.testing.Benchmark(
            x_names=['N_CTX'],
            x_vals=[2**i for i in range(10, 15)],
            line_arg='provider',
            line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
            line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
            styles=[('red', '-'), ('blue', '-')],
            ylabel='ms',
            plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
            args={
                'H': N_HEADS,
                'BATCH': BATCH,
                'D_HEAD': D_HEAD,
                'dtype': torch.float16,
                'mode': mode,
                'causal': causal})
        )


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
    assert mode in ['fwd', 'bwd']
    warmup = 25
    rep = 100
    split_kernel = False
    # Bwd pass only supports causal=True right now
    if mode == 'bwd':
        causal = True
        split_kernel = True
    if provider == "triton":
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        sm_scale = 1.3
        fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "flash":
        qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
        if FLASH_VER == 1:
            lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
            cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
            cu_seqlens[1:] = lengths.cumsum(0)
            qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD)
            fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
        elif FLASH_VER == 2:
            fn = lambda: flash_attn_func(qkv, causal=causal)
        else:
            raise ValueError(f'unknown {FLASH_VER = }')
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD
    total_flops = 2 * flops_per_matmul
    if causal:
        total_flops *= 0.5
    if mode == 'bwd':
        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
    return total_flops / ms * 1e-9


# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path='.', print_data=True)

Operating System

NAME="Ubuntu" VERSION="20.04.5 LTS (Focal Fossa)"

CPU

Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz

GPU

AMD Instinct MI210

ROCm Version

ROCm 5.5.0

ROCm Component

No response

Steps to Reproduce

Steps

pytest 06-fused-attention.py

errors

06-fused-attention.py ..F...FFFF                                                                                                                     

============================================================================================================= FAILURES ==============================
__________________________________________________________________________________________________ test_op_fwd[False-4-48-4096-64] __________________

Z = 4, H = 48, N_CTX = 4096, D_HEAD = 64, causal = False, dtype = torch.float16

    @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                             [(4, 48, 1024, 64),
                              (4, 48, 2048, 64),
                              (4, 48, 4096, 64),
                              #(4, 48, 8192, 64),
                              #(4, 48, 16384, 64)
                              ])
    @pytest.mark.parametrize('causal', [False, True])
    def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
        torch.manual_seed(20)
        q = (
            torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
            .normal_(mean=0., std=0.5)
            .requires_grad_()
        )
        k = (
            torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
            .normal_(mean=0., std=0.5)
            .requires_grad_()
        )
        v = (
            torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
            .normal_(mean=0., std=0.5)
            .requires_grad_()
        )
        sm_scale = 0.5
        dout = torch.randn_like(q)
        # reference implementation
        M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
        p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
        if causal:
            p[:, :, M == 0] = float("-inf")
        p = torch.softmax(p.float(), dim=-1).half()
        ref_out = torch.matmul(p, v)
        #print("print fwd ref_out",ref_out)
        # triton implementation
        tri_out = attention(q, k, v, causal, sm_scale)
        #print("print fwd tri_out",tri_out)
        # compare
>       assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7fa4d75d5520>(tensor([[[[ 3.7460e-03,  2.2659e-02, -2.0050e-02,  ..., -1.227422e-02, -1.3916e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<UnsafeViewBackward0>), tensor([[[[0.0658, 0.0659, 0.0700,  ..., 0.069.0671,  ..., 0.0643, 0.0666, 0.0669]]]],\n       device='cuda:0', dtype=torch.float16, grad_fn=<_attentionBackward>), atol=0.01, rtol=0)
E        +    where <built-in method allclose of type object at 0x7fa4d75d5520> = torch.allclose

06-fused-attention.py:695: AssertionError
_____________________________________________________________________________________________________ test_op_bwd[4-48-1024-64] _____________________

Z = 4, H = 48, N_CTX = 1024, D_HEAD = 64, dtype = torch.float16

    @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                             [(4, 48, 1024, 64),
                              (4, 48, 2048, 64),
                              (4, 48, 4096, 64),
                              (1, 16, 8192, 64),
                              ])
    def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
        torch.manual_seed(20)
        causal = True
        q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        sm_scale = 0,5
        split_kernel = True
        dout = torch.randn_like(q)
        # reference implementation
        M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
>       p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
E       TypeError: only integer tensors of a single element can be converted to an index

06-fused-attention.py:715: TypeError
_____________________________________________________________________________________________________ test_op_bwd[4-48-2048-64] _____________________

Z = 4, H = 48, N_CTX = 2048, D_HEAD = 64, dtype = torch.float16

    @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                             [(4, 48, 1024, 64),
                              (4, 48, 2048, 64),
                              (4, 48, 4096, 64),
                              (1, 16, 8192, 64),
                              ])
    def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
        torch.manual_seed(20)
        causal = True
        q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        sm_scale = 0,5
        split_kernel = True
        dout = torch.randn_like(q)
        # reference implementation
        M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
>       p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
E       TypeError: only integer tensors of a single element can be converted to an index

06-fused-attention.py:715: TypeError
_____________________________________________________________________________________________________ test_op_bwd[4-48-4096-64] _____________________

Z = 4, H = 48, N_CTX = 4096, D_HEAD = 64, dtype = torch.float16

    @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                             [(4, 48, 1024, 64),
                              (4, 48, 2048, 64),
                              (4, 48, 4096, 64),
                              (1, 16, 8192, 64),
                              ])
    def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
        torch.manual_seed(20)
        causal = True
        q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        sm_scale = 0,5
        split_kernel = True
        dout = torch.randn_like(q)
        # reference implementation
        M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
>       p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
E       TypeError: only integer tensors of a single element can be converted to an index

06-fused-attention.py:715: TypeError
_____________________________________________________________________________________________________ test_op_bwd[1-16-8192-64] _____________________

Z = 1, H = 16, N_CTX = 8192, D_HEAD = 64, dtype = torch.float16

    @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                             [(4, 48, 1024, 64),
                              (4, 48, 2048, 64),
                              (4, 48, 4096, 64),
                              (1, 16, 8192, 64),
                              ])
    def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
        torch.manual_seed(20)
        causal = True
        q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
        sm_scale = 0,5
        split_kernel = True
        dout = torch.randn_like(q)
        # reference implementation
        M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
>       p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
E       TypeError: only integer tensors of a single element can be converted to an index

06-fused-attention.py:715: TypeError
====================================================================================================== short test summary info ======================
FAILED 06-fused-attention.py::test_op_fwd[False-4-48-4096-64] - AssertionError: assert False
FAILED 06-fused-attention.py::test_op_bwd[4-48-1024-64] - TypeError: only integer tensors of a single element can be converted to an index
FAILED 06-fused-attention.py::test_op_bwd[4-48-2048-64] - TypeError: only integer tensors of a single element can be converted to an index
FAILED 06-fused-attention.py::test_op_bwd[4-48-4096-64] - TypeError: only integer tensors of a single element can be converted to an index
FAILED 06-fused-attention.py::test_op_bwd[1-16-8192-64] - TypeError: only integer tensors of a single element can be converted to an index
============================================================================================== 5 failed, 5 passed in 192.25s (0:03:12) =============

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

ROCk module is loaded
=====================    
HSA System Attributes    
=====================    
Runtime Version:         1.1
System Timestamp Freq.:  1000.000000MHz
Sig. Max Wait Duration:  18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model:           LARGE                              
System Endianness:       LITTLE                             

==========               
HSA Agents               
==========               
*******                  
Agent 1                  
*******                  
  Name:                    Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
  Uuid:                    CPU-XX                             
  Marketing Name:          Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
  Vendor Name:             CPU                                
  Feature:                 None specified                     
  Profile:                 FULL_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        0(0x0)                             
  Queue Min Size:          0(0x0)                             
  Queue Max Size:          0(0x0)                             
  Queue Type:              MULTI                              
  Node:                    0                                  
  Device Type:             CPU                                
  Cache Info:              
    L1:                      32768(0x8000) KB                   
  Chip ID:                 0(0x0)                             
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   3300                               
  BDFID:                   0                                  
  Internal Node ID:        0                                  
  Compute Unit:            24                                 
  SIMDs per CU:            0                                  
  Shader Engines:          0                                  
  Shader Arrs. per Eng.:   0                                  
  WatchPts on Addr. Ranges:1                                  
  Features:                None
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: FINE GRAINED        
      Size:                    32582668(0x1f12c0c) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: KERNARG, FINE GRAINED
      Size:                    32582668(0x1f12c0c) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 3                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    32582668(0x1f12c0c) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
  ISA Info:                
*******                  
Agent 2                  
*******                  
  Name:                    Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
  Uuid:                    CPU-XX                             
  Marketing Name:          Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
  Vendor Name:             CPU                                
  Feature:                 None specified                     
  Profile:                 FULL_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        0(0x0)                             
  Queue Min Size:          0(0x0)                             
  Queue Max Size:          0(0x0)                             
  Queue Type:              MULTI                              
  Node:                    1                                  
  Device Type:             CPU                                
  Cache Info:              
    L1:                      32768(0x8000) KB                   
  Chip ID:                 0(0x0)                             
  ASIC Revision:           0(0x0)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   3300                               
  BDFID:                   0                                  
  Internal Node ID:        1                                  
  Compute Unit:            24                                 
  SIMDs per CU:            0                                  
  Shader Engines:          0                                  
  Shader Arrs. per Eng.:   0                                  
  WatchPts on Addr. Ranges:1                                  
  Features:                None
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: FINE GRAINED        
      Size:                    49496132(0x2f34044) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 2                   
      Segment:                 GLOBAL; FLAGS: KERNARG, FINE GRAINED
      Size:                    49496132(0x2f34044) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
    Pool 3                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    49496132(0x2f34044) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       TRUE                               
  ISA Info:                
*******                  
Agent 3                  
*******                  
  Name:                    gfx90a                             
  Uuid:                    GPU-a3f49e3629c87732               
  Marketing Name:          AMD Instinct MI210                 
  Vendor Name:             AMD                                
  Feature:                 KERNEL_DISPATCH                    
  Profile:                 BASE_PROFILE                       
  Float Round Mode:        NEAR                               
  Max Queue Number:        128(0x80)                          
  Queue Min Size:          64(0x40)                           
  Queue Max Size:          131072(0x20000)                    
  Queue Type:              MULTI                              
  Node:                    2                                  
  Device Type:             GPU                                
  Cache Info:              
    L1:                      16(0x10) KB                        
    L2:                      8192(0x2000) KB                    
  Chip ID:                 29711(0x740f)                      
  ASIC Revision:           1(0x1)                             
  Cacheline Size:          64(0x40)                           
  Max Clock Freq. (MHz):   1700                               
  BDFID:                   1792                               
  Internal Node ID:        2                                  
  Compute Unit:            104                                
  SIMDs per CU:            4                                  
  Shader Engines:          8                                  
  Shader Arrs. per Eng.:   1                                  
  WatchPts on Addr. Ranges:4                                  
  Features:                KERNEL_DISPATCH 
  Fast F16 Operation:      TRUE                               
  Wavefront Size:          64(0x40)                           
  Workgroup Max Size:      1024(0x400)                        
  Workgroup Max Size per Dimension:
    x                        1024(0x400)                        
    y                        1024(0x400)                        
    z                        1024(0x400)                        
  Max Waves Per CU:        32(0x20)                           
  Max Work-item Per CU:    2048(0x800)                        
  Grid Max Size:           4294967295(0xffffffff)             
  Grid Max Size per Dimension:
    x                        4294967295(0xffffffff)             
    y                        4294967295(0xffffffff)             
    z                        4294967295(0xffffffff)             
  Max fbarriers/Workgrp:   32                                 
  Pool Info:               
    Pool 1                   
      Segment:                 GLOBAL; FLAGS: COARSE GRAINED      
      Size:                    67092480(0x3ffc000) KB             
      Allocatable:             TRUE                               
      Alloc Granule:           4KB                                
      Alloc Alignment:         4KB                                
      Accessible by all:       FALSE                              
    Pool 2                   
      Segment:                 GROUP                              
      Size:                    64(0x40) KB                        
      Allocatable:             FALSE                              
      Alloc Granule:           0KB                                
      Alloc Alignment:         0KB                                
      Accessible by all:       FALSE                              
  ISA Info:                
    ISA 1                    
      Name:                    amdgcn-amd-amdhsa--gfx90a:sramecc+:xnack-
      Machine Models:          HSA_MACHINE_MODEL_LARGE            
      Profiles:                HSA_PROFILE_BASE                   
      Default Rounding Mode:   NEAR                               
      Default Rounding Mode:   NEAR                               
      Fast f16:                TRUE                               
      Workgroup Max Size:      1024(0x400)                        
      Workgroup Max Size per Dimension:
        x                        1024(0x400)                        
        y                        1024(0x400)                        
        z                        1024(0x400)                        
      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        4294967295(0xffffffff)             
        y                        4294967295(0xffffffff)             
        z                        4294967295(0xffffffff)             
      FBarrier Max Size:       32                                 
*** Done ***             

Additional Information

No response

@zhangxiao-stack zhangxiao-stack changed the title [Issue]: triton flash atten module generates wrong results triton flash atten module generates wrong results Feb 20, 2024
@zhanglx13
Copy link

I tried your code on MI250 with rocm5.6 and test_op_fwd can pass after I change best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage) to best_config = _attn_fwd.get_best_config(). bwd tests did not work.
It seems that this is a very old version of flash-attention kernel from the repo. Can you try the latest one on the tip of the triton-mlir branch?

@zhangxiao-stack
Copy link
Author

@zhanglx13 thanks for your reply, I will try the latest one

@zhangxiao-stack
Copy link
Author

@zhanglx13 I update the latest on ,the results seems ok, Additionally, if I integrate Triton Flash Attention with VLLM, which Triton FA code should I use?

@zhanglx13
Copy link

Let's ask @vgokhale for help

@vgokhale
Copy link
Collaborator

You can use this one. Thanks!

@zhangxiao-stack
Copy link
Author

@vgokhale thanks for your reply,I just tested this code(python/perf-kernels/flash-attention.py) and it generates core dumped

fused-attention-fwd-d128-causal=False-bias=False:
    BATCH     H    N_CTX     Triton
0    16.0  16.0   1024.0  61.429768
1     8.0  16.0   2048.0  68.134906
2     4.0  16.0   4096.0  71.954407
3     2.0  16.0   8192.0  73.787859
4     1.0  16.0  16384.0  74.633200
5     2.0  48.0   1024.0  58.086475
6     2.0  48.0   2048.0  66.287027
7     2.0  48.0   4096.0  71.869249
8     2.0  48.0   8192.0  73.561685
9     2.0  48.0  16384.0  73.963175
10    8.0  16.0   1989.0  64.129571
11    4.0  16.0   4097.0  65.608565
12    2.0  16.0   8122.0  72.668382
13    1.0  16.0  16281.0  73.702959
14    2.0  48.0   1021.0  55.679723
15    2.0  48.0   2001.0  62.954609
16    2.0  48.0   3996.0  68.704991
17    2.0  48.0   8181.0  73.365176

fused-attention-fwd-d128-causal=False-bias=True:
    BATCH     H    N_CTX     Triton
0    16.0  16.0   1024.0  36.203390
1     8.0  16.0   2048.0  46.371167
2     4.0  16.0   4096.0  52.490864
3     2.0  16.0   8192.0  55.326495
4     1.0  16.0  16384.0  56.978791
5     2.0  48.0   1024.0  31.577685
6     2.0  48.0   2048.0  43.559168
7     2.0  48.0   4096.0  52.591746
8     2.0  48.0   8192.0  55.379588
9     2.0  48.0  16384.0  57.056650
10    8.0  16.0   1989.0  38.848884
11    4.0  16.0   4097.0  44.244679
12    2.0  16.0   8122.0  53.214392
13    1.0  16.0  16281.0  55.554325
14    2.0  48.0   1021.0  27.077047
15    2.0  48.0   2001.0  36.845493
16    2.0  48.0   3996.0  47.236803
17    2.0  48.0   8181.0  54.186119
python3: /home/runner/work/triton/triton/llvm-project/mlir/lib/Analysis/SliceAnalysis.cpp:109: void getBackwardSliceImpl(mlir::Operation *, SetVector<mlir::Operation *> *, mlir::BackwardSliceOptions): Assertion `parentOp->getNumRegions() == 1 && parentOp->getRegion(0).getBlocks().size() == 1' failed.
Aborted (core dumped)

@vgokhale
Copy link
Collaborator

This is a bit weird - it completes the benchmark and core dumps somewhere after. These numbers are also pretty low, even for MI210. Few points:

  1. Is it possible to try with a more recent ROCm version / docker image? 5.5 is quite old.
  2. When you git fetch'ed triton-mlir to use latest, I assume you rebuilt Triton? If not, that would be good to do.
  3. This version is a little bit in flux this week. If possible, can you wait while we push all updates? Once we push the new kernel (sometime this week), you'll be working with new code so fixing issues in this one, if any, may not make sense.

@zhangxiao-stack
Copy link
Author

@vgokhale thanks for your reply, I will update the ROCM Version and try, and Is there a Flash-Decoding algorithm implemented based on Triton?

@vgokhale
Copy link
Collaborator

Hi @zhangxiao-stack,

flash-attention.py should now run without the core dump you faced above. Can you try with triton-mlir latest?

Re. flash-decoding, we are working on this. We expect a first version this week.

https://github.com/ROCm/triton/pull/492/files

@zhangxiao-stack
Copy link
Author

Hi, @vgokhale
I git fetch latest triton-mlir and rebuilt triton, flash-attention.py run without the core dump, but the number are pretty low

fused-attention-fwd-d128-causal=False:
    BATCH     H  N_CTX_Q  N_CTX_K     Triton
0    16.0  16.0   1024.0   1024.0  79.214163
1     8.0  16.0   2048.0   2048.0  86.808619
2     4.0  16.0   4096.0   4096.0  90.950733
3     2.0  16.0   8192.0   8192.0  93.136436
4     1.0  16.0  16384.0  16384.0  93.901165
5     2.0  48.0   1024.0   1024.0  75.596109
6     2.0  48.0   2048.0   1024.0  76.366666
7     2.0  48.0   4096.0   8192.0  92.684200
8     2.0  48.0   8192.0   4096.0  90.778724
9     2.0  48.0  16384.0   8192.0  92.719959
10    8.0  16.0   1989.0  15344.0  90.903315
11    4.0  16.0   4097.0    163.0  34.711483
12    2.0  16.0   8122.0   2159.0  85.837150
13    1.0  16.0  16281.0      7.0   2.098211
14    2.0  48.0   1021.0   1020.0  73.826180
15    2.0  48.0   2001.0   2048.0  81.380739
16    2.0  48.0   3996.0   9639.0  90.536361
17    2.0  48.0   8181.0   1021.0  78.771301
fused-attention-fwd-d128-causal=True:
    BATCH     H  N_CTX_Q  N_CTX_K     Triton
0    16.0  16.0   1024.0   1024.0  12.886308
1     8.0  16.0   2048.0   2048.0  13.848956
2     4.0  16.0   4096.0   4096.0  14.383211
3     2.0  16.0   8192.0   8192.0  14.646250
4     1.0  16.0  16384.0  16384.0  14.777015
5     2.0  48.0   1024.0   1024.0  12.670611
6     2.0  48.0   2048.0   1024.0  25.142167
7     2.0  48.0   4096.0   8192.0   9.854398
8     2.0  48.0   8192.0   4096.0  28.748776
9     2.0  48.0  16384.0   8192.0  29.420537
10    8.0  16.0   1989.0  15344.0   7.728812
11    4.0  16.0   4097.0    163.0  38.106455
12    2.0  16.0   8122.0   2159.0  49.793807
13    1.0  16.0  16281.0      7.0   1.635808
14    2.0  48.0   1021.0   1020.0  12.503968
15    2.0  48.0   2001.0   2048.0  13.502583
16    2.0  48.0   3996.0   9639.0   9.313698
17    2.0  48.0   8181.0   1021.0  91.321226
fused-attention-varlen-fwd-d128:
    BATCH     HQ    HK    N_CTX     Triton
0     2.0   16.0   4.0   1024.0   5.716963
1     8.0   16.0   2.0   2048.0   4.230196
2     4.0   16.0   8.0   4096.0  16.136856
3     2.0   16.0   4.0   8192.0  13.651264
4     2.0   16.0   8.0  16384.0  82.597363
5     2.0   48.0  12.0   1024.0  17.080724
6     2.0   48.0  24.0   2048.0  17.039837
7     2.0   48.0   8.0   4096.0  55.623457
8     2.0   48.0   4.0   8192.0  79.230780
9     2.0   48.0   2.0  16384.0  87.740621
10    2.0   64.0  32.0   1024.0  22.524289
11    4.0   64.0  16.0   2048.0  33.392685
12    4.0   64.0   8.0   4096.0  37.548845
13    4.0   64.0  32.0   8192.0  66.148101
14    4.0  128.0  16.0  16384.0  81.158887

1 similar comment
@zhangxiao-stack
Copy link
Author

Hi, @vgokhale
I git fetch latest triton-mlir and rebuilt triton, flash-attention.py run without the core dump, but the number are pretty low

fused-attention-fwd-d128-causal=False:
    BATCH     H  N_CTX_Q  N_CTX_K     Triton
0    16.0  16.0   1024.0   1024.0  79.214163
1     8.0  16.0   2048.0   2048.0  86.808619
2     4.0  16.0   4096.0   4096.0  90.950733
3     2.0  16.0   8192.0   8192.0  93.136436
4     1.0  16.0  16384.0  16384.0  93.901165
5     2.0  48.0   1024.0   1024.0  75.596109
6     2.0  48.0   2048.0   1024.0  76.366666
7     2.0  48.0   4096.0   8192.0  92.684200
8     2.0  48.0   8192.0   4096.0  90.778724
9     2.0  48.0  16384.0   8192.0  92.719959
10    8.0  16.0   1989.0  15344.0  90.903315
11    4.0  16.0   4097.0    163.0  34.711483
12    2.0  16.0   8122.0   2159.0  85.837150
13    1.0  16.0  16281.0      7.0   2.098211
14    2.0  48.0   1021.0   1020.0  73.826180
15    2.0  48.0   2001.0   2048.0  81.380739
16    2.0  48.0   3996.0   9639.0  90.536361
17    2.0  48.0   8181.0   1021.0  78.771301
fused-attention-fwd-d128-causal=True:
    BATCH     H  N_CTX_Q  N_CTX_K     Triton
0    16.0  16.0   1024.0   1024.0  12.886308
1     8.0  16.0   2048.0   2048.0  13.848956
2     4.0  16.0   4096.0   4096.0  14.383211
3     2.0  16.0   8192.0   8192.0  14.646250
4     1.0  16.0  16384.0  16384.0  14.777015
5     2.0  48.0   1024.0   1024.0  12.670611
6     2.0  48.0   2048.0   1024.0  25.142167
7     2.0  48.0   4096.0   8192.0   9.854398
8     2.0  48.0   8192.0   4096.0  28.748776
9     2.0  48.0  16384.0   8192.0  29.420537
10    8.0  16.0   1989.0  15344.0   7.728812
11    4.0  16.0   4097.0    163.0  38.106455
12    2.0  16.0   8122.0   2159.0  49.793807
13    1.0  16.0  16281.0      7.0   1.635808
14    2.0  48.0   1021.0   1020.0  12.503968
15    2.0  48.0   2001.0   2048.0  13.502583
16    2.0  48.0   3996.0   9639.0   9.313698
17    2.0  48.0   8181.0   1021.0  91.321226
fused-attention-varlen-fwd-d128:
    BATCH     HQ    HK    N_CTX     Triton
0     2.0   16.0   4.0   1024.0   5.716963
1     8.0   16.0   2.0   2048.0   4.230196
2     4.0   16.0   8.0   4096.0  16.136856
3     2.0   16.0   4.0   8192.0  13.651264
4     2.0   16.0   8.0  16384.0  82.597363
5     2.0   48.0  12.0   1024.0  17.080724
6     2.0   48.0  24.0   2048.0  17.039837
7     2.0   48.0   8.0   4096.0  55.623457
8     2.0   48.0   4.0   8192.0  79.230780
9     2.0   48.0   2.0  16384.0  87.740621
10    2.0   64.0  32.0   1024.0  22.524289
11    4.0   64.0  16.0   2048.0  33.392685
12    4.0   64.0   8.0   4096.0  37.548845
13    4.0   64.0  32.0   8192.0  66.148101
14    4.0  128.0  16.0  16384.0  81.158887

@vgokhale
Copy link
Collaborator

Hmm, these look expected on a MI210. What baseline are you comparing with to compare these as low?

@zhangxiao-stack
Copy link
Author

Hi @vgokhale . Sorry for the late reply, I got it wrong
Second, below are flash decoding results on an MI210:

fused-attention-d128-fwd-causal=False:
        B   Mq       Mkv    Hq  Hkv      K      Triton
0   256.0  1.0     256.0  16.0  1.0  128.0  593.210161
1   128.0  1.0     512.0  16.0  1.0  128.0  588.659704
2    64.0  1.0    1024.0  16.0  1.0  128.0  582.193434
3    32.0  1.0    2048.0  16.0  1.0  128.0  578.703046
4    16.0  1.0    4096.0  16.0  1.0  128.0  581.533909
5     8.0  1.0    8192.0  16.0  1.0  128.0  581.341326
6     4.0  1.0   16384.0  16.0  1.0  128.0  590.296090
7     2.0  1.0   32768.0  16.0  1.0  128.0  580.572486
8     1.0  1.0   65536.0  16.0  1.0  128.0  582.854867
9     1.0  1.0  131072.0  16.0  1.0  128.0  580.364287
10  256.0  1.0     256.0  16.0  2.0  128.0  585.848987
11  128.0  1.0     512.0  16.0  2.0  128.0  583.298564
12   64.0  1.0    1024.0  16.0  2.0  128.0  579.619348
13   32.0  1.0    2048.0  16.0  2.0  128.0  587.042093
14   16.0  1.0    4096.0  16.0  2.0  128.0  581.362069
15    8.0  1.0    8192.0  16.0  2.0  128.0  586.579323
16    4.0  1.0   16384.0  16.0  2.0  128.0  587.934375
17    2.0  1.0   32768.0  16.0  2.0  128.0  586.112440
18    1.0  1.0   65536.0  16.0  2.0  128.0  588.137567
19    1.0  1.0  131072.0  16.0  2.0  128.0  585.379064

Does this result look ok?

@vgokhale
Copy link
Collaborator

vgokhale commented Mar 5, 2024

What script are you using for flash decoding? I don't think we have one checked in at top of triton-mlir branch yet.

@zhangxiao-stack
Copy link
Author

@zhangxiao-stack
Copy link
Author

@vgokhale hi, python/perf-kernels/06-attention-decode.py generates wrong results now

Mismatched elements: 8192 / 8192 (100.0%)
Greatest absolute difference: 0.4072265625 at index (0, 9, 0, 61) (up to 0.021 allowed)
Greatest relative difference: 0.407958984375 at index (0, 9, 0, 111) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[2-1-32768-16-1-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 4096 / 4096 (100.0%)
Greatest absolute difference: 0.8779296875 at index (0, 9, 0, 59) (up to 0.021 allowed)
Greatest relative difference: 0.87353515625 at index (0, 9, 0, 33) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-65536-16-1-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 0.99853515625 at index (0, 3, 0, 121) (up to 0.021 allowed)
Greatest relative difference: 0.99365234375 at index (0, 3, 0, 28) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-131072-16-1-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.0048828125 at index (0, 0, 0, 58) (up to 0.021 allowed)
Greatest relative difference: 1.0 at index (0, 0, 0, 0) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[8-1-8192-16-2-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 16384 / 16384 (100.0%)
Greatest absolute difference: 0.0791015625 at index (6, 11, 0, 55) (up to 0.021 allowed)
Greatest relative difference: 0.080810546875 at index (6, 11, 0, 55) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[4-1-16384-16-2-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 8192 / 8192 (100.0%)
Greatest absolute difference: 0.4091796875 at index (3, 1, 0, 122) (up to 0.021 allowed)
Greatest relative difference: 0.408447265625 at index (3, 1, 0, 97) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[2-1-32768-16-2-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 4096 / 4096 (100.0%)
Greatest absolute difference: 0.87939453125 at index (0, 9, 0, 55) (up to 0.021 allowed)
Greatest relative difference: 0.87646484375 at index (0, 9, 0, 90) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-65536-16-2-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.0 at index (0, 3, 0, 89) (up to 0.021 allowed)
Greatest relative difference: 0.994140625 at index (0, 9, 0, 2) (up to 0 allowed)
FAILED decoding1.py::test_op_fwd_int4_kv[1-1-131072-16-2-128] - AssertionError: Tensor-likes are not close!

Mismatched elements: 2048 / 2048 (100.0%)
Greatest absolute difference: 1.005859375 at index (0, 11, 0, 85) (up to 0.021 allowed)
Greatest relative difference: 1.0 at index (0, 0, 0, 0) (up to 0 allowed)
============================================================================================= 19 failed, 22 passed in 69.18s (0:01:09

@vgokhale
Copy link
Collaborator

Hi @scxiao, since you have sent a PR, I imagine it passes all unit tests?

@zhangxiao-stack
Copy link
Author

zhangxiao-stack commented Mar 20, 2024

@vgokhale hi, After making modifications based on the pull request #541
1、Official python/perf-kernels/flash-attention.py (unit test: test_op_bwd)fails with big absolute differences,

pytest  flash-attention.py -k test_op_bwd
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
                         [(4, 48, 1024, 64),
                          (4, 48, 2048, 64),
                          (4, 48, 4096, 64),

FAILED flash-attention.py::test_op_bwd[4-48-1024-64] - AssertionError: Tensor-likes are not close!

Mismatched elements: 8701164 / 12582912 (69.2%)
Greatest absolute difference: 6.38818359375 at index (3, 40, 0, 19) (up to 0.05 allowed)
Greatest relative difference: inf at index (2, 28, 26, 36) (up to 0 allowed)
FAILED flash-attention.py::test_op_bwd[4-48-2048-64] - AssertionError: Tensor-likes are not close!

Mismatched elements: 12801585 / 25165824 (50.9%)
Greatest absolute difference: 7.03271484375 at index (2, 16, 0, 6) (up to 0.05 allowed)
Greatest relative difference: inf at index (2, 39, 1344, 41) (up to 0 allowed)
FAILED flash-attention.py::test_op_bwd[4-48-4096-64] - AssertionError: Tensor-likes are not close!

Mismatched elements: 6034803 / 50331648 (12.0%)
Greatest absolute difference: 5.860401153564453 at index (1, 25, 0, 57) (up to 0.05 allowed)
Greatest relative difference: inf at index (1, 0, 48, 1) (up to 0 allowed)

2、Official python/perf-kernels/flash-attention.py(unit test: test_op_bwd) fails at the (1, 16, 8192, 64) with memory access fault .

pytest  flash-attention.py -k test_op_bwd
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
(1, 16, 8192, 64),

flash-attention.py Memory access fault by GPU node-10 (Agent handle: 0x6e2de00) on address 0x7eff6bdeb000. Reason: Unknown.

@vgokhale
Copy link
Collaborator

For FA bwd kernel, please use python/tutorials/06-fused-attention.py.

We are currently working on supporting bwd in the perf-kernels folder - until then the tutorials folder is the right one to use for bwd.

@jerryyin
Copy link
Member

jerryyin commented Jun 5, 2024

Closing due to no new updates.

@jerryyin jerryyin closed this as completed Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants