Skip to content

Commit

Permalink
decast fp8 for ref input, use fp16 as input
Browse files Browse the repository at this point in the history
fix minor things

match readme

decast fp8 for ref input, use fp16 as input
  • Loading branch information
micmelesse committed Feb 25, 2025
1 parent f4463e5 commit 5df3f94
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 123 deletions.
3 changes: 0 additions & 3 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ Inside the docker, it should open to the flash attention repo with everything in
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.

##### Credits
AMD Triton kernels team

Expand Down
216 changes: 97 additions & 119 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flash_attn import flash_attn_func, flash_attn_varlen_func

from .utils import DEBUG, DEBUG_TRITON, DEBUG_TRITON_DETAIL, MetaData, \
cast_to_fp8, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8
cast_to_fp8, decast_fp8, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8
from .interface_torch import attention_prefill, attention_decode
from .fwd_ref import attention_forward_pytorch_ref_impl
from .fwd_prefill import attention_prefill_forward_triton_impl
Expand Down Expand Up @@ -809,8 +809,8 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
(4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('dropout_p', [0.0, 0.1])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT):
Expand All @@ -820,43 +820,24 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
alibi_slopes = None
deterministic = False
layout = "bshd"
fp8_dtype = torch.float8_e4m3fnuz
ref_dtype = torch.float16

q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
if DEBUG_INPUT:
do = torch.ones_like(q)
else:
do = torch.randn_like(q)

# ref forward pass
# NOTE: bfp16 is not supported by atomic ops. we have to use fp16
q_fp16 = q.clone().to(torch.float16).requires_grad_()
k_fp16 = k.clone().to(torch.float16).requires_grad_()
v_fp16 = v.clone().to(torch.float16).requires_grad_()
do_fp16 = do.clone().to(torch.float16)
out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_func(
q_fp16,
k_fp16,
v_fp16,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_fp16, dk_fp16, dv_fp16 = torch.autograd.grad(out_fp16, (q_fp16, k_fp16, v_fp16), do_fp16)

# ----------------------------------------------------------------
# --- FP8 ---
# ----------------------------------------------------------------
# cast to fp8
q_fp8, descale_q= cast_to_fp8(q, torch.float8_e4m3fnuz, layout)
k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, layout)
v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, layout)
do_fp8, descale_do = cast_to_fp8(do, torch.float8_e4m3fnuz, layout)
q_fp8, descale_q= cast_to_fp8(q, fp8_dtype, layout)
k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, layout)
v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, layout)
do_fp8, descale_do = cast_to_fp8(do, fp8_dtype, layout)

# fp8 forward pass
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func(
Expand All @@ -879,42 +860,66 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)

# ----------------------------------------------------------------
# --- Reference ---
# ----------------------------------------------------------------
q_ref = decast_fp8(q_fp8, descale_q, ref_dtype, layout)
k_ref = decast_fp8(k_fp8, descale_k, ref_dtype, layout)
v_ref = decast_fp8(v_fp8, descale_v, ref_dtype, layout)
do_ref = decast_fp8(do_fp8, descale_do, ref_dtype, layout)
out_ref, lse_ref, S_dmask_ref = flash_attn_func(
q_ref,
k_ref,
v_ref,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref)


# compare forward
if DEBUG:
print()
print("Compare fp8 against ref")

