Skip to content

Commit

Permalink
decast fp8 for ref input, use fp16 as input
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 25, 2025
1 parent 52b5565 commit 84c3259
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 119 deletions.
216 changes: 97 additions & 119 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flash_attn import flash_attn_func, flash_attn_varlen_func

from .utils import DEBUG, DEBUG_TRITON, DEBUG_TRITON_DETAIL, MetaData, \
cast_to_fp8, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8
cast_to_fp8, decast_fp8, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8
from .interface_torch import attention_prefill, attention_decode
from .fwd_ref import attention_forward_pytorch_ref_impl
from .fwd_prefill import attention_prefill_forward_triton_impl
Expand Down Expand Up @@ -809,8 +809,8 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
(4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.1])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT):
Expand All @@ -820,43 +820,24 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
alibi_slopes = None
deterministic = False
layout = "bshd"
fp8_dtype = torch.float8_e4m3fnuz
ref_dtype = torch.float16

q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
if DEBUG_INPUT:
do = torch.ones_like(q)
else:
do = torch.randn_like(q)

# ref forward pass
# NOTE: bfp16 is not supported by atomic ops. we have to use fp16
q_fp16 = q.clone().to(torch.float16).requires_grad_()
k_fp16 = k.clone().to(torch.float16).requires_grad_()
v_fp16 = v.clone().to(torch.float16).requires_grad_()
do_fp16 = do.clone().to(torch.float16)
out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_func(
q_fp16,
k_fp16,
v_fp16,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_fp16, dk_fp16, dv_fp16 = torch.autograd.grad(out_fp16, (q_fp16, k_fp16, v_fp16), do_fp16)

# ----------------------------------------------------------------
# --- FP8 ---
# ----------------------------------------------------------------
# cast to fp8
q_fp8, descale_q= cast_to_fp8(q, torch.float8_e4m3fnuz, layout)
k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, layout)
v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, layout)
do_fp8, descale_do = cast_to_fp8(do, torch.float8_e4m3fnuz, layout)
q_fp8, descale_q= cast_to_fp8(q, fp8_dtype, layout)
k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, layout)
v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, layout)
do_fp8, descale_do = cast_to_fp8(do, fp8_dtype, layout)

# fp8 forward pass
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func(
Expand All @@ -879,42 +860,66 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)

# ----------------------------------------------------------------
# --- Reference ---
# ----------------------------------------------------------------
q_ref = decast_fp8(q_fp8, descale_q, ref_dtype, layout)
k_ref = decast_fp8(k_fp8, descale_k, ref_dtype, layout)
v_ref = decast_fp8(v_fp8, descale_v, ref_dtype, layout)
do_ref = decast_fp8(do_fp8, descale_do, ref_dtype, layout)
out_ref, lse_ref, S_dmask_ref = flash_attn_func(
q_ref,
k_ref,
v_ref,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref)


# compare forward
if DEBUG:
print()
print("Compare fp8 against ref")

