Skip to content

Commit

Permalink
Merge pull request #5 from ROCm/roll-back-fmha-common
Browse files Browse the repository at this point in the history
roll back fmha/common.py
  • Loading branch information
qianfengz authored Feb 29, 2024
2 parents 14c831e + f654b3a commit 99947ff
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 132 deletions.
131 changes: 3 additions & 128 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,47 +320,6 @@ def T(t):
return out.permute((0, 2, 1, 3))


# this interface assumes the tensor is in BMHK, but q and k/v might have different number of heads
def ref_attention_mqa(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
assert q.ndim == 4

B, M, Hq, K = q.shape
_, N, Hkv, Kv = v.shape
nhead_ratio_qk = Hq // Hkv

def attn_bias_head(head: int):
if isinstance(attn_bias, torch.Tensor):
assert attn_bias.ndim == 4
_, H, _, _ = attn_bias.shape
assert H == Hq
bias_bghmn = attn_bias.reshape(B, Hkv, nhead_ratio_qk, M, N)
return bias_bghmn[:, :, head]
if isinstance(attn_bias, fmha.attn_bias.LowerTriangularMaskWithTensorBias):
assert attn_bias._bias.ndim == 4
_, H, _, _ = attn_bias._bias.shape
assert H == Hq
bias_bghmn = attn_bias._bias.reshape(B, Hkv, nhead_ratio_qk, M, N)
return fmha.attn_bias.LowerTriangularMaskWithTensorBias(
bias_bghmn[:, :, head]
)
return attn_bias

q_bmghk = q.reshape((B, M, Hkv, nhead_ratio_qk, K))

return torch.stack(
[
ref_attention_bmhk(
q_bmghk[:, :, :, h],
k,
v,
attn_bias=attn_bias_head(h),
)
for h in range(q_bmghk.shape[3])
],
dim=3,
).reshape((B, M, Hq, Kv))


def _rand_partition(r: random.Random, total: int, n: int) -> List[int]:
# returns list of n nonnegative integers summing to total
idx = {0, total}
Expand Down Expand Up @@ -571,92 +530,6 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs)
)


@rocm_only
@pytest.mark.parametrize("hdim_k,hdim_v", [(64, 64), (128, 128)])
@pytest.mark.parametrize("nhead_q,nhead_kv", [(8, 1), (8, 2), (12, 4), (4, 4)])
@pytest.mark.parametrize("seqlen_q,seqlen_kv", [(100, 128), (128, 100), (200, 1000)])
@pytest.mark.parametrize("batches", [100, 64, 1])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask]
)
@pytest.mark.parametrize("op", [fmha.ck.FwOp])
def test_mqa_forward(
op,
attn_bias_type,
dtype,
batches: int,
seqlen_kv: int,
seqlen_q: int,
nhead_kv: int,
nhead_q: int,
hdim_v: int,
hdim_k: int,
):
B = batches
M = seqlen_q
N = seqlen_kv
Hq = nhead_q
Hkv = nhead_kv
K = hdim_k
Kv = hdim_v
nhead_ratio_qk = Hq // Hkv

device = torch.device("cuda")

torch.manual_seed(B * M + N * K + Hq * Hkv + Kv)

scale = 3
query = torch.randn((B, M, Hq, K), device=device, dtype=dtype).mul_(scale)
key = torch.randn((B, N, Hkv, K), device=device, dtype=dtype).mul_(scale)
value = torch.randn((B, N, Hkv, Kv), device=device, dtype=dtype).mul_(scale)

attn_bias = None
if attn_bias_type is not None:
attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=Hq,
num_heads_groups=nhead_ratio_qk,
q_len=M,
kv_len=N,
dtype=dtype,
device=device,
requires_grad=False,
fmt="BMHK",
op=op,
)

inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias)
reasons = op.not_supported_reasons(inputs)
if reasons:
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})"
# Ensure we free memory to avoid OOMs
del query, key, value, attn_bias, inputs
assert False, err_msg

out = xformers.ops.memory_efficient_attention_forward(
query, key, value, attn_bias, op=op
)
assert not out.isnan().any(), ("Output has NaNs", attn_bias)
out2 = xformers.ops.memory_efficient_attention_forward(
query, key, value, attn_bias, op=op
)
assert torch.allclose(out, out2, atol=0.0, rtol=0.0), (
"Non-deterministic behavior",
attn_bias,
)

ref = ref_attention_mqa(query, key, value, attn_bias)
assert out.shape == ref.shape, out.shape
assert_allclose(
out.float(),
ref,
atol=op.ERROR_ATOL[dtype],
rtol=op.ERROR_RTOL.get(dtype, 1e-5),
)


@cuda_only
@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
Expand Down Expand Up @@ -2328,7 +2201,9 @@ def test_forward_splitk(

@cuda_only
@pytest.mark.parametrize(
"op", [fmha.triton_splitk.FwOp, fmha.flash.FwOp], ids=lambda op: op.NAME
"op",
[fmha.triton_splitk.FwOp, fmha.flash.FwOp, fmha.ck.FwOp],
ids=lambda op: op.NAME,
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
Expand Down
6 changes: 2 additions & 4 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,11 @@ def validate_inputs(self) -> None:
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
Hkv = self.key.shape[-2]
if self.query.ndim == 4: # BMHK
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, Hkv, key_embed_dim)
and self.value.shape == (B, Mkv, Hkv, Kv)
and H % Hkv == 0
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
Expand Down

0 comments on commit 99947ff

Please sign in to comment.