if DEBUG:
print("out_fp16:", out_fp16, out_fp16.shape)
print("out_ref:", out_ref, out_ref.shape)
print("out_fp8:", out_fp8, out_fp8.shape)
torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(out_ref, out_fp8.to(ref_dtype), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("lse_fp16:", lse_fp16, lse_fp16.shape)
print("lse_ref:", lse_ref, lse_ref.shape)
print("lse_fp8:", lse_fp8, lse_fp8.shape)
torch.testing.assert_close(lse_fp16.to(torch.float32), lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(lse_ref, lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None )
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# if DEBUG:
# print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape if S_dmask_ref is not None else None )
# print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_ref is not None else None)
# torch.testing.assert_close(S_dmask_ref if S_dmask_ref is not None else None, S_dmask_fp8.to(ref_dtype) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# compare backward
if DEBUG:
print("dv_fp16:", dv_fp16, dv_fp16.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_fp8:", dv_fp8, dv_fp8.shape)
torch.testing.assert_close(dv_fp16.to(torch.float32), dv_fp8.to(torch.float32),
torch.testing.assert_close(dv_ref, dv_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dk_fp16:", dk_fp16, dk_fp16.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_fp8:", dk_fp8, dk_fp8.shape)
torch.testing.assert_close(dk_fp16.to(torch.float32), dk_fp8.to(torch.float32),
torch.testing.assert_close(dk_ref, dk_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dq_fp16:", dq_fp16, dq_fp16.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_fp8:", dq_fp8, dq_fp8.shape)
torch.testing.assert_close(dq_fp16.to(torch.float32), dq_fp8.to(torch.float32),
torch.testing.assert_close(dq_ref, dq_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)


Expand Down Expand Up @@ -1022,47 +1027,23 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
alibi_slopes = None
deterministic = False
layout = "thd"
fp8_dtype = torch.float8_e4m3fnuz
ref_dtype = torch.float16

q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, device=device, DEBUG_INPUT=DEBUG_INPUT) # failure due to small seqlen tensors
# q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, device=device, equal_seqlens=True, DEBUG_INPUT=DEBUG_INPUT) # most cases pass
q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, ref_dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
if DEBUG_INPUT:
do = torch.ones_like(q)
else:
do = torch.randn_like(q)

# launch kernel in fp16
q_fp16 = q.clone().to(torch.float16)
k_fp16 = k.clone().to(torch.float16)
v_fp16 = v.clone().to(torch.float16)
do_fp16 = do.clone().to(torch.float16)
out_fp16, lse_fp16, S_dmask_fp16 = flash_attn_varlen_func(
q_fp16,
k_fp16,
v_fp16,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_fp16, dk_fp16, dv_fp16 = torch.autograd.grad(out_fp16, (q_fp16, k_fp16, v_fp16), do_fp16)

# ----------------------------------------------------------------
# --- FP8 ---
# ----------------------------------------------------------------
# cast to fp8
q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_q)
k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_k)
v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_k)
do_fp8, descale_do = cast_to_fp8(do, torch.float8_e4m3fnuz, layout, cu_seqlens=metadata.cu_seqlens_q)
q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
do_fp8, descale_do = cast_to_fp8(do, fp8_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)

# launch kernel in fp8
out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func(
Expand All @@ -1089,75 +1070,72 @@ def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, drop
# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)


# ----------------------------------------------------------------
# --- Reference ---
# ----------------------------------------------------------------
q_ref = decast_fp8(q_fp8, descale_q, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
k_ref = decast_fp8(k_fp8, descale_k, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
v_ref = decast_fp8(v_fp8, descale_v, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_k)
do_ref = decast_fp8(do_fp8, descale_do, ref_dtype, layout, cu_seqlens=metadata.cu_seqlens_q)
out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func(
q_ref,
k_ref,
v_ref,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# ref backward pass
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref)

# compare forward
if DEBUG:
print()
print("Compare fp8 against ref")

if DEBUG:
print("out_fp16:", out_fp16, out_fp16.shape)
print("out_ref:", out_ref, out_ref.shape)
print("out_fp8:", out_fp8, out_fp8.shape)
torch.testing.assert_close(out_fp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(out_ref, out_fp8.to(ref_dtype), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("lse_fp16:", lse_fp16, lse_fp16.shape)
print("lse_ref:", lse_ref, lse_ref.shape)
print("lse_fp8:", lse_fp8, lse_fp8.shape)
torch.testing.assert_close(lse_fp16.to(torch.float32), lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)
torch.testing.assert_close(lse_ref, lse_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8)

if DEBUG:
print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None )
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)
# if DEBUG:
# print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape if S_dmask_ref is not None else None )
# print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_ref is not None else None)
# torch.testing.assert_close(S_dmask_ref if S_dmask_ref is not None else None, S_dmask_fp8.to(ref_dtype) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)

# compare backward
if DEBUG:
print("dv_fp16:", dv_fp16, dv_fp16.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_fp8:", dv_fp8, dv_fp8.shape)
print("metadata.cu_seqlens_k:", metadata.cu_seqlens_k)
config = (Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD)
# vis_close(
# dv_fp16.to(torch.float32),
# dv_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dv_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_k
# )
torch.testing.assert_close(dv_fp16.to(torch.float32), dv_fp8.to(torch.float32),
torch.testing.assert_close(dv_ref, dv_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dk_fp16:", dk_fp16, dk_fp16.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_fp8:", dk_fp8, dk_fp8.shape)
# vis_close(
# dk_fp16.to(torch.float32),
# dk_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dk_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_k
# )
torch.testing.assert_close(dk_fp16.to(torch.float32), dk_fp8.to(torch.float32),
torch.testing.assert_close(dk_ref, dk_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)

if DEBUG:
print("dq_fp16:", dq_fp16, dq_fp16.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_fp8:", dq_fp8, dq_fp8.shape)
# vis_close(
# dq_fp16.to(torch.float32),
# dq_fp8.to(torch.float32),
# atol=ATOL_fp8,
# rtol=RTOL_fp8,
# equal_nan=EQUAL_NAN,
# layout="thd",
# out_file=f"dq_close_{config}.png",
# cu_seqlens=metadata.cu_seqlens_q
# )
torch.testing.assert_close(dq_fp16.to(torch.float32), dq_fp8.to(torch.float32),
torch.testing.assert_close(dq_ref, dq_fp8.to(ref_dtype),
atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN)


Expand Down
26 changes: 26 additions & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import math
from typing import Optional
import torch
import os
import random
Expand Down Expand Up @@ -376,6 +377,31 @@ def cast_to_fp8(
else:
raise ValueError(f"Unknown layout: {layout}")

def decast_fp8(x_fp8: torch.Tensor,
descale_factor: torch.Tensor,
original_dtype: torch.dtype,
layout: str,
cu_seqlens: Optional[torch.Tensor] = None) -> torch.Tensor:
# convert fp8 tensor back to the desired original dtype
x_orig = x_fp8.to(original_dtype)

if layout in ("bshd", "bhsd"):
return x_orig * descale_factor
elif layout == "thd":
if cu_seqlens is None:
raise ValueError("cu_seqlens must be provided for varlen layout ('thd')")
x_out = x_orig.clone()
batch = cu_seqlens.shape[0] - 1
for i in range(batch):
start = int(cu_seqlens[i].item())
end = int(cu_seqlens[i + 1].item())
factor = descale_factor[i].unsqueeze(0).unsqueeze(-1)
x_out[start:end] = x_out[start:end] * factor
return x_out

else:
raise ValueError(f"Unknown layout: {layout}")

# -------------------------------
# Misc
# -------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def validate_and_update_archs(archs):
subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
if IS_ROCM:
if not USE_TRITON_ROCM:
if not SKIP_CK_BUILD:
assert (
os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py")
), "csrc/composable_kernel is missing, please use source distribution or git clone"
Expand Down

0 comments on commit 5df3f94

Please sign in to comment.