if DEBUG:
print("out_fp16:", out_fp16, out_fp16.shape)
print("out_ref:", out_ref, out_ref.shape)
print("out_fp8:", out_fp8, out_fp8.shape)
torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(out_ref, out_fp8.to(ref_dtype), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("lse_fp16:", lse_fp16, lse_fp16.shape)
print("lse_ref:", lse_ref, lse_ref.shape)
print("lse_fp8:", lse_fp8, lse_fp8.shape)
torch.testing.assert_close(lse_fp16.to(torch.float32), lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(lse_ref, lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None )
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# if DEBUG:
# print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape if S_dmask_ref is not None else None )
# print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_ref is not None else None)
# torch.testing.assert_close(S_dmask_ref if S_dmask_ref is not None else None, S_dmask_fp8.to(ref_dtype) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# compare backward
if DEBUG:
print("dv_fp16:", dv_fp16, dv_fp16.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_fp8:", dv_fp8, dv_fp8.shape)
torch.testing.assert_close(dv_fp16.to(torch.float32), dv_fp8.to(torch.float32),
torch.testing.assert_close(dv_ref, dv_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dk_fp16:", dk_fp16, dk_fp16.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_fp8:", dk_fp8, dk_fp8.shape)
torch.testing.assert_close(dk_fp16.to(torch.float32), dk_fp8.to(torch.float32),
torch.testing.assert_close(dk_ref, dk_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dq_fp16:", dq_fp16, dq_fp16.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_fp8:", dq_fp8, dq_fp8.shape)
torch.testing.assert_close(dq_fp16.to(torch.float32), dq_fp8.to(torch.float32),
torch.testing.assert_close(dq_ref, dq_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)


Expand Down Expand Up @@ -1022,47 +1027,23 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
alibi_slopes = None
deterministic = False
layout = "thd"
fp8_dtype = torch.float8_e4m3fnuz
ref_dtype = torch.float16

q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, device=device, DEBUG_INPUT=DEBUG_INPUT) # failure due to small seqlen tensors
# q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, device=device, equal_seqlens=True, DEBUG_INPUT=DEBUG_INPUT) # most cases pass
q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, ref_dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
if DEBUG_INPUT:
do = torch.ones_like(q)
else:
do = torch.randn_like(q)

# launch kernel in fp16
q_fp16 = q.clone().to(torch.float16)
k_fp16 = k.clone().to(torch.float16)
v_fp16 = v.clone().to(torch.float16)
do_fp16 = do.clone().to(torch.float16)
out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_varlen_func(
q_fp16,
k_fp16,
v_fp16,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_fp16, dk_fp16, dv_fp16 = torch.autograd.grad(out_fp16, (q_fp16, k_fp16, v_fp16), do_fp16)

# ----------------------------------------------------------------
# --- FP8 ---
# ----------------------------------------------------------------
# cast to fp8
q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_q)
k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_k)
v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_k)
do_fp8, descale_do = cast_to_fp8(do, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_q)
q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
do_fp8, descale_do = cast_to_fp8(do, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)

# launch kernel in fp8
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func(
Expand All @@ -1089,75 +1070,72 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)


# ----------------------------------------------------------------
# --- Reference ---
# ----------------------------------------------------------------
q_ref = decast_fp8(q_fp8, descale_q, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
k_ref = decast_fp8(k_fp8, descale_k, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
v_ref = decast_fp8(v_fp8, descale_v, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
do_ref = decast_fp8(do_fp8, descale_do, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func(
q_ref,
k_ref,
v_ref,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref)

# compare forward
if DEBUG:
print()
print("Compare fp8 against ref")

if DEBUG:
print("out_fp16:", out_fp16, out_fp16.shape)
print("out_ref:", out_ref, out_ref.shape)
print("out_fp8:", out_fp8, out_fp8.shape)
torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(out_ref, out_fp8.to(ref_dtype), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("lse_fp16:", lse_fp16, lse_fp16.shape)
print("lse_ref:", lse_ref, lse_ref.shape)
print("lse_fp8:", lse_fp8, lse_fp8.shape)
torch.testing.assert_close(lse_fp16.to(torch.float32), lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(lse_ref, lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None )
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# if DEBUG:
# print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape if S_dmask_ref is not None else None )
# print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_ref is not None else None)
# torch.testing.assert_close(S_dmask_ref if S_dmask_ref is not None else None, S_dmask_fp8.to(ref_dtype) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)

# compare backward
if DEBUG:
print("dv_fp16:", dv_fp16, dv_fp16.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_fp8:", dv_fp8, dv_fp8.shape)
print("metadata.cu_seqlens_k:", metadata.cu_seqlens_k)
config = (Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD)
# vis_close(
# dv_fp16.to(torch.float32),
# dv_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dv_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_k
# )
torch.testing.assert_close(dv_fp16.to(torch.float32), dv_fp8.to(torch.float32),
torch.testing.assert_close(dv_ref, dv_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dk_fp16:", dk_fp16, dk_fp16.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_fp8:", dk_fp8, dk_fp8.shape)
# vis_close(
# dk_fp16.to(torch.float32),
# dk_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dk_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_k
# )
torch.testing.assert_close(dk_fp16.to(torch.float32), dk_fp8.to(torch.float32),
torch.testing.assert_close(dk_ref, dk_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dq_fp16:", dq_fp16, dq_fp16.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_fp8:", dq_fp8, dq_fp8.shape)
# vis_close(
# dq_fp16.to(torch.float32),
# dq_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dq_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_q
# )
torch.testing.assert_close(dq_fp16.to(torch.float32), dq_fp8.to(torch.float32),
torch.testing.assert_close(dq_ref, dq_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)


Expand Down
26 changes: 26 additions & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import math
from typing import Optional
import torch
import os
import random
Expand Down Expand Up @@ -376,6 +377,31 @@ def cast_to_fp8(
else:
raise ValueError(f"Unknown layout: {layout}")

def decast_fp8(x_fp8: torch.Tensor,
descale_factor: torch.Tensor,
original_dtype: torch.dtype,
layout: str,
cu_seqlens: Optional[torch.Tensor] = None) -> torch.Tensor:
# convert fp8 tensor back to the desired original dtype
x_orig = x_fp8.to(original_dtype)

if layout in ("bshd", "bhsd"):
return x_orig * descale_factor
elif layout == "thd":
if cu_seqlens is None:
raise ValueError("cu_seqlens must be provided for varlen layout ('thd')")
x_out = x_orig.clone()
batch = cu_seqlens.shape[0] - 1
for i in range(batch):
start = int(cu_seqlens[i].item())
end = int(cu_seqlens[i + 1].item())
factor = descale_factor[i].unsqueeze(0).unsqueeze(-1)
x_out[start:end] = x_out[start:end] * factor
return x_out

else:
raise ValueError(f"Unknown layout: {layout}")

# -------------------------------
# Misc
# -------------------------------
Expand Down

0 comments on commit 84c3259

Please sign in to comment.