From 55984646c47f1635675bb3f90ee706384f31d943 Mon Sep 17 00:00:00 2001 From: Tianxing Wang <55205022+shadow150519@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:24:10 +0800 Subject: [PATCH] Support position_ids as input for flash attention (#23) * support position_ids input for flash attention --- test/test_flash_attention_varlen_backward.py | 206 +++++++++ test/test_flash_attention_varlen_forward.py | 222 ++++++++- torch_xla/csrc/flash_attention_utils.cpp | 120 ++++- torch_xla/csrc/flash_attention_utils.h | 17 +- torch_xla/csrc/init_python_bindings.cpp | 104 ++++- .../ops/flash_attention_varlen_forward.cpp | 7 +- ...attention_varlen_position_ids_backward.cpp | 434 ++++++++++++++++++ ...h_attention_varlen_position_ids_backward.h | 38 ++ ..._attention_varlen_position_ids_forward.cpp | 328 +++++++++++++ ...sh_attention_varlen_position_ids_forward.h | 34 ++ torch_xla/csrc/tensor_methods.cpp | 46 ++ torch_xla/csrc/tensor_methods.h | 12 + 12 files changed, 1549 insertions(+), 19 deletions(-) create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py index 1ee13118ee2..2da5985b53d 100644 --- a/test/test_flash_attention_varlen_backward.py +++ b/test/test_flash_attention_varlen_backward.py @@ -249,3 +249,209 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [8]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (128, 128), + (1024, 1024), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, + dropout_p, causal, local, + alibi, deterministic, mha_type, + dtype): + if d % 8 != 0: + pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") + + device = "cuda" + # set seed + torch.random.manual_seed(101) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else tuple( + torch.randint(0, seqlen_k, (2,)).tolist()) + torch.cuda.synchronize() + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + softmax_scale = q.shape[-1]**(-0.5) + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + do = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + rng_state = torch.Tensor([0, 0]).to(torch.int64).to(device) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + attention_mask = torch.zeros( + batch_size, seqlen_k, dtype=torch.int32).to(device) + k_lengths = torch.randint(low=2, high=seqlen_k, size=(batch_size,)) + for i in range(batch_size): + k_len = k_lengths[i].item() + attention_mask[i, :k_len] = 1 + q[i, k_len:, :, :] = 0 + k[i, k_len:, :, :] = 0 + v[i, k_len:, :, :] = 0 + do[i, k_len:, :, :] = 0 + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + + q_cuda, k_cuda, v_cuda, do_cuda, dq_cuda, dk_cuda, dv_cuda, \ + indices_q, indices_k, cu_seq_lens, max_seq_lens = _unpad_input( + q, k, v, do, dq, dk, dv, attention_mask, seqlen_q, nheads + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + + if alibi: + alibi_slopes = torch.rand( + batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + else: + alibi_slopes = None + + o_cuda, softmax_lse_cuda, _ = flash_attn_varlen_func( + q_cuda.contiguous(), + k_cuda.contiguous(), + v_cuda.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + max_seqlen_q, + max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + dq_cuda, dk_cuda, dv_cuda, softmax_d_cuda = flash_attn_cuda.varlen_bwd( + do_cuda.contiguous(), q_cuda.contiguous(), k_cuda.contiguous(), + v_cuda.contiguous(), o_cuda.contiguous(), softmax_lse_cuda.contiguous(), + dq_cuda, dk_cuda, dv_cuda, cu_seqlens_q, cu_seqlens_k, alibi_slopes, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, + window_size[0], window_size[1], deterministic, None, rng_state) + + dq_cuda = pad_input(dq_cuda, indices_q, batch_size, seqlen_q) + dk_cuda = pad_input(dk_cuda, indices_k, batch_size, seqlen_k) + dv_cuda = pad_input(dv_cuda, indices_k, batch_size, seqlen_k) + softmax_d_cuda = softmax_d_cuda[:, :, :seqlen_q] + + q = q.cpu().detach() + k = k.cpu().detach() + v = v.cpu().detach() + do = do.cpu().detach() + rng_state = rng_state.cpu().detach() + + dq_cuda = dq_cuda.cpu().detach() + dk_cuda = dk_cuda.cpu().detach() + dv_cuda = dv_cuda.cpu().detach() + softmax_d_cuda = softmax_d_cuda.cpu().detach() + if alibi: + alibi_slopes = alibi_slopes.cpu() + torch.cuda.synchronize() + + device = ta.lazy_device() + torch.random.manual_seed(101) + + indices_q = indices_q.cpu() + indices_k = indices_k.cpu() + + q_xla = q.flatten(0, 1)[indices_q].unsqueeze(0).to(device) + k_xla = k.flatten(0, 1)[indices_k].unsqueeze(0).to(device) + v_xla = v.flatten(0, 1)[indices_k].unsqueeze(0).to(device) + do_xla = do.flatten(0, 1)[indices_q].unsqueeze(0).to(device) + + def attention_mask_to_position_ids(attention_mask): + seqlens = attention_mask.sum(dim=1).flatten() + position_ids = torch.cat([ + torch.arange( + 0, seqlen, dtype=torch.int32, device=attention_mask.device) + for seqlen in seqlens + ], + dim=0) + return position_ids.unsqueeze(0) + + position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device) + + rng_state_xla = rng_state.to(device) + if alibi: + alibi_slopes = alibi_slopes.cpu().to(device) + + softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward( + q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(), + position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], True, None) + + assert torch.allclose( + q_xla.cpu(), + q_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + k_xla.cpu(), + k_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + v_xla.cpu(), + v_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + cu_seqlen_k_xla[:batch_size + 1].cpu(), + cu_seqlens_k.cpu(), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + o_xla.cpu(), + o_cuda.unsqueeze(0).cpu(), + rtol=1e-2, + atol=1e-2, + equal_nan=True) + + q_xla.requires_grad = True + k_xla.requires_grad = True + v_xla.requires_grad = True + o_xla.requires_grad = True + softmax_lse_xla.requires_grad = True + + dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_position_ids_backward( + do_xla.contiguous(), q_xla.contiguous(), k_xla.contiguous(), + v_xla.contiguous(), o_xla.contiguous(), softmax_lse_xla.contiguous(), + cu_seqlen_q_xla, cu_seqlen_k_xla, alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], deterministic, None, + rng_state_xla) + + ta.mark_step(wait=True) + torch.cuda.synchronize() + + dq_xla = dq_xla.cpu().detach().squeeze(0) + dk_xla = dk_xla.cpu().detach().squeeze(0) + dv_xla = dv_xla.cpu().detach().squeeze(0) + do_xla = do_xla.cpu().detach().squeeze(0) + softmax_d_xla = softmax_d_xla.cpu().detach() + + dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q) + dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k) + dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k) + + assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py index b17d0c4e91a..843573436fb 100644 --- a/test/test_flash_attention_varlen_forward.py +++ b/test/test_flash_attention_varlen_forward.py @@ -225,7 +225,223 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, cu_seqlen_q_xla, cu_seqlens_q, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True) - softmax_lse_xla = softmax_lse_xla[:, :, :max_seqlen_in_batch_q] - softmax_lse = softmax_lse[:, :, :max_seqlen_in_batch_q] + for i in range(len(cu_seq_lens[0]) - 1): + seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i] + assert torch.allclose( + softmax_lse_xla[i, :, :seqlen], + softmax_lse[i, :, :seqlen], + rtol=1e-2, + atol=1e-3, + equal_nan=True) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32]) +@pytest.mark.parametrize("softmax_scale", [0.25]) +@pytest.mark.parametrize( + "max_seqlen_q,max_seqlen_k", + [ + (8, 8), + (128, 128), + (2048, 2048), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, + dropout_p, causal, softmax_scale, + local, alibi, deterministic, + mha_type, dtype): + if d % 8 != 0: + pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") + + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else tuple( + torch.randint(0, max_seqlen_k, (2,)).tolist()) + torch.cuda.synchronize() + + def generate_qkv_and_position_ids(batch_size, max_seqlen_q, max_seqlen_k, + dtype, n_heads_q, n_heads_k, device, + head_dim, use_same_seqlen): + ''' + generate varlen qkv and postion_ids and pad total seqlen to 8 + ''' + seq_len_q = torch.randint(1, max_seqlen_q + 1, (batch_size,)) + if use_same_seqlen: + seq_len_k = seq_len_q.clone() + else: + seq_len_k = torch.randint(1, max_seqlen_k + 1, (batch_size,)) + total_q = seq_len_q.sum().item() + total_k = seq_len_k.sum().item() + + padd_q = 0 if total_q % 8 == 0 else 8 - total_q % 8 + padd_k = 0 if total_k % 8 == 0 else 8 - total_k % 8 + if use_same_seqlen: + assert torch.all(seq_len_k == seq_len_q) + + # padding to last q and k + if padd_q: + seq_len_q[-1] += padd_q + total_q += padd_q + assert total_q % 8 == 0 + if padd_k: + seq_len_k[-1] += padd_k + total_k += padd_k + assert total_k % 8 == 0 + + q = torch.randn((1, total_q, n_heads_q, head_dim), + dtype=dtype, + device=device) + k = torch.randn((1, total_k, n_heads_k, head_dim), + dtype=dtype, + device=device) + v = torch.randn((1, total_k, n_heads_k, head_dim), + dtype=dtype, + device=device) + + assert torch.all(seq_len_q > 0) + assert torch.all(seq_len_k > 0) + + position_ids_q = torch.cat([ + torch.arange(0, seq_len, dtype=torch.int32, device=device) + for seq_len in seq_len_q + ], + dim=0).unsqueeze(0) + position_ids_k = torch.cat([ + torch.arange(0, seq_len, dtype=torch.int32, device=device) + for seq_len in seq_len_k + ], + dim=0).unsqueeze(0) + assert position_ids_q.shape[1] % 8 == 0 + assert position_ids_k.shape[1] % 8 == 0 + + return q, k, v, position_ids_q, position_ids_k + + q, k, v, position_ids_q, position_ids_k = generate_qkv_and_position_ids( + batch_size, max_seqlen_q, max_seqlen_k, dtype, nheads, nheads_k, device, + d, max_seqlen_q == max_seqlen_k) + + indices_q = torch.arange( + 0, position_ids_q.size(1), device=device, dtype=torch.int64) + indices_k = torch.arange( + 0, position_ids_k.size(1), device=device, dtype=torch.int64) + + seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0] + seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0] + + cu_seq_lens_q = F.pad( + seq_q_start_idx, (0, 1), value=position_ids_q.size(1)).to(torch.int32) + cu_seq_lens_k = F.pad( + seq_k_start_idx, (0, 1), value=position_ids_k.size(1)).to(torch.int32) + + max_seqlen_q = (cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]).max().item() + max_seqlen_k = (cu_seq_lens_k[1:] - cu_seq_lens_k[:-1]).max().item() + + q_cuda = q.reshape(-1, nheads, d).cuda() + k_cuda = k.reshape(-1, nheads_k, d).cuda() + v_cuda = v.reshape(-1, nheads_k, d).cuda() + + if alibi: + alibi_slopes = torch.rand( + batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + else: + alibi_slopes = None + + out_fa, softmax_lse, _ = flash_attn_varlen_func( + q_cuda.contiguous(), + k_cuda.contiguous(), + v_cuda.contiguous(), + cu_seq_lens_q.contiguous(), + cu_seq_lens_k.contiguous(), + max_seqlen_q, + max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + out_fa = pad_input(out_fa, indices_q, batch_size, max_seqlen_q) + + q = q.cpu().detach() + k = k.cpu().detach() + v = v.cpu().detach() + + out_fa = out_fa.cpu().detach() + softmax_lse = softmax_lse.cpu().detach() + cu_seq_lens_q = cu_seq_lens_q.cpu().detach() + cu_seq_lens_k = cu_seq_lens_k.cpu().detach() + + if alibi: + alibi_slopes = alibi_slopes.cpu() + torch.cuda.synchronize() + + device = ta.lazy_device() + torch.random.manual_seed(0) + q_xla = q.to(device) + k_xla = k.to(device) + v_xla = v.to(device) + + position_ids_q_xla = position_ids_q.to(device) + + q_xla.requires_grad = False + k_xla.requires_grad = False + v_xla.requires_grad = False + if alibi: + alibi_slopes = alibi_slopes.cpu().to(device) + + softmax_lse_xla, out_xla, _, cu_seq_len_q_xla, cu_seq_len_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward( + q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(), + position_ids_q_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], True, None) + + ta.mark_step(wait=True) + q_xla = q_xla.cpu().detach() + k_xla = k_xla.cpu().detach() + v_xla = v_xla.cpu().detach() + out_xla = out_xla.cpu().detach() + cu_seq_len_q_xla = cu_seq_len_q_xla.cpu().detach() + cu_seq_len_k_xla = cu_seq_len_k_xla.cpu().detach() + softmax_lse_xla = softmax_lse_xla.cpu().detach() + position_ids_q_xla = position_ids_q_xla.cpu().detach() + + out_xla = out_xla.squeeze(0) + + out_xla = pad_input(out_xla, indices_q, batch_size, max_seqlen_q) + + assert out_xla.shape == out_fa.shape + + assert torch.allclose(q_xla, q, rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(k_xla, k, rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(v_xla, v, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( - softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2, equal_nan=True) + cu_seq_len_k_xla[:batch_size + 1], + cu_seq_lens_k, + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True) + for i in range(batch_size): + start_idx = cu_seq_len_q_xla[i] + end_idx = cu_seq_len_q_xla[i + 1] + seqlen = end_idx - start_idx + assert torch.allclose( + softmax_lse_xla[0, :, start_idx:end_idx], + softmax_lse[i, :, :seqlen], + rtol=1e-3, + atol=1e-3, + equal_nan=True) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 2c3331c8d1d..6536dc96c6c 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -236,7 +236,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b, FlashAttentionForwardParams get_flash_attention_forward_params( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - c10::optional& attention_mask, // (batch_size, seqlen) + c10::optional attention_mask, // (batch_size, seqlen) + c10::optional position_ids, // (1,seqlen_q) c10::optional& alibi_slopes_, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool return_softmax) { @@ -274,10 +275,15 @@ FlashAttentionForwardParams get_flash_attention_forward_params( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + if (attention_mask.has_value()) { TORCH_CHECK(attention_mask.value().dtype() == torch::kInt32); CHECK_SHAPE(attention_mask.value(), batch_size, seqlen_k); } + if (position_ids.has_value()) { + TORCH_CHECK(position_ids.value().dtype() == torch::kInt32); + CHECK_SHAPE(position_ids.value(), 1, seqlen_q); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); @@ -308,7 +314,6 @@ FlashAttentionForwardParams get_flash_attention_forward_params( p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, alibi_slopes_batch_stride, enable_alibi_slopes, /*seqlenq_ngroups_swapped*/ false); - return params; } @@ -382,9 +387,19 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); - TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1); - CHECK_SHAPE(cu_seqlens_q.value(), batch_size + 1); - CHECK_SHAPE(cu_seqlens_k.value(), batch_size + 1); + TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || + batch_size == 1); // now pack qkv batch size only support 1, + // maybe need to change in the future + TORCH_CHECK( + cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) || + cu_seqlens_q.value().sizes() == + torch::IntArrayRef({seqlen_q * batch_size + 1}), + "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1"); + TORCH_CHECK( + cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) || + cu_seqlens_k.value().sizes() == + torch::IntArrayRef({seqlen_k * batch_size + 1}), + "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1"); } int alibi_slopes_batch_stride = 0; @@ -412,6 +427,21 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( return params; } +at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, + int& max_seqlen_in_batch, int& total_q, + int& real_batch_size) { + const at::Tensor valid_cu_seqlens = + padded_cu_seqlens.index({padded_cu_seqlens > -1}); + real_batch_size = valid_cu_seqlens.size(0) - 1; + at::Tensor seqs_len = + valid_cu_seqlens.slice(0, 1, valid_cu_seqlens.size(0)) - + valid_cu_seqlens.slice(0, 0, valid_cu_seqlens.size(0) - 1); + max_seqlen_in_batch = seqs_len.max().item(); + total_q = valid_cu_seqlens[-1].item(); + return torch::arange(total_q, + torch::dtype(torch::kInt64).device(torch::kCUDA)); +} + at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, int seqlen, torch::Dtype scalar_type, int& max_seqlen_in_batch, int& total) { @@ -451,18 +481,92 @@ at::Tensor mask_to_indices(const at::Tensor& attention_mask, return indices; } +torch::Tensor unpad_softmax_lse( + const torch::Tensor& pad_softmax_lse, // (batch_size, nhead, max_seqlen) + const torch::Tensor& cu_seqlens) // (total_seqlen + 1) +{ + int batch_size = pad_softmax_lse.size(0); + int nhead = pad_softmax_lse.size(1); + int max_seqlen = pad_softmax_lse.size(2); + int total, max_seqlen_in_batch; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1); + at::Tensor indices = + cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen, + torch::kInt64, max_seqlen_in_batch, total); + at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options()); + result.copy_( + pad_softmax_lse.transpose(1, 2) + .reshape({batch_size * max_seqlen, nhead}) + .index({indices, + torch::indexing::Slice()})); // if packed tensor's batch size + // > 1 is supported in the + // future, need to modify here + // in the future + return result.transpose(0, 1).unsqueeze(0); +} + +torch::Tensor pad_softmax_lse( + const at::Tensor& softmax_lse, // (1,nheads,total_seqlen) + const at::Tensor& cu_seqlens, const int max_seq_len, const int batch_size) { + const int nheads = softmax_lse.size(1); + int max_seqlen_in_batch; + int total; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1); + at::Tensor indices = + cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seq_len, + torch::kInt32, max_seqlen_in_batch, total); + TORCH_CHECK(indices.size(0) == softmax_lse.size(2), + "indice should be same size with softmax_lse") + + at::Tensor result = + at::zeros({batch_size * max_seq_len, nheads}, softmax_lse.options()); + + result.index_put_({indices, torch::indexing::Slice()}, + softmax_lse.squeeze(0).transpose(0, 1)); + return result.reshape({batch_size, max_seq_len, nheads}) + .transpose(1, 2) + .contiguous(); +} + +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, + int& max_seqlen_in_batch, int& total, + at::Tensor& cu_seqlen, + int& real_batch_size) { + cu_seqlen.fill_(-1); + at::Tensor flatten_position_ids = position_ids.flatten(); + at::Tensor indices = + torch::arange(flatten_position_ids.size(0), + torch::dtype(torch::kInt64).device(torch::kCUDA)); + + at::Tensor batch_seq_start_idx = indices.index({flatten_position_ids == 0}); + real_batch_size = batch_seq_start_idx.size(0); + at::Tensor batch_seqlen_cumsum = at::empty( + {real_batch_size + 1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + batch_seqlen_cumsum.index({torch::indexing::Slice(0, real_batch_size)}) = + batch_seq_start_idx; + total = flatten_position_ids.size(0); + batch_seqlen_cumsum.index({-1}) = total; + + at::Tensor batch_seqlen = + batch_seqlen_cumsum.slice(0, 1, batch_seqlen_cumsum.size(0)) - + batch_seqlen_cumsum.slice(0, 0, batch_seqlen_cumsum.size(0) - 1); + max_seqlen_in_batch = batch_seqlen.max().item(); + cu_seqlen.narrow(0, 0, real_batch_size + 1) = batch_seqlen_cumsum; + return indices; +} + at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices) { torch::IntArrayRef sizes = input.sizes(); - int64_t first_axis_dim = sizes[0]; - auto other_shape = sizes.slice(1, sizes.size() - 1); + int64_t first_axis_dim = sizes[0]; // bs + auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h] int64_t second_dim = 1; for (auto dim : other_shape) { second_dim *= dim; } - at::Tensor flat_input = torch::flatten(input, 1); + at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah] torch::Tensor repeated_indices = indices.unsqueeze(1).expand({indices.size(0), second_dim}); at::Tensor gather_input = torch::gather(flat_input, 0, repeated_indices); diff --git a/torch_xla/csrc/flash_attention_utils.h b/torch_xla/csrc/flash_attention_utils.h index a2696f25f82..d231f2e5165 100644 --- a/torch_xla/csrc/flash_attention_utils.h +++ b/torch_xla/csrc/flash_attention_utils.h @@ -112,7 +112,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b, FlashAttentionForwardParams get_flash_attention_forward_params( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - c10::optional& attention_mask, + c10::optional attention_mask, + c10::optional position_ids, c10::optional& alibi_slopes_, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool return_softmax); @@ -130,10 +131,18 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, int seqlen, torch::Dtype scalar_type, int& max_seqlen_in_batch, int& total); +at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, + int& max_seqlen_in_batch, int& total_q, + int& real_batch_size); + at::Tensor mask_to_indices(const at::Tensor& attention_mask, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen); +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, + int& max_seqlen_in_batch, int& total, + at::Tensor& cu_seqlen, int& real_batch_size); + at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices); // make xla::Shape like input, same elemet_type and dimension as of input, @@ -142,5 +151,11 @@ xla::Shape shape_like(const torch::lazy::Value& input); xla::Shape shape_like(const xla::XlaBuilder* builder, const xla::XlaOp& input); +torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, + const torch::Tensor& cu_seqlens); + +torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, + const at::Tensor& cu_seqlens, + const int max_seq_len, const int batch_size); } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_FLASH_ATTENTION_UTILS_H \ No newline at end of file diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bce6795ae8d..2ca6b514071 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2409,9 +2409,9 @@ void InitXlaModuleBindings(py::module m) { c10::optional gen_) { // get launch params on at::Tensor auto params = get_flash_attention_forward_params( - q, k, v, attention_mask, alibi_slopes, p_dropout, softmax_scale, - zero_tensors, is_causal, window_size_left, window_size_right, - return_softmax); + q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, + softmax_scale, zero_tensors, is_causal, window_size_left, + window_size_right, return_softmax); // call flash attention forward XLATensorPtr q_xla = bridge::GetXlaTensor(q); XLATensorPtr k_xla = bridge::GetXlaTensor(k); @@ -2440,6 +2440,50 @@ void InitXlaModuleBindings(py::module m) { } return results; }); + + m.def("_flash_attention_position_ids_forward", + [](const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, + c10::optional& position_ids, + c10::optional& alibi_slopes, const float p_dropout, + const float softmax_scale, const bool zero_tensors, + const bool is_causal, const int window_size_left, + const int window_size_right, const bool return_softmax, + c10::optional gen_) { + // get launch params on at::Tensor + auto params = get_flash_attention_forward_params( + q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, + softmax_scale, zero_tensors, is_causal, window_size_left, + window_size_right, return_softmax); + // call flash attention forward + XLATensorPtr q_xla = bridge::GetXlaTensor(q); + XLATensorPtr k_xla = bridge::GetXlaTensor(k); + XLATensorPtr v_xla = bridge::GetXlaTensor(v); + XLATensorPtr alibi_slopes_xla = + alibi_slopes.has_value() + ? bridge::GetXlaTensor(alibi_slopes.value()) + : XLATensorPtr(); + + std::vector xresults; + if (position_ids.has_value()) { + XLATensorPtr position_ids_xla = + bridge::GetXlaTensor(position_ids.value()); + xresults = + tensor_methods::flash_attention_varlen_position_ids_forward( + q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla, + params.ToString()); + } else { + xresults = tensor_methods::flash_attention_forward( + q_xla, k_xla, v_xla, alibi_slopes_xla, params.ToString()); + } + std::vector results; + for (auto& xresult : xresults) { + at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult)); + results.push_back(torch::autograd::make_variable( + tensor, /*requires_grad=*/false)); + } + return results; + }); + m.def( "_flash_attention_backward", [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k, @@ -2492,6 +2536,60 @@ void InitXlaModuleBindings(py::module m) { } return results; }); + m.def( + "_flash_attention_position_ids_backward", + [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k, + const at::Tensor& v, const at::Tensor& out, + const at::Tensor& softmax_lse, c10::optional& cu_seqlens_q, + c10::optional& cu_seqlens_k, + c10::optional& alibi_slopes, const float p_dropout, + const float softmax_scale, const bool zero_tensors, + const bool is_causal, const int window_size_left, + const int window_size_right, const bool deterministic, + c10::optional gen_, const at::Tensor& rng_state) { + // get launch params on at::Tensor + auto params = get_flash_attention_backward_params( + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + alibi_slopes, p_dropout, softmax_scale, zero_tensors, is_causal, + window_size_left, window_size_right, deterministic); + // call flash attention backward + XLATensorPtr dout_xla = bridge::GetXlaTensor(dout); + XLATensorPtr q_xla = bridge::GetXlaTensor(q); + XLATensorPtr k_xla = bridge::GetXlaTensor(k); + XLATensorPtr v_xla = bridge::GetXlaTensor(v); + XLATensorPtr out_xla = bridge::GetXlaTensor(out); + XLATensorPtr softmax_lse_xla = bridge::GetXlaTensor(softmax_lse); + XLATensorPtr rng_state_xla = bridge::GetXlaTensor(rng_state); + XLATensorPtr alibi_slopes_xla = + alibi_slopes.has_value() + ? bridge::GetXlaTensor(alibi_slopes.value()) + : XLATensorPtr(); + + std::vector xresults; + if (cu_seqlens_q.has_value() && cu_seqlens_k.has_value()) { + XLATensorPtr cu_seqlens_q_xla = + bridge::GetXlaTensor(cu_seqlens_q.value()); + XLATensorPtr cu_seqlens_k_xla = + bridge::GetXlaTensor(cu_seqlens_k.value()); + xresults = + tensor_methods::flash_attention_varlen_position_ids_backward( + dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, + cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla, + alibi_slopes_xla, params.ToString()); + } else { + xresults = tensor_methods::flash_attention_backward( + dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, + rng_state_xla, alibi_slopes_xla, params.ToString()); + } + std::vector results; + for (auto& xresult : xresults) { + at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult)); + results.push_back( + torch::autograd::make_variable(tensor, /*requires_grad=*/false)); + } + return results; + }); + // -------------FlashAttention Integration API End------------------- // -------------Dynamo Integration API Start------------------------- diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp index 478f9f114df..832e6c280fa 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp @@ -16,7 +16,6 @@ namespace torch_xla { namespace { - xla::Shape NodeOutputShape(const torch::lazy::Value& q) { auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions()); xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape( @@ -98,9 +97,9 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, int max_seqlen_in_batch_k = params.seqlen_k; int total_k = params.b * params.seqlen_k; - at::Tensor indices_k = mask_to_indices(attention_mask, max_seqlen_in_batch_k, total_k, cu_seqlens_k); + auto unpad_k = index_first_axis(k, indices_k); auto unpad_v = index_first_axis(v, indices_k); @@ -113,6 +112,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, indices_q = indices_k; } else if (params.seqlen_q == 1) { max_seqlen_in_batch_q = 1; + cu_seqlens_q = torch::arange(0, params.b + 1, opts); indices_q = cu_seqlens_q.slice(/*dim=*/0, /*start=*/0, /*end=*/params.b); total_q = params.b; } else { @@ -121,7 +121,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, indices_q = mask_to_indices(attention_mask_slice, max_seqlen_in_batch_q, total_q, cu_seqlens_q); } - at::Tensor unpad_q = index_first_axis(q, indices_q); at::Tensor unpad_output = @@ -241,7 +240,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, } TF_VLOG(2) << "Running FlashAttention Forward."; - FP16_SWITCH(!launch_params.is_bf16, [&] { HEADDIM_SWITCH(launch_params.d, [&] { // TODO(wenting.swt): support split_kv @@ -261,6 +259,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, cudaEventRecord(xla_wait_torch_event, torch_stream); cudaStreamWaitEvent(stream, xla_wait_torch_event); } + XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_forward, "CUDA"); diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp new file mode 100644 index 00000000000..0e508b80af6 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -0,0 +1,434 @@ +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h" + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "static_switch.h" +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "xla/service/custom_call_target_registry.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& softmax_lse) { + return xla::ShapeUtil::MakeTupleShape( + {shape_like(q), shape_like(k), shape_like(v), shape_like(softmax_lse)}); +} + +void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, + const bool configure) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_(params, stream); }); + }); +} + +// Layout of `buffers` listed above: +// buffers[0] = dout +// buffers[1] = q +// buffers[2] = k +// buffers[3] = v +// buffers[4] = out +// buffers[5] = softmax_lse +// buffers[6] = cu_seqlens_q +// buffers[7] = cu_seqlens_k +// buffers[8] = rng_state +// buffers[9] = alibi_slopes +// buffers[10] = dq // this is output +// buffers[11] = dk // this is output +// buffers[12] = dv // this is output +// buffers[13] = softmax_d // this is output +void custom_call_flash_attention_varlen_position_ids_backward( + cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + std::string opaque_str(opaque, opaque_len); + TF_VLOG(3) + << "custom_call_flash_attention_varlen_position_ids_backward opaque str: " + << opaque_str; + FlashAttentionBackwardParams params; + params.FromString(std::move(opaque_str)); + int buf_offset = params.enable_alibi_slopes; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + + cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream(); + cudaEvent_t torch_wait_xla_event; + cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming); + cudaEvent_t xla_wait_torch_event; + cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming); + cudaEventRecord(torch_wait_xla_event, stream); + cudaStreamWaitEvent(torch_stream, torch_wait_xla_event); + + auto cuda_stream = at::cuda::getCurrentCUDAStream(); + at::cuda::CUDAStreamGuard guard(cuda_stream); + + auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); + + // Inputs + at::Tensor do_ = torch::from_blob( + buffers[0], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor q = torch::from_blob( + buffers[1], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor k = torch::from_blob( + buffers[2], {params.b * params.seqlen_k, params.h_k, params.d}, opts); + at::Tensor v = torch::from_blob( + buffers[3], {params.b * params.seqlen_k, params.h_k, params.d}, opts); + at::Tensor o = torch::from_blob( + buffers[4], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor softmax_lse = + torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + at::Tensor cu_seqlens_q = torch::from_blob( + buffers[6], {params.b * params.seqlen_q + 1}, opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_k = torch::from_blob( + buffers[7], {params.b * params.seqlen_k + 1}, opts.dtype(torch::kInt32)); + + // Outputs + at::Tensor dq = + torch::from_blob(buffers[9 + buf_offset], + {params.b * params.seqlen_q, params.h, params.d}, opts); + + at::Tensor dk = torch::from_blob( + buffers[10 + buf_offset], + {params.b * params.seqlen_k, params.h_k, params.d}, opts); + + at::Tensor dv = torch::from_blob( + buffers[11 + buf_offset], + {params.b * params.seqlen_k, params.h_k, params.d}, opts); + + at::Tensor dsoftmax_sum = torch::from_blob( + buffers[12 + buf_offset], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + + // Fill zeros for outputs. + dq.fill_(0); + dk.fill_(0); + dv.fill_(0); + dsoftmax_sum.fill_(0); + + int max_seqlen_in_batch_q = params.seqlen_q; + int max_seqlen_in_batch_k = params.seqlen_k; + int total_q = params.b * params.seqlen_q; + int total_k = params.b * params.seqlen_k; + int real_batch_size; + at::Tensor indices_q = cu_seqlens_to_indices( + cu_seqlens_q, max_seqlen_in_batch_q, total_q, real_batch_size); + + at::Tensor indices_k; + if (params.seqlen_q == params.seqlen_k) { + indices_k = indices_q; + max_seqlen_in_batch_k = max_seqlen_in_batch_q; + total_k = total_q; + } else { + indices_k = cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, + total_k, real_batch_size); + } + + auto padded_softmax_lse = pad_softmax_lse( + softmax_lse, cu_seqlens_q, max_seqlen_in_batch_q, real_batch_size); + + Flash_bwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data_ptr(); + launch_params.k_ptr = k.data_ptr(); + launch_params.v_ptr = v.data_ptr(); + launch_params.o_ptr = o.data_ptr(); + + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast(cu_seqlens_q.data_ptr()); + launch_params.cu_seqlens_k = static_cast(cu_seqlens_k.data_ptr()); + launch_params.softmax_lse_ptr = padded_softmax_lse.data_ptr(); + + launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[9] : nullptr; + + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Set the dimensions. + launch_params.b = real_batch_size; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = max_seqlen_in_batch_q; + launch_params.seqlen_k = max_seqlen_in_batch_k; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128); + launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128); + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + if (max_seqlen_in_batch_q == 1) { + params.is_causal = false; + } + if (params.is_causal) { + params.window_size_right = 0; + } + + if (params.window_size_left >= max_seqlen_in_batch_k) { + params.window_size_left = -1; + } + if (params.window_size_right >= max_seqlen_in_batch_k) { + params.window_size_right = -1; + } + + launch_params.is_causal = + params.window_size_left < 0 && params.window_size_right == 0; + + if (params.window_size_left < 0 && params.window_size_right >= 0) { + params.window_size_left = max_seqlen_in_batch_k; + } + if (params.window_size_left >= 0 && params.window_size_right < 0) { + params.window_size_right = max_seqlen_in_batch_k; + } + + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = true; + + launch_params.do_row_stride = params.do_row_stride; + launch_params.do_head_stride = params.do_head_stride; + + launch_params.dq_row_stride = params.dq_row_stride; + launch_params.dk_row_stride = params.dk_row_stride; + launch_params.dv_row_stride = params.dv_row_stride; + launch_params.dq_head_stride = params.dq_head_stride; + launch_params.dk_head_stride = params.dk_head_stride; + launch_params.dv_head_stride = params.dv_head_stride; + + at::Tensor rounded_dsoftmax_sum = + at::zeros({real_batch_size, params.h, launch_params.seqlen_q_rounded}, + opts.dtype(torch::kFloat)); + + launch_params.do_ptr = do_.data_ptr(); + launch_params.dq_ptr = dq.data_ptr(); + launch_params.dk_ptr = dk.data_ptr(); + launch_params.dv_ptr = dv.data_ptr(); + launch_params.dsoftmax_sum = rounded_dsoftmax_sum.data_ptr(); + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + at::Tensor dq_accum; + if (loop) { + if (!params.deterministic) { + dq_accum = torch::empty({total_q + 128 * launch_params.b, launch_params.h, + launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + const int nsplits = (dprops->multiProcessorCount + + launch_params.b * launch_params.h - 1) / + (launch_params.b * launch_params.h); + dq_accum = torch::zeros({nsplits, total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } + } + + at::Tensor dk_expanded, dv_expanded; + + if (launch_params.h_k != launch_params.h) { // MQA / GQA + TF_VLOG(2) << "Running FlashAttention Backward as MQA/GQA"; + dk_expanded = + torch::empty({total_k, launch_params.h, launch_params.d}, opts); + dv_expanded = + torch::empty({total_k, launch_params.h, launch_params.d}, opts); + + launch_params.dk_ptr = dk_expanded.data_ptr(); + launch_params.dv_ptr = dv_expanded.data_ptr(); + launch_params.dk_row_stride = dk_expanded.stride(-3); + launch_params.dv_row_stride = dv_expanded.stride(-3); + launch_params.dk_head_stride = dk_expanded.stride(-2); + launch_params.dv_head_stride = dv_expanded.stride(-2); + } else { + TF_VLOG(2) << "Running FlashAttention Backward"; + dk_expanded = dk; + dv_expanded = dv; + } + + launch_params.dq_accum_ptr = loop ? dq_accum.data_ptr() : nullptr; + launch_params.dk_accum_ptr = nullptr; + launch_params.dv_accum_ptr = nullptr; + + launch_params.deterministic = params.deterministic; + launch_params.dq_accum_split_stride = + !launch_params.deterministic ? 0 : dq_accum.stride(0); + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + + bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; + + // TODO(wenting.swt): According to the implementation in + // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates + // `rng_state` which is passed as ctx to the backward. Hence, for simplifying + // the logic, the redundant branch where `rng_state` is None has been omitted. + launch_params.rng_state = reinterpret_cast(buffers[8]); + + launch(launch_params, torch_stream, /*configure=*/false); + + // For MQA/GQA we need to sum dK and dV across the groups + if (launch_params.h_k != launch_params.h) { + at::sum_out(dk, + at::reshape(dk_expanded, {total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + at::sum_out(dv, + at::reshape(dv_expanded, {total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + } + + dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum, cu_seqlens_q)); + + cudaEventRecord(xla_wait_torch_event, torch_stream); + cudaStreamWaitEvent(stream, xla_wait_torch_event); +} +XLA_REGISTER_CUSTOM_CALL_TARGET( + custom_call_flash_attention_varlen_position_ids_backward, "CUDA"); + +std::vector BuildFlashAttentionVarlenPositionIdsBackward( + const xla::XlaOp& dout, const xla::XlaOp& q, const xla::XlaOp& k, + const xla::XlaOp& v, const xla::XlaOp& out, const xla::XlaOp& softmax_lse, + const xla::XlaOp& cu_seqlens_q, const xla::XlaOp& cu_seqlens_k, + const xla::XlaOp& rng_state, const xla::XlaOp& alibi_slopes, + const std::string& opaque, const xla::Shape& output_shape) { + auto builder = q.builder(); + std::vector operands{ + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state}; + std::vector operand_shapes_with_layout{ + shape_like(builder, dout), + shape_like(builder, q), + shape_like(builder, k), + shape_like(builder, v), + shape_like(builder, out), + shape_like(builder, softmax_lse), + builder->GetShape(cu_seqlens_q).value(), + builder->GetShape(cu_seqlens_k).value(), + builder->GetShape(rng_state).value()}; + if (alibi_slopes.valid()) { + operands.push_back(alibi_slopes); + operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes)); + } + xla::XlaOp result = xla::CustomCallWithLayout( + builder, "custom_call_flash_attention_varlen_position_ids_backward", + std::move(operands), output_shape, std::move(operand_shapes_with_layout), + opaque); + return {xla::GetTupleElement(result, 0), xla::GetTupleElement(result, 1), + xla::GetTupleElement(result, 2), xla::GetTupleElement(result, 3)}; +} + +} // namespace + +FlashAttentionVarlenPositionIdsBackward:: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, const std::string params) + : XlaNode(xla_flash_attention_backward, + {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + rng_state}, + NodeOutputShape(q, k, v, softmax_lse), + /*num_outputs=*/4, torch::lazy::MHash(params)), + params_(params) {} + +FlashAttentionVarlenPositionIdsBackward:: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, + const torch::lazy::Value& alibi_slopes, const std::string params) + : XlaNode(xla_flash_attention_backward, + {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + rng_state, alibi_slopes}, + NodeOutputShape(q, k, v, softmax_lse), + /*num_outputs=*/4, torch::lazy::MHash(params)), + params_(params) {} + +torch::lazy::NodePtr FlashAttentionVarlenPositionIdsBackward::Clone( + torch::lazy::OpList operands) const { + if (operands.size() > 9) { + torch::lazy::MakeNode( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), operands.at(5), operands.at(6), operands.at(7), + operands.at(8), operands.at(9), params_); + } else { + torch::lazy::MakeNode( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), operands.at(5), operands.at(6), operands.at(7), + operands.at(8), params_); + } +} + +XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower( + LoweringContext* loctx) const { + xla::XlaOp dout = loctx->GetOutputOp(operand(0)); + xla::XlaOp q = loctx->GetOutputOp(operand(1)); + xla::XlaOp k = loctx->GetOutputOp(operand(2)); + xla::XlaOp v = loctx->GetOutputOp(operand(3)); + xla::XlaOp out = loctx->GetOutputOp(operand(4)); + xla::XlaOp softmax_lse = loctx->GetOutputOp(operand(5)); + xla::XlaOp cu_seqlens_q = loctx->GetOutputOp(operand(6)); + xla::XlaOp cu_seqlens_k = loctx->GetOutputOp(operand(7)); + xla::XlaOp rng_state = loctx->GetOutputOp(operand(8)); + xla::XlaOp alibi_slopes = + operands().size() > 9 ? loctx->GetOutputOp(operand(9)) : xla::XlaOp(); + std::vector result = BuildFlashAttentionVarlenPositionIdsBackward( + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + alibi_slopes, params_, xla_shape()); + + return ReturnOps({result}, loctx); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h new file mode 100644 index 00000000000..42ba728ceb0 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h @@ -0,0 +1,38 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ +#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ + +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class FlashAttentionVarlenPositionIdsBackward : public XlaNode { + public: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, const std::string params); + + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, + const torch::lazy::Value& alibi_slopes, const std::string params); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + const std::string params_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp new file mode 100644 index 00000000000..23d5d1c4cc6 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp @@ -0,0 +1,328 @@ +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "static_switch.h" +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "xla/service/custom_call_target_registry.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& q) { + auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions()); + xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape( + xla::PrimitiveType::F32, + {q_shape[0], q_shape[2], q_shape[1]}); // 1, num_heads, total_q + xla::Shape rng_state_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); + xla::Shape cu_seqlens_shape = xla::ShapeUtil::MakeShape( + xla::PrimitiveType::S32, + {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim] + return xla::ShapeUtil::MakeTupleShape({softmax_lse_shape, shape_like(q), + rng_state_shape, cu_seqlens_shape, + cu_seqlens_shape}); +} + +// Layout of `buffers` listed above: +// buffers[0] = q +// buffers[1] = k +// buffers[2] = v +// buffers[3] = attention_mask or position_ids +// buffers[4] = alibi_slopes +// buffers[5] = softmax_lse // this is output +// buffers[6] = out_for_output // this is output +// buffers[7] = rng_state // this is output +// buffers[8] = cu_seqlen_q // this is output +// buffers[9] = cu_seqlen_k // this is output +void custom_call_flash_attention_varlen_position_ids_forward( + cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + std::string opaque_str(opaque, opaque_len); + TF_VLOG(3) + << "custom_call_flash_attention_varlen_position_ids_forward opaque str: " + << opaque_str; + FlashAttentionForwardParams params; + params.FromString(std::move(opaque_str)); + int buf_offset = params.enable_alibi_slopes; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + + cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream(); + cudaEvent_t torch_wait_xla_event; + cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming); + cudaEvent_t xla_wait_torch_event; + cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming); + cudaEventRecord(torch_wait_xla_event, stream); + cudaStreamWaitEvent(torch_stream, torch_wait_xla_event); + + auto cuda_stream = at::cuda::getCurrentCUDAStream(); + at::cuda::CUDAStreamGuard guard(cuda_stream); + + auto opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + + at::Tensor q = torch::from_blob( + buffers[0], {params.b * params.seqlen_q, params.h, params.d}, + opts.dtype(scalar_type)); + at::Tensor k = torch::from_blob( + buffers[1], {params.b * params.seqlen_k, params.h_k, params.d}, + opts.dtype(scalar_type)); + at::Tensor v = torch::from_blob( + buffers[2], {params.b * params.seqlen_k, params.h_k, params.d}, + opts.dtype(scalar_type)); + at::Tensor position_ids = + torch::from_blob(buffers[3], {params.b, params.seqlen_k}, opts); + at::Tensor softmax_lse = torch::from_blob( + buffers[4 + buf_offset], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + at::Tensor o_output = + torch::from_blob(buffers[5 + buf_offset], + {params.b * params.seqlen_q, params.h * params.d}, + opts.dtype(scalar_type)); + at::Tensor cu_seqlens_q = + torch::from_blob(buffers[7 + buf_offset], {params.seqlen_q + 1}, opts); + at::Tensor cu_seqlens_k = + torch::from_blob(buffers[8 + buf_offset], {params.seqlen_k + 1}, opts); + at::Tensor rng_state = + torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64)); + softmax_lse.fill_(0); + o_output.fill_(0); + + int max_seqlen_in_batch_k = params.seqlen_k; + int total_k = params.b * params.seqlen_k; + at::Tensor indices_k; + + int real_batch_size; // packed qkv's batch size is 1, but fa need to know + // real batch size before packing. + indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k, + total_k, cu_seqlens_k, real_batch_size); + TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, + "cu_seqlen'shape should be params.seqlen_k"); + TORCH_CHECK(indices_k.dtype() == torch::kInt64, "indice should be int64"); + + int max_seqlen_in_batch_q = max_seqlen_in_batch_k; + int total_q = total_k; + at::Tensor indices_q; + + TORCH_CHECK(params.seqlen_q == params.seqlen_k, + "now only support same seqlen for q and k"); + + if (params.seqlen_q == params.seqlen_k) { + cu_seqlens_q.copy_(cu_seqlens_k); + indices_q = indices_k; + } else { + // TODO:(wangtianxing.wtx) support different seqlen_q and seqlen_k + indices_q = position_ids_to_indices(position_ids, max_seqlen_in_batch_q, + total_q, cu_seqlens_q, real_batch_size); + } + + if (max_seqlen_in_batch_q == 1) { + params.is_causal = false; + } + if (params.is_causal) { + params.window_size_right = 0; + } + + if (params.window_size_left >= max_seqlen_in_batch_k) { + params.window_size_left = -1; + } + if (params.window_size_right >= max_seqlen_in_batch_k) { + params.window_size_right = -1; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + at::Tensor pad_softmax_lse = + at::empty({real_batch_size, params.h, max_seqlen_in_batch_q}, + torch::dtype(torch::kFloat).device(torch::kCUDA)); + + Flash_fwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data_ptr(); + launch_params.k_ptr = k.data_ptr(); + launch_params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = o_output.data_ptr(); + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast(cu_seqlens_q.data_ptr()); + launch_params.cu_seqlens_k = static_cast(cu_seqlens_k.data_ptr()); + + launch_params.seqused_k = static_cast(nullptr); + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = pad_softmax_lse.data_ptr(); + + // Set the dimensions. + launch_params.b = real_batch_size; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = max_seqlen_in_batch_q; + launch_params.seqlen_k = max_seqlen_in_batch_k; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128); + launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128); + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = + params.window_size_left < 0 && params.window_size_right == 0; + + if (params.window_size_left < 0 && params.window_size_right >= 0) { + params.window_size_left = max_seqlen_in_batch_k; + } + if (params.window_size_left >= 0 && params.window_size_right < 0) { + params.window_size_right = max_seqlen_in_batch_k; + } + + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative; + + launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[4] : nullptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // set params splitkv + launch_params.num_splits = params.num_splits; + + int64_t counter_offset = params.b * params.h * 32; + + // Forward kernel will populate memory with the seed and offset. + launch_params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if ((1.f - launch_params.p_dropout) > 0.0) { + // number of times random will be generated per thread, to offset philox + // counter in thc random state We use a custom RNG that increases the offset + // by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.philox_args = gen->philox_cuda_state(counter_offset); + } + TF_VLOG(2) << "Running FlashAttention Forward."; + FP16_SWITCH(!launch_params.is_bf16, [&] { + HEADDIM_SWITCH(launch_params.d, [&] { + // TODO(wenting.swt): support split_kv + run_mha_fwd_(launch_params, torch_stream); + }); + }); + softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse, cu_seqlens_q)); + + // TODO(wenting.swt): we should pad and unpad q,k,v when head_size_og % 8 != 0 + // sync with cudaEvent + cudaEventRecord(xla_wait_torch_event, torch_stream); + cudaStreamWaitEvent(stream, xla_wait_torch_event); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET( + custom_call_flash_attention_varlen_position_ids_forward, "CUDA"); + +std::vector BuildFlashAttentionVarlenPositionIdsForward( + const xla::XlaOp& q, const xla::XlaOp& k, const xla::XlaOp& v, + const xla::XlaOp& position_ids, const xla::XlaOp& alibi_slopes, + const std::string& opaque, const xla::Shape& output_shape) { + auto builder = q.builder(); + std::vector operands{q, k, v, position_ids}; + std::vector operand_shapes_with_layout{ + shape_like(builder, q), shape_like(builder, k), shape_like(builder, v), + shape_like(builder, position_ids)}; + if (alibi_slopes.valid()) { + operands.push_back(alibi_slopes); + operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes)); + } + xla::XlaOp result = xla::CustomCallWithLayout( + builder, "custom_call_flash_attention_varlen_position_ids_forward", + std::move(operands), output_shape, std::move(operand_shapes_with_layout), + opaque); + return {/*softmax_lse*/ xla::GetTupleElement(result, 0), + /*output*/ xla::GetTupleElement(result, 1), + /*rng_state*/ xla::GetTupleElement(result, 2), + /*cu_seqlen_q*/ xla::GetTupleElement(result, 3), + /*cu_seqlen_k*/ xla::GetTupleElement(result, 4)}; +} + +} // namespace + +FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward( + const torch::lazy::Value& q, const torch::lazy::Value& k, + const torch::lazy::Value& v, const torch::lazy::Value& position_ids, + const std::string params) + : XlaNode(xla_flash_attention_forward, {q, k, v, position_ids}, + NodeOutputShape(q), + /*num_outputs=*/5, torch::lazy::MHash(params)), + params_(params) {} + +FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward( + const torch::lazy::Value& q, const torch::lazy::Value& k, + const torch::lazy::Value& v, const torch::lazy::Value& position_ids, + const torch::lazy::Value& alibi_slopes, const std::string params) + : XlaNode(xla_flash_attention_forward, + {q, k, v, position_ids, alibi_slopes}, NodeOutputShape(q), + /*num_outputs=*/5, torch::lazy::MHash(params)), + params_(params) {} + +torch::lazy::NodePtr FlashAttentionVarlenPositionIdsForward::Clone( + torch::lazy::OpList operands) const { + if (operands.size() > 4) { + torch::lazy::MakeNode( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), params_); + } else { + torch::lazy::MakeNode( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + params_); + } +} + +XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower( + LoweringContext* loctx) const { + xla::XlaOp q = loctx->GetOutputOp(operand(0)); + xla::XlaOp k = loctx->GetOutputOp(operand(1)); + xla::XlaOp v = loctx->GetOutputOp(operand(2)); + xla::XlaOp position_ids = loctx->GetOutputOp(operand(3)); + xla::XlaOp alibi_slopes = + operands().size() > 4 ? loctx->GetOutputOp(operand(4)) : xla::XlaOp(); + std::vector result = BuildFlashAttentionVarlenPositionIdsForward( + q, k, v, position_ids, alibi_slopes, params_, xla_shape()); + return ReturnOps({result}, loctx); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h new file mode 100644 index 00000000000..2e27ee973b0 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h @@ -0,0 +1,34 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ +#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ + +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class FlashAttentionVarlenPositionIdsForward : public XlaNode { + public: + FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const std::string params); + + FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const torch::lazy::Value& alibi_slopes, + const std::string params); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + const std::string params_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ \ No newline at end of file diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 41ddfce8043..44489c0c44e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -53,6 +53,8 @@ #include "torch_xla/csrc/ops/flash_attention_forward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_backward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_forward.h" +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h" +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" #include "torch_xla/csrc/ops/flip.h" #include "torch_xla/csrc/ops/gather.h" #include "torch_xla/csrc/ops/generic.h" @@ -675,6 +677,25 @@ std::vector flash_attention_varlen_forward( } } +std::vector flash_attention_varlen_position_ids_forward( + const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, + const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes, + const std::string& params) { + if (alibi_slopes) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode( + q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), + position_ids->GetIrValue(), alibi_slopes->GetIrValue(), params); + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } else { + torch::lazy::NodePtr node = + torch::lazy::MakeNode( + q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), + position_ids->GetIrValue(), params); + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } +} + std::vector flash_attention_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, @@ -720,6 +741,31 @@ std::vector flash_attention_varlen_backward( } } +std::vector flash_attention_varlen_position_ids_backward( + const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, + const XLATensorPtr& v, const XLATensorPtr& out, + const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q, + const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, + const XLATensorPtr& alibi_slopes, const std::string& params) { + if (alibi_slopes) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode( + dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(), + v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(), + cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(), + rng_state->GetIrValue(), alibi_slopes->GetIrValue(), params); + return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } else { + torch::lazy::NodePtr node = + torch::lazy::MakeNode( + dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(), + v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(), + cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(), + rng_state->GetIrValue(), params); + return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } +} + std::vector user_computation( const std::string& opname, absl::Span inputs, runtime::ComputationClient::ComputationPtr computation) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb574640045..a04a9538463 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -123,6 +123,11 @@ std::vector flash_attention_varlen_forward( const XLATensorPtr& attention_mask, const XLATensorPtr& alibi_slopes, const std::string& params); +std::vector flash_attention_varlen_position_ids_forward( + const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, + const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes, + const std::string& params); + std::vector flash_attention_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, @@ -136,6 +141,13 @@ std::vector flash_attention_varlen_backward( const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, const XLATensorPtr& alibi_slopes, const std::string& params); +std::vector flash_attention_varlen_position_ids_backward( + const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, + const XLATensorPtr& v, const XLATensorPtr& out, + const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q, + const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, + const XLATensorPtr& alibi_slopes, const std::string& params); + std::vector user_computation( const std::string& opname, absl::Span inputs, runtime::ComputationClient::ComputationPtr computation);