Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support position_ids as input for flash attention #23

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions test/test_flash_attention_varlen_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,209 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [8])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(128, 128),
(1024, 1024),
],
)
@pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d,
dropout_p, causal, local,
alibi, deterministic, mha_type,
dtype):
if d % 8 != 0:
pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")

device = "cuda"
# set seed
torch.random.manual_seed(101)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)

assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else tuple(
torch.randint(0, seqlen_k, (2,)).tolist())
torch.cuda.synchronize()
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
softmax_scale = q.shape[-1]**(-0.5)
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
do = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
rng_state = torch.Tensor([0, 0]).to(torch.int64).to(device)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)

attention_mask = torch.zeros(
batch_size, seqlen_k, dtype=torch.int32).to(device)
k_lengths = torch.randint(low=2, high=seqlen_k, size=(batch_size,))
for i in range(batch_size):
k_len = k_lengths[i].item()
attention_mask[i, :k_len] = 1
q[i, k_len:, :, :] = 0
k[i, k_len:, :, :] = 0
v[i, k_len:, :, :] = 0
do[i, k_len:, :, :] = 0
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True

q_cuda, k_cuda, v_cuda, do_cuda, dq_cuda, dk_cuda, dv_cuda, \
indices_q, indices_k, cu_seq_lens, max_seq_lens = _unpad_input(
q, k, v, do, dq, dk, dv, attention_mask, seqlen_q, nheads
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens

if alibi:
alibi_slopes = torch.rand(
batch_size, nheads, device=device, dtype=torch.float32) * 0.3
else:
alibi_slopes = None

o_cuda, softmax_lse_cuda, _ = flash_attn_varlen_func(
q_cuda.contiguous(),
k_cuda.contiguous(),
v_cuda.contiguous(),
cu_seqlens_q.contiguous(),
cu_seqlens_k.contiguous(),
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
dq_cuda, dk_cuda, dv_cuda, softmax_d_cuda = flash_attn_cuda.varlen_bwd(
do_cuda.contiguous(), q_cuda.contiguous(), k_cuda.contiguous(),
v_cuda.contiguous(), o_cuda.contiguous(), softmax_lse_cuda.contiguous(),
dq_cuda, dk_cuda, dv_cuda, cu_seqlens_q, cu_seqlens_k, alibi_slopes,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
window_size[0], window_size[1], deterministic, None, rng_state)

dq_cuda = pad_input(dq_cuda, indices_q, batch_size, seqlen_q)
dk_cuda = pad_input(dk_cuda, indices_k, batch_size, seqlen_k)
dv_cuda = pad_input(dv_cuda, indices_k, batch_size, seqlen_k)
softmax_d_cuda = softmax_d_cuda[:, :, :seqlen_q]

q = q.cpu().detach()
k = k.cpu().detach()
v = v.cpu().detach()
do = do.cpu().detach()
rng_state = rng_state.cpu().detach()

dq_cuda = dq_cuda.cpu().detach()
dk_cuda = dk_cuda.cpu().detach()
dv_cuda = dv_cuda.cpu().detach()
softmax_d_cuda = softmax_d_cuda.cpu().detach()
if alibi:
alibi_slopes = alibi_slopes.cpu()
torch.cuda.synchronize()

device = ta.lazy_device()
torch.random.manual_seed(101)

indices_q = indices_q.cpu()
indices_k = indices_k.cpu()

q_xla = q.flatten(0, 1)[indices_q].unsqueeze(0).to(device)
k_xla = k.flatten(0, 1)[indices_k].unsqueeze(0).to(device)
v_xla = v.flatten(0, 1)[indices_k].unsqueeze(0).to(device)
do_xla = do.flatten(0, 1)[indices_q].unsqueeze(0).to(device)

def attention_mask_to_position_ids(attention_mask):
seqlens = attention_mask.sum(dim=1).flatten()
position_ids = torch.cat([
torch.arange(
0, seqlen, dtype=torch.int32, device=attention_mask.device)
for seqlen in seqlens
],
dim=0)
return position_ids.unsqueeze(0)

position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device)

rng_state_xla = rng_state.to(device)
if alibi:
alibi_slopes = alibi_slopes.cpu().to(device)

softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward(
q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(),
position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale,
False, causal, window_size[0], window_size[1], True, None)

assert torch.allclose(
q_xla.cpu(),
q_cuda.cpu().unsqueeze(0),
rtol=1e-3,
atol=1e-3,
equal_nan=True)
assert torch.allclose(
k_xla.cpu(),
k_cuda.cpu().unsqueeze(0),
rtol=1e-3,
atol=1e-3,
equal_nan=True)
assert torch.allclose(
v_xla.cpu(),
v_cuda.cpu().unsqueeze(0),
rtol=1e-3,
atol=1e-3,
equal_nan=True)
assert torch.allclose(
cu_seqlen_k_xla[:batch_size + 1].cpu(),
cu_seqlens_k.cpu(),
rtol=1e-3,
atol=1e-3,
equal_nan=True)
assert torch.allclose(
o_xla.cpu(),
o_cuda.unsqueeze(0).cpu(),
rtol=1e-2,
atol=1e-2,
equal_nan=True)

q_xla.requires_grad = True
k_xla.requires_grad = True
v_xla.requires_grad = True
o_xla.requires_grad = True
softmax_lse_xla.requires_grad = True

dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_position_ids_backward(
do_xla.contiguous(), q_xla.contiguous(), k_xla.contiguous(),
v_xla.contiguous(), o_xla.contiguous(), softmax_lse_xla.contiguous(),
cu_seqlen_q_xla, cu_seqlen_k_xla, alibi_slopes, dropout_p, softmax_scale,
False, causal, window_size[0], window_size[1], deterministic, None,
rng_state_xla)

ta.mark_step(wait=True)
torch.cuda.synchronize()

dq_xla = dq_xla.cpu().detach().squeeze(0)
dk_xla = dk_xla.cpu().detach().squeeze(0)
dv_xla = dv_xla.cpu().detach().squeeze(0)
do_xla = do_xla.cpu().detach().squeeze(0)
softmax_d_xla = softmax_d_xla.cpu().detach()

dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q)
dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k)
dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k)

assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
Loading