From 141832e24c5932bc7326fa56322ff7781f4399af Mon Sep 17 00:00:00 2001 From: tianxingwang <wangtianxing.wtx@alibaba-inc.com> Date: Mon, 2 Dec 2024 13:43:36 +0800 Subject: [PATCH 1/5] support position_ids input for flash attention --- test/test_flash_attention_forward.py | 3 + test/test_flash_attention_varlen_backward.py | 180 ++++++++ test/test_flash_attention_varlen_forward.py | 193 +++++++- torch_xla/csrc/flash_attention_utils.cpp | 98 +++- torch_xla/csrc/flash_attention_utils.h | 14 +- torch_xla/csrc/init_python_bindings.cpp | 99 +++- .../ops/flash_attention_varlen_forward.cpp | 27 +- ...attention_varlen_position_ids_backward.cpp | 430 ++++++++++++++++++ ...h_attention_varlen_position_ids_backward.h | 38 ++ ..._attention_varlen_position_ids_forward.cpp | 325 +++++++++++++ ...sh_attention_varlen_position_ids_forward.h | 34 ++ torch_xla/csrc/tensor_methods.cpp | 48 ++ torch_xla/csrc/tensor_methods.h | 12 + 13 files changed, 1470 insertions(+), 31 deletions(-) create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h diff --git a/test/test_flash_attention_forward.py b/test/test_flash_attention_forward.py index 9e019ed29325..6414855a9dd2 100644 --- a/test/test_flash_attention_forward.py +++ b/test/test_flash_attention_forward.py @@ -136,3 +136,6 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose(softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2) assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2) + + + diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py index 1ee13118ee20..751b9aafd74f 100644 --- a/test/test_flash_attention_varlen_backward.py +++ b/test/test_flash_attention_varlen_backward.py @@ -249,3 +249,183 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + + + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [8]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (128, 128), + (1024, 1024), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_p, causal, + local, alibi, deterministic, mha_type, + dtype): + if d % 8 != 0: + pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") + + device = "cuda" + # set seed + torch.random.manual_seed(101) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else tuple( + torch.randint(0, seqlen_k, (2,)).tolist()) + torch.cuda.synchronize() + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + softmax_scale = q.shape[-1]**(-0.5) + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + do = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + rng_state = torch.Tensor([0, 0]).to(torch.int64).to(device) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + attention_mask = torch.zeros( + batch_size, seqlen_k, dtype=torch.int32).to(device) + k_lengths = torch.randint(low=2, high=seqlen_k, size=(batch_size,)) + for i in range(batch_size): + k_len = k_lengths[i].item() + attention_mask[i, :k_len] = 1 + q[i, k_len:, :, :] = 0 + k[i, k_len:, :, :] = 0 + v[i, k_len:, :, :] = 0 + do[i, k_len:, :, :] = 0 + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + + q_cuda, k_cuda, v_cuda, do_cuda, dq_cuda, dk_cuda, dv_cuda, \ + indices_q, indices_k, cu_seq_lens, max_seq_lens = _unpad_input( + q, k, v, do, dq, dk, dv, attention_mask, seqlen_q, nheads + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + + if alibi: + alibi_slopes = torch.rand( + batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + else: + alibi_slopes = None + + o_cuda, softmax_lse_cuda, _ = flash_attn_varlen_func( + q_cuda.contiguous(), + k_cuda.contiguous(), + v_cuda.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + max_seqlen_q, + max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + dq_cuda, dk_cuda, dv_cuda, softmax_d_cuda = flash_attn_cuda.varlen_bwd( + do_cuda.contiguous(), q_cuda.contiguous(), k_cuda.contiguous(), + v_cuda.contiguous(), o_cuda.contiguous(), softmax_lse_cuda.contiguous(), + dq_cuda, dk_cuda, dv_cuda, cu_seqlens_q, cu_seqlens_k, alibi_slopes, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, + window_size[0], window_size[1], deterministic, None, rng_state) + + dq_cuda = pad_input(dq_cuda, indices_q, batch_size, seqlen_q) + dk_cuda = pad_input(dk_cuda, indices_k, batch_size, seqlen_k) + dv_cuda = pad_input(dv_cuda, indices_k, batch_size, seqlen_k) + softmax_d_cuda = softmax_d_cuda[:, :, :seqlen_q] + + q = q.cpu().detach() + k = k.cpu().detach() + v = v.cpu().detach() + do = do.cpu().detach() + rng_state = rng_state.cpu().detach() + + dq_cuda = dq_cuda.cpu().detach() + dk_cuda = dk_cuda.cpu().detach() + dv_cuda = dv_cuda.cpu().detach() + softmax_d_cuda = softmax_d_cuda.cpu().detach() + if alibi: + alibi_slopes = alibi_slopes.cpu() + torch.cuda.synchronize() + + device = ta.lazy_device() + torch.random.manual_seed(101) + + indices_q = indices_q.cpu() + indices_k = indices_k.cpu() + + q_xla = q.flatten(0,1)[indices_q].unsqueeze(0).to(device) + k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device) + v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device) + do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device) + + def attention_mask_to_position_ids(attention_mask): + seqlens = attention_mask.sum(dim=1).flatten() + position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0) + return position_ids.unsqueeze(0) + + position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device) + + rng_state_xla = rng_state.to(device) + if alibi: + alibi_slopes = alibi_slopes.cpu().to(device) + + + softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward( + q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(), + position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], True, None) + + assert torch.allclose(q_xla.cpu(), q_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(k_xla.cpu(), k_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(v_xla.cpu(), v_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose( + cu_seqlen_k_xla[:batch_size+1].cpu(), cu_seqlens_k.cpu(), rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(o_xla.cpu(), o_cuda.unsqueeze(0).cpu(), rtol=1e-2, atol=1e-2, equal_nan=True) + + q_xla.requires_grad = True + k_xla.requires_grad = True + v_xla.requires_grad = True + o_xla.requires_grad = True + softmax_lse_xla.requires_grad = True + + dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_position_ids_backward( + do_xla.contiguous(), q_xla.contiguous(), k_xla.contiguous(), + v_xla.contiguous(), o_xla.contiguous(), softmax_lse_xla.contiguous(), + cu_seqlen_q_xla, cu_seqlen_k_xla, alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], deterministic, None, + rng_state_xla) + + ta.mark_step(wait=True) + torch.cuda.synchronize() + + dq_xla = dq_xla.cpu().detach().squeeze(0) + dk_xla = dk_xla.cpu().detach().squeeze(0) + dv_xla = dv_xla.cpu().detach().squeeze(0) + do_xla = do_xla.cpu().detach().squeeze(0) + softmax_d_xla = softmax_d_xla.cpu().detach() + + dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q) + dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k) + dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k) + + assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) + diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py index b17d0c4e91a2..30e4a4aa5287 100644 --- a/test/test_flash_attention_varlen_forward.py +++ b/test/test_flash_attention_varlen_forward.py @@ -178,6 +178,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, return_attn_probs=True, ) + out_fa = pad_input(out_fa, indices_q, batch_size, seqlen_q) q = q.cpu().detach() @@ -225,7 +226,193 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, cu_seqlen_q_xla, cu_seqlens_q, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True) - softmax_lse_xla = softmax_lse_xla[:, :, :max_seqlen_in_batch_q] - softmax_lse = softmax_lse[:, :, :max_seqlen_in_batch_q] + for i in range(len(cu_seq_lens[0]) - 1): + seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i] + assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True) + + + + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32]) +@pytest.mark.parametrize("softmax_scale", [0.25]) +@pytest.mark.parametrize( + "max_seqlen_q,max_seqlen_k", + [ + (8, 8), + (128, 128), + (2048, 2048), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, dropout_p, causal, + softmax_scale, local, alibi, deterministic, mha_type, + dtype): + if d % 8 != 0: + pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") + + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else tuple( + torch.randint(0, max_seqlen_k, (2,)).tolist()) + torch.cuda.synchronize() + + def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, n_heads_q, n_heads_k, device, head_dim, use_same_seqlen): + ''' + generate varlen qkv and postion_ids and pad total seqlen to 8 + ''' + seq_len_q = torch.randint(1,max_seqlen_q+1,(batch_size,)) + if use_same_seqlen: + seq_len_k = seq_len_q.clone() + else: + seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,)) + total_q = seq_len_q.sum().item() + total_k = seq_len_k.sum().item() + + padd_q = 0 if total_q % 8 == 0 else 8 - total_q % 8 + padd_k = 0 if total_k % 8 == 0 else 8 - total_k % 8 + if use_same_seqlen: + assert torch.all(seq_len_k == seq_len_q) + + # padding to last q and k + if padd_q: + seq_len_q[-1] += padd_q + total_q += padd_q + assert total_q % 8 == 0 + if padd_k: + seq_len_k[-1] += padd_k + total_k += padd_k + assert total_k % 8 == 0 + + q = torch.randn((1,total_q,n_heads_q,head_dim),dtype=dtype,device=device) + k = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device) + v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device) + + assert torch.all(seq_len_q > 0) + assert torch.all(seq_len_k > 0) + + position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0) + position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0) + assert position_ids_q.shape[1] % 8 == 0 + assert position_ids_k.shape[1] % 8 == 0 + + return q,k,v,position_ids_q,position_ids_k + + + q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k) + + indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64) + indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64) + + seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0] + seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0] + + cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32) + cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32) + + max_seqlen_q = (cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]).max().item() + max_seqlen_k = (cu_seq_lens_k[1:] - cu_seq_lens_k[:-1]).max().item() + + q_cuda = q.reshape(-1,nheads,d).cuda() + k_cuda = k.reshape(-1,nheads_k,d).cuda() + v_cuda = v.reshape(-1,nheads_k,d).cuda() + + + if alibi: + alibi_slopes = torch.rand( + batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + else: + alibi_slopes = None + + + out_fa, softmax_lse, _ = flash_attn_varlen_func( + q_cuda.contiguous(), + k_cuda.contiguous(), + v_cuda.contiguous(), + cu_seq_lens_q.contiguous(), + cu_seq_lens_k.contiguous(), + max_seqlen_q, + max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + out_fa = pad_input(out_fa, indices_q, batch_size, max_seqlen_q) + + q = q.cpu().detach() + k = k.cpu().detach() + v = v.cpu().detach() + + out_fa = out_fa.cpu().detach() + softmax_lse = softmax_lse.cpu().detach() + cu_seq_lens_q = cu_seq_lens_q.cpu().detach() + cu_seq_lens_k = cu_seq_lens_k.cpu().detach() + + if alibi: + alibi_slopes = alibi_slopes.cpu() + torch.cuda.synchronize() + + device = ta.lazy_device() + torch.random.manual_seed(0) + q_xla = q.to(device) + k_xla = k.to(device) + v_xla = v.to(device) + + position_ids_q_xla = position_ids_q.to(device) + + q_xla.requires_grad = False + k_xla.requires_grad = False + v_xla.requires_grad = False + if alibi: + alibi_slopes = alibi_slopes.cpu().to(device) + + softmax_lse_xla, out_xla, _, cu_seq_len_q_xla, cu_seq_len_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward( + q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(), + position_ids_q_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], True, None) + + + ta.mark_step(wait=True) + q_xla = q_xla.cpu().detach() + k_xla = k_xla.cpu().detach() + v_xla = v_xla.cpu().detach() + out_xla = out_xla.cpu().detach() + cu_seq_len_q_xla = cu_seq_len_q_xla.cpu().detach() + cu_seq_len_k_xla = cu_seq_len_k_xla.cpu().detach() + softmax_lse_xla = softmax_lse_xla.cpu().detach() + position_ids_q_xla = position_ids_q_xla.cpu().detach() + + out_xla = out_xla.squeeze(0) + + out_xla = pad_input(out_xla, indices_q, batch_size, max_seqlen_q) + + assert out_xla.shape == out_fa.shape + + assert torch.allclose(q_xla, q, rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(k_xla, k, rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(v_xla, v, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( - softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2, equal_nan=True) + cu_seq_len_k_xla[:batch_size+1], cu_seq_lens_k, rtol=1e-3, atol=1e-3, equal_nan=True) + assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True) + for i in range(batch_size): + start_idx = cu_seq_len_q_xla[i] + end_idx = cu_seq_len_q_xla[i+1] + seqlen = end_idx - start_idx + assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 2c3331c8d1d9..546a00c71947 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -2,6 +2,7 @@ #include <ATen/cuda/CUDAContext.h> #include <torch/extension.h> +#include <iostream> #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -236,7 +237,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b, FlashAttentionForwardParams get_flash_attention_forward_params( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - c10::optional<at::Tensor>& attention_mask, // (batch_size, seqlen) + c10::optional<at::Tensor> attention_mask, // (batch_size, seqlen) + c10::optional<at::Tensor> position_ids, // (1,seqlen_q) c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool return_softmax) { @@ -258,7 +260,7 @@ FlashAttentionForwardParams get_flash_attention_forward_params( const auto sizes = q.sizes(); const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; + const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); @@ -274,10 +276,15 @@ FlashAttentionForwardParams get_flash_attention_forward_params( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + if (attention_mask.has_value()) { TORCH_CHECK(attention_mask.value().dtype() == torch::kInt32); CHECK_SHAPE(attention_mask.value(), batch_size, seqlen_k); } + if (position_ids.has_value()){ + TORCH_CHECK(position_ids.value().dtype() == torch::kInt32); + CHECK_SHAPE(position_ids.value(),1, seqlen_q); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); @@ -308,7 +315,6 @@ FlashAttentionForwardParams get_flash_attention_forward_params( p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, alibi_slopes_batch_stride, enable_alibi_slopes, /*seqlenq_ngroups_swapped*/ false); - return params; } @@ -382,9 +388,13 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); - TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1); - CHECK_SHAPE(cu_seqlens_q.value(), batch_size + 1); - CHECK_SHAPE(cu_seqlens_k.value(), batch_size + 1); + TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || batch_size == 1); + TORCH_CHECK(cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size+1}) || + cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q+1}), + "cu_seqlens_q shape should be batch_size+1 or seqlen_q"); + TORCH_CHECK(cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size+1}) || + cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k+1}), + "cu_seqlens_k shape should be batch_size+1 or seqlen_k"); } int alibi_slopes_batch_stride = 0; @@ -412,6 +422,18 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( return params; } + +at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, + int& max_seqlen_in_batch,int& total_q, int& real_batch_size) { + const at::Tensor valid_cu_seqlens = padded_cu_seqlens.index({padded_cu_seqlens > -1}); + real_batch_size = valid_cu_seqlens.size(0)-1; + at::Tensor seqs_len = valid_cu_seqlens.slice(0,1,valid_cu_seqlens.size(0)) - + valid_cu_seqlens.slice(0,0,valid_cu_seqlens.size(0) - 1); + max_seqlen_in_batch = seqs_len.max().item<int>(); + total_q = valid_cu_seqlens[-1].item<int>(); + return torch::arange(total_q,torch::dtype(torch::kInt64).device(torch::kCUDA)); +} + at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, int seqlen, torch::Dtype scalar_type, int& max_seqlen_in_batch, int& total) { @@ -422,10 +444,10 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); torch::Tensor rows = - torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); + torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); // (batch_size,1) torch::Tensor cols = - torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); - torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); + torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); // (1,seqlen) + torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1) max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>(); torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32)); @@ -451,19 +473,69 @@ at::Tensor mask_to_indices(const at::Tensor& attention_mask, return indices; } +torch::Tensor unpad_softmax_lse( + const torch::Tensor& pad_softmax_lse, // (batch_size, nhead, max_seqlen) + const torch::Tensor& cu_seqlens) // (total_seqlen + 1) +{ + int batch_size = pad_softmax_lse.size(0); + int nhead = pad_softmax_lse.size(1); + int max_seqlen = pad_softmax_lse.size(2); + int total, max_seqlen_in_batch; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1); + at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seqlen,torch::kInt64,max_seqlen_in_batch,total); + at::Tensor result = at::empty({total,nhead},pad_softmax_lse.options()); + result.copy_(pad_softmax_lse.transpose(1,2).reshape({batch_size*max_seqlen,nhead}).index({indices,torch::indexing::Slice()})); + return result.transpose(0,1).unsqueeze(0); +} + +torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, // (1,nheads,total_seqlen) + const at::Tensor& cu_seqlens,const int max_seq_len, const int batch_size){ + const int nheads = softmax_lse.size(1); + int max_seqlen_in_batch; + int total; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1); + at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seq_len, + torch::kInt32,max_seqlen_in_batch,total); + TORCH_CHECK(indices.size(0) == softmax_lse.size(2),"indice should be same size with softmax_lse") + + at::Tensor result = at::empty({batch_size * max_seq_len,nheads},softmax_lse.options()); + + result.index_put_({indices,torch::indexing::Slice()},softmax_lse.squeeze(0).transpose(0,1)); + return result.reshape({batch_size,max_seq_len,nheads}).transpose(1,2).contiguous(); +} + +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size){ + at::Tensor flatten_position_ids = position_ids.flatten(); + at::Tensor indices = torch::arange(flatten_position_ids.size(0), torch::dtype(torch::kInt64).device(torch::kCUDA)); + + at::Tensor batch_seq_start_idx = indices.index({flatten_position_ids == 0}); + real_batch_size = batch_seq_start_idx.size(0); + at::Tensor batch_seqlen_cumsum = at::empty({real_batch_size+1},torch::dtype(torch::kInt32).device(torch::kCUDA)); + batch_seqlen_cumsum.index({torch::indexing::Slice(0,real_batch_size)}) = batch_seq_start_idx; + total = flatten_position_ids.size(0); + batch_seqlen_cumsum.index({-1}) = total; + + at::Tensor batch_seqlen = batch_seqlen_cumsum.slice(0,1,batch_seqlen_cumsum.size(0)) - batch_seqlen_cumsum.slice(0,0,batch_seqlen_cumsum.size(0)-1); + max_seqlen_in_batch = batch_seqlen.max().item<int>(); + cu_seqlen.narrow(0,0,real_batch_size+1) = batch_seqlen_cumsum; + return indices; +} + + + at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices) { torch::IntArrayRef sizes = input.sizes(); - int64_t first_axis_dim = sizes[0]; - auto other_shape = sizes.slice(1, sizes.size() - 1); + int64_t first_axis_dim = sizes[0]; // bs + auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h] int64_t second_dim = 1; for (auto dim : other_shape) { second_dim *= dim; } - at::Tensor flat_input = torch::flatten(input, 1); - torch::Tensor repeated_indices = + at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah] + torch::Tensor repeated_indices = indices.unsqueeze(1).expand({indices.size(0), second_dim}); at::Tensor gather_input = torch::gather(flat_input, 0, repeated_indices); std::vector<int64_t> reshaped_size = {-1}; diff --git a/torch_xla/csrc/flash_attention_utils.h b/torch_xla/csrc/flash_attention_utils.h index a2696f25f82a..1692e1bd8678 100644 --- a/torch_xla/csrc/flash_attention_utils.h +++ b/torch_xla/csrc/flash_attention_utils.h @@ -112,7 +112,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b, FlashAttentionForwardParams get_flash_attention_forward_params( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - c10::optional<at::Tensor>& attention_mask, + c10::optional<at::Tensor> attention_mask, + c10::optional<at::Tensor> position_ids, c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool return_softmax); @@ -130,10 +131,17 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, int seqlen, torch::Dtype scalar_type, int& max_seqlen_in_batch, int& total); +at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, + int& max_seqlen_in_batch,int& total_q,int& real_batch_size); + at::Tensor mask_to_indices(const at::Tensor& attention_mask, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen); +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, + int& max_seqlen_in_batch, int& total, + at::Tensor& cu_seqlen, int& real_batch_size); + at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices); // make xla::Shape like input, same elemet_type and dimension as of input, @@ -142,5 +150,9 @@ xla::Shape shape_like(const torch::lazy::Value& input); xla::Shape shape_like(const xla::XlaBuilder* builder, const xla::XlaOp& input); +torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, const torch::Tensor& cu_seqlens); + +torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse,const at::Tensor& cu_seqlens, + const int max_seq_len, const int batch_size); } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_FLASH_ATTENTION_UTILS_H \ No newline at end of file diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bce6795ae8db..55d70bc03ed6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2409,7 +2409,7 @@ void InitXlaModuleBindings(py::module m) { c10::optional<at::Generator> gen_) { // get launch params on at::Tensor auto params = get_flash_attention_forward_params( - q, k, v, attention_mask, alibi_slopes, p_dropout, softmax_scale, + q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, softmax_scale, zero_tensors, is_causal, window_size_left, window_size_right, return_softmax); // call flash attention forward @@ -2440,6 +2440,50 @@ void InitXlaModuleBindings(py::module m) { } return results; }); + + m.def("_flash_attention_position_ids_forward", + [](const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, + c10::optional<at::Tensor>& position_ids, + c10::optional<at::Tensor>& alibi_slopes, const float p_dropout, + const float softmax_scale, const bool zero_tensors, + const bool is_causal, const int window_size_left, + const int window_size_right, const bool return_softmax, + c10::optional<at::Generator> gen_) { + // get launch params on at::Tensor + auto params = get_flash_attention_forward_params( + q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, softmax_scale, + zero_tensors, is_causal, window_size_left, window_size_right, + return_softmax); + // call flash attention forward + XLATensorPtr q_xla = bridge::GetXlaTensor(q); + XLATensorPtr k_xla = bridge::GetXlaTensor(k); + XLATensorPtr v_xla = bridge::GetXlaTensor(v); + XLATensorPtr alibi_slopes_xla = + alibi_slopes.has_value() + ? bridge::GetXlaTensor(alibi_slopes.value()) + : XLATensorPtr(); + + std::vector<XLATensorPtr> xresults; + if (position_ids.has_value()) { + + XLATensorPtr position_ids_xla = + bridge::GetXlaTensor(position_ids.value()); + xresults = tensor_methods::flash_attention_varlen_position_ids_forward( + q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla, + params.ToString()); + } else { + xresults = tensor_methods::flash_attention_forward( + q_xla, k_xla, v_xla, alibi_slopes_xla, params.ToString()); + } + std::vector<at::Tensor> results; + for (auto& xresult : xresults) { + at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult)); + results.push_back(torch::autograd::make_variable( + tensor, /*requires_grad=*/false)); + } + return results; + }); + m.def( "_flash_attention_backward", [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k, @@ -2492,6 +2536,59 @@ void InitXlaModuleBindings(py::module m) { } return results; }); + m.def( + "_flash_attention_position_ids_backward", + [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k, + const at::Tensor& v, const at::Tensor& out, + const at::Tensor& softmax_lse, c10::optional<at::Tensor>& cu_seqlens_q, + c10::optional<at::Tensor>& cu_seqlens_k, + c10::optional<at::Tensor>& alibi_slopes, const float p_dropout, + const float softmax_scale, const bool zero_tensors, + const bool is_causal, const int window_size_left, + const int window_size_right, const bool deterministic, + c10::optional<at::Generator> gen_, const at::Tensor& rng_state) { + // get launch params on at::Tensor + auto params = get_flash_attention_backward_params( + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + alibi_slopes, p_dropout, softmax_scale, zero_tensors, is_causal, + window_size_left, window_size_right, deterministic); + // call flash attention backward + XLATensorPtr dout_xla = bridge::GetXlaTensor(dout); + XLATensorPtr q_xla = bridge::GetXlaTensor(q); + XLATensorPtr k_xla = bridge::GetXlaTensor(k); + XLATensorPtr v_xla = bridge::GetXlaTensor(v); + XLATensorPtr out_xla = bridge::GetXlaTensor(out); + XLATensorPtr softmax_lse_xla = bridge::GetXlaTensor(softmax_lse); + XLATensorPtr rng_state_xla = bridge::GetXlaTensor(rng_state); + XLATensorPtr alibi_slopes_xla = + alibi_slopes.has_value() + ? bridge::GetXlaTensor(alibi_slopes.value()) + : XLATensorPtr(); + + std::vector<XLATensorPtr> xresults; + if (cu_seqlens_q.has_value() && cu_seqlens_k.has_value()) { + XLATensorPtr cu_seqlens_q_xla = + bridge::GetXlaTensor(cu_seqlens_q.value()); + XLATensorPtr cu_seqlens_k_xla = + bridge::GetXlaTensor(cu_seqlens_k.value()); + xresults = tensor_methods::flash_attention_varlen_position_ids_backward( + dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, + cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla, + alibi_slopes_xla, params.ToString()); + } else { + xresults = tensor_methods::flash_attention_backward( + dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, + rng_state_xla, alibi_slopes_xla, params.ToString()); + } + std::vector<at::Tensor> results; + for (auto& xresult : xresults) { + at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult)); + results.push_back( + torch::autograd::make_variable(tensor, /*requires_grad=*/false)); + } + return results; + }); + // -------------FlashAttention Integration API End------------------- // -------------Dynamo Integration API Start------------------------- diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp index 478f9f114df2..1a7cfed5bfa3 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp @@ -3,6 +3,7 @@ #include <ATen/cuda/CUDAContext.h> #include <c10/cuda/CUDAGuard.h> #include <torch/extension.h> +#include <iostream> #include "cutlass/numeric_types.h" #include "flash.h" @@ -16,7 +17,6 @@ namespace torch_xla { namespace { - xla::Shape NodeOutputShape(const torch::lazy::Value& q) { auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions()); xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape( @@ -36,7 +36,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) { // buffers[0] = q // buffers[1] = k // buffers[2] = v -// buffers[3] = attention_mask +// buffers[3] = attention_mask // buffers[4] = alibi_slopes // buffers[5] = softmax_lse // this is output // buffers[6] = out_for_output // this is output @@ -47,6 +47,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { + std::string opaque_str(opaque, opaque_len); TF_VLOG(3) << "custom_call_flash_attention_varlen_forward opaque str: " << opaque_str; @@ -98,21 +99,22 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, int max_seqlen_in_batch_k = params.seqlen_k; int total_k = params.b * params.seqlen_k; - at::Tensor indices_k = mask_to_indices(attention_mask, max_seqlen_in_batch_k, - total_k, cu_seqlens_k); - auto unpad_k = index_first_axis(k, indices_k); + total_k,cu_seqlens_k); + + auto unpad_k = index_first_axis(k, indices_k); auto unpad_v = index_first_axis(v, indices_k); - int max_seqlen_in_batch_q = max_seqlen_in_batch_k; + int max_seqlen_in_batch_q = max_seqlen_in_batch_k; int total_q = total_k; at::Tensor indices_q; if (params.seqlen_q == params.seqlen_k) { cu_seqlens_q.copy_(cu_seqlens_k); indices_q = indices_k; - } else if (params.seqlen_q == 1) { + } else if (params.seqlen_q == 1){ max_seqlen_in_batch_q = 1; + cu_seqlens_q = torch::arange(0,params.b+1,opts); indices_q = cu_seqlens_q.slice(/*dim=*/0, /*start=*/0, /*end=*/params.b); total_q = params.b; } else { @@ -120,15 +122,14 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, /*dim=*/1, /*start=*/-params.seqlen_q, /*end=*/torch::indexing::None); indices_q = mask_to_indices(attention_mask_slice, max_seqlen_in_batch_q, total_q, cu_seqlens_q); - } - + } at::Tensor unpad_q = index_first_axis(q, indices_q); at::Tensor unpad_output = torch::zeros({total_q, params.h * params.d}, opts.dtype(scalar_type)); at::Tensor unpad_softmax_lse = torch::zeros( - {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); - + {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); + if (max_seqlen_in_batch_q == 1) { params.is_causal = false; } @@ -241,7 +242,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, } TF_VLOG(2) << "Running FlashAttention Forward."; - FP16_SWITCH(!launch_params.is_bf16, [&] { HEADDIM_SWITCH(launch_params.d, [&] { // TODO(wenting.swt): support split_kv @@ -261,6 +261,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, cudaEventRecord(xla_wait_torch_event, torch_stream); cudaStreamWaitEvent(stream, xla_wait_torch_event); } + XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_forward, "CUDA"); @@ -321,7 +322,7 @@ torch::lazy::NodePtr FlashAttentionVarlenForward::Clone( } } -XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const { +XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const{ xla::XlaOp q = loctx->GetOutputOp(operand(0)); xla::XlaOp k = loctx->GetOutputOp(operand(1)); xla::XlaOp v = loctx->GetOutputOp(operand(2)); diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp new file mode 100644 index 000000000000..435b40e822cb --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -0,0 +1,430 @@ +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h" + +#include <ATen/cuda/CUDAContext.h> +#include <c10/cuda/CUDAGuard.h> +#include <torch/extension.h> + +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "static_switch.h" +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "xla/service/custom_call_target_registry.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& softmax_lse) { + return xla::ShapeUtil::MakeTupleShape( + {shape_like(q), shape_like(k), shape_like(v), shape_like(softmax_lse)}); +} + +void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, + const bool configure) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_<elem_type, kHeadDim>(params, stream); }); + }); +} + +// Layout of `buffers` listed above: +// buffers[0] = dout +// buffers[1] = q +// buffers[2] = k +// buffers[3] = v +// buffers[4] = out +// buffers[5] = softmax_lse +// buffers[6] = cu_seqlens_q +// buffers[7] = cu_seqlens_k +// buffers[8] = rng_state +// buffers[9] = alibi_slopes +// buffers[10] = dq // this is output +// buffers[11] = dk // this is output +// buffers[12] = dv // this is output +// buffers[13] = softmax_d // this is output +void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t stream, + void** buffers, + const char* opaque, + size_t opaque_len) { + std::string opaque_str(opaque, opaque_len); + TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_backward opaque str: " + << opaque_str; + FlashAttentionBackwardParams params; + params.FromString(std::move(opaque_str)); + int buf_offset = params.enable_alibi_slopes; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + + cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream(); + cudaEvent_t torch_wait_xla_event; + cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming); + cudaEvent_t xla_wait_torch_event; + cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming); + cudaEventRecord(torch_wait_xla_event, stream); + cudaStreamWaitEvent(torch_stream, torch_wait_xla_event); + + auto cuda_stream = at::cuda::getCurrentCUDAStream(); + at::cuda::CUDAStreamGuard guard(cuda_stream); + + auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); + + // Inputs + at::Tensor do_ = torch::from_blob( + buffers[0], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor q = torch::from_blob( + buffers[1], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor k = torch::from_blob( + buffers[2], {params.b * params.seqlen_k, params.h_k, params.d}, opts); + at::Tensor v = torch::from_blob( + buffers[3], {params.b * params.seqlen_k, params.h_k, params.d}, opts); + at::Tensor o = torch::from_blob( + buffers[4], {params.b * params.seqlen_q, params.h, params.d}, opts); + at::Tensor softmax_lse = + torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + at::Tensor cu_seqlens_q = + torch::from_blob(buffers[6], {params.seqlen_q + 1}, opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_k = + torch::from_blob(buffers[7], {params.seqlen_k + 1}, opts.dtype(torch::kInt32)); + + // Outputs + at::Tensor dq = + torch::from_blob(buffers[9 + buf_offset], + {params.b * params.seqlen_q, params.h, params.d}, opts); + + at::Tensor dk = torch::from_blob( + buffers[10 + buf_offset], + {params.b * params.seqlen_k, params.h_k, params.d}, opts); + + at::Tensor dv = torch::from_blob( + buffers[11 + buf_offset], + {params.b * params.seqlen_k, params.h_k, params.d}, opts); + + at::Tensor dsoftmax_sum = torch::from_blob( + buffers[12 + buf_offset], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + + // Fill zeros for outputs. + dq.fill_(0); + dk.fill_(0); + dv.fill_(0); + dsoftmax_sum.fill_(0); + + int max_seqlen_in_batch_q = params.seqlen_q; + int max_seqlen_in_batch_k = params.seqlen_k; + int total_q = params.b * params.seqlen_q; + int total_k = params.b * params.seqlen_k; + int real_batch_size; + at::Tensor indices_q = + cu_seqlens_to_indices(cu_seqlens_q,max_seqlen_in_batch_q,total_q,real_batch_size); + + at::Tensor indices_k; + if (params.seqlen_q == params.seqlen_k) { + indices_k = indices_q; + max_seqlen_in_batch_k = max_seqlen_in_batch_q; + total_k = total_q; + } else { + indices_k = + cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, total_k,real_batch_size); + } + + auto padded_softmax_lse = pad_softmax_lse(softmax_lse,cu_seqlens_q,max_seqlen_in_batch_q,real_batch_size); + + Flash_bwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data_ptr(); + launch_params.k_ptr = k.data_ptr(); + launch_params.v_ptr = v.data_ptr(); + launch_params.o_ptr = o.data_ptr(); + + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q.data_ptr()); + launch_params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k.data_ptr()); + launch_params.softmax_lse_ptr = padded_softmax_lse.data_ptr(); + + launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[9] : nullptr; + + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Set the dimensions. + launch_params.b = real_batch_size; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = max_seqlen_in_batch_q; + launch_params.seqlen_k = max_seqlen_in_batch_k; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128); + launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128); + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + if (max_seqlen_in_batch_q == 1) { + params.is_causal = false; + } + if (params.is_causal) { + params.window_size_right = 0; + } + + if (params.window_size_left >= max_seqlen_in_batch_k) { + params.window_size_left = -1; + } + if (params.window_size_right >= max_seqlen_in_batch_k) { + params.window_size_right = -1; + } + + launch_params.is_causal = + params.window_size_left < 0 && params.window_size_right == 0; + + if (params.window_size_left < 0 && params.window_size_right >= 0) { + params.window_size_left = max_seqlen_in_batch_k; + } + if (params.window_size_left >= 0 && params.window_size_right < 0) { + params.window_size_right = max_seqlen_in_batch_k; + } + + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = true; + + launch_params.do_row_stride = params.do_row_stride; + launch_params.do_head_stride = params.do_head_stride; + + launch_params.dq_row_stride = params.dq_row_stride; + launch_params.dk_row_stride = params.dk_row_stride; + launch_params.dv_row_stride = params.dv_row_stride; + launch_params.dq_head_stride = params.dq_head_stride; + launch_params.dk_head_stride = params.dk_head_stride; + launch_params.dv_head_stride = params.dv_head_stride; + + at::Tensor rounded_dsoftmax_sum = + at::zeros({real_batch_size, params.h, launch_params.seqlen_q_rounded}, + opts.dtype(torch::kFloat)); + + launch_params.do_ptr = do_.data_ptr(); + launch_params.dq_ptr = dq.data_ptr(); + launch_params.dk_ptr = dk.data_ptr(); + launch_params.dv_ptr = dv.data_ptr(); + launch_params.dsoftmax_sum = rounded_dsoftmax_sum.data_ptr(); + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + at::Tensor dq_accum; + if (loop) { + if (!params.deterministic) { + dq_accum = torch::empty({total_q + 128 * launch_params.b, launch_params.h, + launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + const int nsplits = (dprops->multiProcessorCount + + launch_params.b * launch_params.h - 1) / + (launch_params.b * launch_params.h); + dq_accum = torch::zeros({nsplits, total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } + } + + at::Tensor dk_expanded, dv_expanded; + + if (launch_params.h_k != launch_params.h) { // MQA / GQA + TF_VLOG(2) << "Running FlashAttention Backward as MQA/GQA"; + dk_expanded = + torch::empty({total_k, launch_params.h, launch_params.d}, opts); + dv_expanded = + torch::empty({total_k, launch_params.h, launch_params.d}, opts); + + launch_params.dk_ptr = dk_expanded.data_ptr(); + launch_params.dv_ptr = dv_expanded.data_ptr(); + launch_params.dk_row_stride = dk_expanded.stride(-3); + launch_params.dv_row_stride = dv_expanded.stride(-3); + launch_params.dk_head_stride = dk_expanded.stride(-2); + launch_params.dv_head_stride = dv_expanded.stride(-2); + } else { + TF_VLOG(2) << "Running FlashAttention Backward"; + dk_expanded = dk; + dv_expanded = dv; + } + + launch_params.dq_accum_ptr = loop ? dq_accum.data_ptr() : nullptr; + launch_params.dk_accum_ptr = nullptr; + launch_params.dv_accum_ptr = nullptr; + + launch_params.deterministic = params.deterministic; + launch_params.dq_accum_split_stride = + !launch_params.deterministic ? 0 : dq_accum.stride(0); + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + + bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; + + // TODO(wenting.swt): According to the implementation in + // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates + // `rng_state` which is passed as ctx to the backward. Hence, for simplifying + // the logic, the redundant branch where `rng_state` is None has been omitted. + launch_params.rng_state = reinterpret_cast<uint64_t*>(buffers[8]); + + launch(launch_params, torch_stream, /*configure=*/false); + + + // For MQA/GQA we need to sum dK and dV across the groups + if (launch_params.h_k != launch_params.h) { + at::sum_out(dk, + at::reshape(dk_expanded, {total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + at::sum_out(dv, + at::reshape(dv_expanded, {total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + } + + dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum,cu_seqlens_q)); + + cudaEventRecord(xla_wait_torch_event, torch_stream); + cudaStreamWaitEvent(stream, xla_wait_torch_event); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_backward, + "CUDA"); + +std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward( + const xla::XlaOp& dout, const xla::XlaOp& q, const xla::XlaOp& k, + const xla::XlaOp& v, const xla::XlaOp& out, const xla::XlaOp& softmax_lse, + const xla::XlaOp& cu_seqlens_q, const xla::XlaOp& cu_seqlens_k, + const xla::XlaOp& rng_state, const xla::XlaOp& alibi_slopes, + const std::string& opaque, const xla::Shape& output_shape) { + auto builder = q.builder(); + std::vector<xla::XlaOp> operands{ + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state}; + std::vector<xla::Shape> operand_shapes_with_layout{ + shape_like(builder, dout), + shape_like(builder, q), + shape_like(builder, k), + shape_like(builder, v), + shape_like(builder, out), + shape_like(builder, softmax_lse), + builder->GetShape(cu_seqlens_q).value(), + builder->GetShape(cu_seqlens_k).value(), + builder->GetShape(rng_state).value()}; + if (alibi_slopes.valid()) { + operands.push_back(alibi_slopes); + operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes)); + } + xla::XlaOp result = xla::CustomCallWithLayout( + builder, "custom_call_flash_attention_varlen_position_ids_backward", + std::move(operands), output_shape, std::move(operand_shapes_with_layout), + opaque); + return {xla::GetTupleElement(result, 0), xla::GetTupleElement(result, 1), + xla::GetTupleElement(result, 2), xla::GetTupleElement(result, 3)}; +} + +} // namespace + +FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state, + const std::string params) + : XlaNode(xla_flash_attention_backward, + {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + rng_state}, + NodeOutputShape(q, k, v, softmax_lse), + /*num_outputs=*/4, torch::lazy::MHash(params)), + params_(params) {} + +FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state, + const torch::lazy::Value& alibi_slopes, const std::string params) + : XlaNode(xla_flash_attention_backward, + {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + rng_state, alibi_slopes}, + NodeOutputShape(q, k, v, softmax_lse), + /*num_outputs=*/4, torch::lazy::MHash(params)), + params_(params) {} + +torch::lazy::NodePtr FlashAttentionVarlenPositionIdsBackward::Clone( + torch::lazy::OpList operands) const { + if (operands.size() > 9) { + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), operands.at(5), operands.at(6), operands.at(7), + operands.at(8), operands.at(9), params_); + } else { + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), operands.at(5), operands.at(6), operands.at(7), + operands.at(8), params_); + } +} + +XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp dout = loctx->GetOutputOp(operand(0)); + xla::XlaOp q = loctx->GetOutputOp(operand(1)); + xla::XlaOp k = loctx->GetOutputOp(operand(2)); + xla::XlaOp v = loctx->GetOutputOp(operand(3)); + xla::XlaOp out = loctx->GetOutputOp(operand(4)); + xla::XlaOp softmax_lse = loctx->GetOutputOp(operand(5)); + xla::XlaOp cu_seqlens_q = loctx->GetOutputOp(operand(6)); + xla::XlaOp cu_seqlens_k = loctx->GetOutputOp(operand(7)); + xla::XlaOp rng_state = loctx->GetOutputOp(operand(8)); + xla::XlaOp alibi_slopes = + operands().size() > 9 ? loctx->GetOutputOp(operand(9)) : xla::XlaOp(); + std::vector<xla::XlaOp> result = BuildFlashAttentionVarlenPositionIdsBackward( + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + alibi_slopes, params_, xla_shape()); + + return ReturnOps({result}, loctx); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h new file mode 100644 index 000000000000..42ba728ceb02 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h @@ -0,0 +1,38 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ +#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ + +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class FlashAttentionVarlenPositionIdsBackward : public XlaNode { + public: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, const std::string params); + + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, + const torch::lazy::Value& alibi_slopes, const std::string params); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + const std::string params_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp new file mode 100644 index 000000000000..9403879add54 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp @@ -0,0 +1,325 @@ +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" + +#include <ATen/cuda/CUDAContext.h> +#include <c10/cuda/CUDAGuard.h> +#include <torch/extension.h> +#include <iostream> + +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "static_switch.h" +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/xla_lower_util.h" +#include "xla/service/custom_call_target_registry.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& q) { + auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions()); + xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape( + xla::PrimitiveType::F32, + {q_shape[0], q_shape[2], + q_shape[1]}); // 1, num_heads, total_q + xla::Shape rng_state_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); + xla::Shape cu_seqlens_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim] + return xla::ShapeUtil::MakeTupleShape({softmax_lse_shape, shape_like(q), + rng_state_shape, cu_seqlens_shape, + cu_seqlens_shape}); +} + +// Layout of `buffers` listed above: +// buffers[0] = q +// buffers[1] = k +// buffers[2] = v +// buffers[3] = attention_mask or position_ids +// buffers[4] = alibi_slopes +// buffers[5] = softmax_lse // this is output +// buffers[6] = out_for_output // this is output +// buffers[7] = rng_state // this is output +// buffers[8] = cu_seqlen_q // this is output +// buffers[9] = cu_seqlen_k // this is output +void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream, + void** buffers, + const char* opaque, + size_t opaque_len) { + + std::string opaque_str(opaque, opaque_len); + TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_forward opaque str: " + << opaque_str; + FlashAttentionForwardParams params; + params.FromString(std::move(opaque_str)); + int buf_offset = params.enable_alibi_slopes; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + + cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream(); + cudaEvent_t torch_wait_xla_event; + cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming); + cudaEvent_t xla_wait_torch_event; + cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming); + cudaEventRecord(torch_wait_xla_event, stream); + cudaStreamWaitEvent(torch_stream, torch_wait_xla_event); + + auto cuda_stream = at::cuda::getCurrentCUDAStream(); + at::cuda::CUDAStreamGuard guard(cuda_stream); + + auto opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + + at::Tensor q = torch::from_blob( + buffers[0], {params.b * params.seqlen_q, params.h, params.d}, + opts.dtype(scalar_type)); + at::Tensor k = torch::from_blob( + buffers[1], {params.b * params.seqlen_k, params.h_k, params.d}, + opts.dtype(scalar_type)); + at::Tensor v = torch::from_blob( + buffers[2], {params.b * params.seqlen_k, params.h_k, params.d}, + opts.dtype(scalar_type)); + at::Tensor position_ids = + torch::from_blob(buffers[3], {params.b, params.seqlen_k}, opts); + at::Tensor softmax_lse = torch::from_blob( + buffers[4 + buf_offset], {params.b, params.h, params.seqlen_q}, + opts.dtype(torch::kFloat)); + at::Tensor o_output = + torch::from_blob(buffers[5 + buf_offset], + {params.b * params.seqlen_q, params.h * params.d}, + opts.dtype(scalar_type)); + at::Tensor cu_seqlens_q = + torch::from_blob(buffers[7 + buf_offset], {params.seqlen_q + 1}, opts); + at::Tensor cu_seqlens_k = + torch::from_blob(buffers[8 + buf_offset], {params.seqlen_k + 1}, opts); + at::Tensor rng_state = + torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64)); + softmax_lse.fill_(0); + o_output.fill_(0); + cu_seqlens_k.fill_(-1); + + int max_seqlen_in_batch_k = params.seqlen_k; + int total_k = params.b * params.seqlen_k; + at::Tensor indices_k; + + int real_batch_size; // packed qkv's batch size is 1, but fa need to know real batch size before packing. + indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k,total_k, + cu_seqlens_k,real_batch_size); + TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, "cu_seqlen'shape should be params.seqlen_k"); + TORCH_CHECK(indices_k.dtype() == torch::kInt64, "indice should be int64"); + + int max_seqlen_in_batch_q = max_seqlen_in_batch_k; + int total_q = total_k; + at::Tensor indices_q; + + TORCH_CHECK(params.seqlen_q == params.seqlen_k, "now only support same seqlen for q and k"); + + if (params.seqlen_q == params.seqlen_k) { + cu_seqlens_q.copy_(cu_seqlens_k); + indices_q = indices_k; + } else { + // TODO:(wangtianxing.wtx) support different seqlen_q and seqlen_k + indices_q = position_ids_to_indices(position_ids, max_seqlen_in_batch_q, + total_q, cu_seqlens_q,real_batch_size); + } + + if (max_seqlen_in_batch_q == 1) { + params.is_causal = false; + } + if (params.is_causal) { + params.window_size_right = 0; + } + + if (params.window_size_left >= max_seqlen_in_batch_k) { + params.window_size_left = -1; + } + if (params.window_size_right >= max_seqlen_in_batch_k) { + params.window_size_right = -1; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + at::Tensor pad_softmax_lse = at::empty({real_batch_size,params.h,max_seqlen_in_batch_q},torch::dtype(torch::kFloat).device(torch::kCUDA)); + + Flash_fwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data_ptr(); + launch_params.k_ptr = k.data_ptr(); + launch_params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = o_output.data_ptr(); + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q.data_ptr()); + launch_params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k.data_ptr()); + + launch_params.seqused_k = static_cast<int*>(nullptr); + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = pad_softmax_lse.data_ptr(); + + // Set the dimensions. + launch_params.b = real_batch_size; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = max_seqlen_in_batch_q; + launch_params.seqlen_k = max_seqlen_in_batch_k; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128); + launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128); + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = + params.window_size_left < 0 && params.window_size_right == 0; + + if (params.window_size_left < 0 && params.window_size_right >= 0) { + params.window_size_left = max_seqlen_in_batch_k; + } + if (params.window_size_left >= 0 && params.window_size_right < 0) { + params.window_size_right = max_seqlen_in_batch_k; + } + + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative; + + launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[4] : nullptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // set params splitkv + launch_params.num_splits = params.num_splits; + + int64_t counter_offset = params.b * params.h * 32; + + // Forward kernel will populate memory with the seed and offset. + launch_params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr()); + + if ((1.f - launch_params.p_dropout) > 0.0) { + // number of times random will be generated per thread, to offset philox + // counter in thc random state We use a custom RNG that increases the offset + // by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard<std::mutex> lock(gen->mutex_); + launch_params.philox_args = gen->philox_cuda_state(counter_offset); + } + TF_VLOG(2) << "Running FlashAttention Forward."; + FP16_SWITCH(!launch_params.is_bf16, [&] { + HEADDIM_SWITCH(launch_params.d, [&] { + // TODO(wenting.swt): support split_kv + run_mha_fwd_<elem_type, kHeadDim>(launch_params, torch_stream); + }); + }); + softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse,cu_seqlens_q)); + + // TODO(wenting.swt): we should pad and unpad q,k,v when head_size_og % 8 != 0 + // sync with cudaEvent + cudaEventRecord(xla_wait_torch_event, torch_stream); + cudaStreamWaitEvent(stream, xla_wait_torch_event); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_forward, + "CUDA"); + +std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsForward( + const xla::XlaOp& q, const xla::XlaOp& k, const xla::XlaOp& v, + const xla::XlaOp& position_ids, const xla::XlaOp& alibi_slopes, + const std::string& opaque, const xla::Shape& output_shape) { + auto builder = q.builder(); + std::vector<xla::XlaOp> operands{q, k, v, position_ids}; + std::vector<xla::Shape> operand_shapes_with_layout{ + shape_like(builder, q), shape_like(builder, k), shape_like(builder, v), + shape_like(builder, position_ids)}; + if (alibi_slopes.valid()) { + operands.push_back(alibi_slopes); + operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes)); + } + xla::XlaOp result = xla::CustomCallWithLayout( + builder, "custom_call_flash_attention_varlen_position_ids_forward", + std::move(operands), output_shape, std::move(operand_shapes_with_layout), + opaque); + return {/*softmax_lse*/ xla::GetTupleElement(result, 0), + /*output*/ xla::GetTupleElement(result, 1), + /*rng_state*/ xla::GetTupleElement(result, 2), + /*cu_seqlen_q*/ xla::GetTupleElement(result, 3), + /*cu_seqlen_k*/ xla::GetTupleElement(result, 4)}; +} + +} // namespace + +FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward( + const torch::lazy::Value& q, const torch::lazy::Value& k, + const torch::lazy::Value& v, const torch::lazy::Value& position_ids, + const std::string params) + : XlaNode(xla_flash_attention_forward, {q, k, v, position_ids}, + NodeOutputShape(q), + /*num_outputs=*/5, torch::lazy::MHash(params)), + params_(params) {} + +FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward( + const torch::lazy::Value& q, const torch::lazy::Value& k, + const torch::lazy::Value& v, const torch::lazy::Value& position_ids, + const torch::lazy::Value& alibi_slopes, const std::string params) + : XlaNode(xla_flash_attention_forward, + {q, k, v, position_ids, alibi_slopes}, NodeOutputShape(q), + /*num_outputs=*/5, torch::lazy::MHash(params)), + params_(params) {} + +torch::lazy::NodePtr FlashAttentionVarlenPositionIdsForward::Clone( + torch::lazy::OpList operands) const { + if (operands.size() > 4) { + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + operands.at(4), params_); + } else { + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + operands.at(0), operands.at(1), operands.at(2), operands.at(3), + params_); + } +} + +XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower(LoweringContext* loctx) const { + xla::XlaOp q = loctx->GetOutputOp(operand(0)); + xla::XlaOp k = loctx->GetOutputOp(operand(1)); + xla::XlaOp v = loctx->GetOutputOp(operand(2)); + xla::XlaOp position_ids = loctx->GetOutputOp(operand(3)); + xla::XlaOp alibi_slopes = + operands().size() > 4 ? loctx->GetOutputOp(operand(4)) : xla::XlaOp(); + std::vector<xla::XlaOp> result = BuildFlashAttentionVarlenPositionIdsForward( + q, k, v, position_ids, alibi_slopes, params_, xla_shape()); + return ReturnOps({result}, loctx); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h new file mode 100644 index 000000000000..2a37eec1c499 --- /dev/null +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h @@ -0,0 +1,34 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ +#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ + +#include "torch_xla/csrc/flash_attention_utils.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class FlashAttentionVarlenPositionIdsForward : public XlaNode { + public: + FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const std::string params); + + FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const torch::lazy::Value& alibi_slopes, + const std::string params); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + const std::string params_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_ \ No newline at end of file diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 41ddfce80438..3992c31289b3 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -53,6 +53,8 @@ #include "torch_xla/csrc/ops/flash_attention_forward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_backward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_forward.h" +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h" #include "torch_xla/csrc/ops/flip.h" #include "torch_xla/csrc/ops/gather.h" #include "torch_xla/csrc/ops/generic.h" @@ -675,6 +677,26 @@ std::vector<XLATensorPtr> flash_attention_varlen_forward( } } +std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward( + const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, + const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes, + const std::string& params) { + if (alibi_slopes) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), + position_ids->GetIrValue(), alibi_slopes->GetIrValue(), params); + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } else { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), + position_ids->GetIrValue(), params); + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } +} + + std::vector<XLATensorPtr> flash_attention_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, @@ -720,6 +742,32 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward( } } + +std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward( + const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, + const XLATensorPtr& v, const XLATensorPtr& out, + const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q, + const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, + const XLATensorPtr& alibi_slopes, const std::string& params) { + if (alibi_slopes) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>( + dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(), + v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(), + cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(), + rng_state->GetIrValue(), alibi_slopes->GetIrValue(), params); + return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } else { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>( + dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(), + v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(), + cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(), + rng_state->GetIrValue(), params); + return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } +} + std::vector<XLATensorPtr> user_computation( const std::string& opname, absl::Span<const XLATensorPtr> inputs, runtime::ComputationClient::ComputationPtr computation) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb574640045c..a04a9538463d 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -123,6 +123,11 @@ std::vector<XLATensorPtr> flash_attention_varlen_forward( const XLATensorPtr& attention_mask, const XLATensorPtr& alibi_slopes, const std::string& params); +std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward( + const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, + const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes, + const std::string& params); + std::vector<XLATensorPtr> flash_attention_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, @@ -136,6 +141,13 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward( const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, const XLATensorPtr& alibi_slopes, const std::string& params); +std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward( + const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, + const XLATensorPtr& v, const XLATensorPtr& out, + const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q, + const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state, + const XLATensorPtr& alibi_slopes, const std::string& params); + std::vector<XLATensorPtr> user_computation( const std::string& opname, absl::Span<const XLATensorPtr> inputs, runtime::ComputationClient::ComputationPtr computation); From 09067b6b7ca8c21134c7c793ed3a1b7b78358856 Mon Sep 17 00:00:00 2001 From: tianxingwang <wangtianxing.wtx@alibaba-inc.com> Date: Mon, 2 Dec 2024 14:11:05 +0800 Subject: [PATCH 2/5] reformat files --- test/test_flash_attention_varlen_backward.py | 9 +- test/test_flash_attention_varlen_forward.py | 21 +-- torch_xla/csrc/flash_attention_utils.cpp | 148 ++++++++++-------- torch_xla/csrc/flash_attention_utils.h | 17 +- torch_xla/csrc/init_python_bindings.cpp | 31 ++-- .../ops/flash_attention_varlen_forward.cpp | 24 +-- ...attention_varlen_position_ids_backward.cpp | 76 ++++----- ..._attention_varlen_position_ids_forward.cpp | 54 ++++--- ...sh_attention_varlen_position_ids_forward.h | 18 +-- torch_xla/csrc/tensor_methods.cpp | 22 ++- 10 files changed, 225 insertions(+), 195 deletions(-) diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py index 751b9aafd74f..e71edec0a75f 100644 --- a/test/test_flash_attention_varlen_backward.py +++ b/test/test_flash_attention_varlen_backward.py @@ -251,8 +251,6 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("deterministic", [False]) @@ -366,7 +364,7 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_ device = ta.lazy_device() torch.random.manual_seed(101) - + indices_q = indices_q.cpu() indices_k = indices_k.cpu() @@ -374,7 +372,7 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_ k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device) v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device) do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device) - + def attention_mask_to_position_ids(attention_mask): seqlens = attention_mask.sum(dim=1).flatten() position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0) @@ -385,7 +383,6 @@ def attention_mask_to_position_ids(attention_mask): rng_state_xla = rng_state.to(device) if alibi: alibi_slopes = alibi_slopes.cpu().to(device) - softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward( q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(), @@ -420,7 +417,7 @@ def attention_mask_to_position_ids(attention_mask): dv_xla = dv_xla.cpu().detach().squeeze(0) do_xla = do_xla.cpu().detach().squeeze(0) softmax_d_xla = softmax_d_xla.cpu().detach() - + dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q) dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k) dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k) diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py index 30e4a4aa5287..119f90d6c6cd 100644 --- a/test/test_flash_attention_varlen_forward.py +++ b/test/test_flash_attention_varlen_forward.py @@ -228,10 +228,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True) for i in range(len(cu_seq_lens[0]) - 1): seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i] - assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True) - - - + assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -262,7 +259,7 @@ def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, drop torch.random.manual_seed(0) batch_size = 4 nheads = 9 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else tuple( @@ -280,7 +277,7 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,)) total_q = seq_len_q.sum().item() total_k = seq_len_k.sum().item() - + padd_q = 0 if total_q % 8 == 0 else 8 - total_q % 8 padd_k = 0 if total_k % 8 == 0 else 8 - total_k % 8 if use_same_seqlen: @@ -301,24 +298,23 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device) assert torch.all(seq_len_q > 0) - assert torch.all(seq_len_k > 0) + assert torch.all(seq_len_k > 0) position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0) position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0) assert position_ids_q.shape[1] % 8 == 0 assert position_ids_k.shape[1] % 8 == 0 - + return q,k,v,position_ids_q,position_ids_k - q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k) - + indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64) indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64) seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0] seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0] - + cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32) cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32) @@ -335,7 +331,6 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, batch_size, nheads, device=device, dtype=torch.float32) * 0.3 else: alibi_slopes = None - out_fa, softmax_lse, _ = flash_attn_varlen_func( q_cuda.contiguous(), @@ -415,4 +410,4 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, start_idx = cu_seq_len_q_xla[i] end_idx = cu_seq_len_q_xla[i+1] seqlen = end_idx - start_idx - assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True) + assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 546a00c71947..950eaf50ce39 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -2,6 +2,7 @@ #include <ATen/cuda/CUDAContext.h> #include <torch/extension.h> + #include <iostream> #include "absl/strings/numbers.h" @@ -238,7 +239,7 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b, FlashAttentionForwardParams get_flash_attention_forward_params( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, c10::optional<at::Tensor> attention_mask, // (batch_size, seqlen) - c10::optional<at::Tensor> position_ids, // (1,seqlen_q) + c10::optional<at::Tensor> position_ids, // (1,seqlen_q) c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool return_softmax) { @@ -260,7 +261,7 @@ FlashAttentionForwardParams get_flash_attention_forward_params( const auto sizes = q.sizes(); const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; + const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); @@ -276,14 +277,14 @@ FlashAttentionForwardParams get_flash_attention_forward_params( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - + if (attention_mask.has_value()) { TORCH_CHECK(attention_mask.value().dtype() == torch::kInt32); CHECK_SHAPE(attention_mask.value(), batch_size, seqlen_k); } - if (position_ids.has_value()){ + if (position_ids.has_value()) { TORCH_CHECK(position_ids.value().dtype() == torch::kInt32); - CHECK_SHAPE(position_ids.value(),1, seqlen_q); + CHECK_SHAPE(position_ids.value(), 1, seqlen_q); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -388,13 +389,16 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); - TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || batch_size == 1); - TORCH_CHECK(cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size+1}) || - cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q+1}), - "cu_seqlens_q shape should be batch_size+1 or seqlen_q"); - TORCH_CHECK(cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size+1}) || - cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k+1}), - "cu_seqlens_k shape should be batch_size+1 or seqlen_k"); + TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || + batch_size == 1); + TORCH_CHECK( + cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) || + cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}), + "cu_seqlens_q shape should be batch_size+1 or seqlen_q"); + TORCH_CHECK( + cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) || + cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}), + "cu_seqlens_k shape should be batch_size+1 or seqlen_k"); } int alibi_slopes_batch_stride = 0; @@ -422,16 +426,19 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( return params; } - at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, - int& max_seqlen_in_batch,int& total_q, int& real_batch_size) { - const at::Tensor valid_cu_seqlens = padded_cu_seqlens.index({padded_cu_seqlens > -1}); - real_batch_size = valid_cu_seqlens.size(0)-1; - at::Tensor seqs_len = valid_cu_seqlens.slice(0,1,valid_cu_seqlens.size(0)) - - valid_cu_seqlens.slice(0,0,valid_cu_seqlens.size(0) - 1); + int& max_seqlen_in_batch, int& total_q, + int& real_batch_size) { + const at::Tensor valid_cu_seqlens = + padded_cu_seqlens.index({padded_cu_seqlens > -1}); + real_batch_size = valid_cu_seqlens.size(0) - 1; + at::Tensor seqs_len = + valid_cu_seqlens.slice(0, 1, valid_cu_seqlens.size(0)) - + valid_cu_seqlens.slice(0, 0, valid_cu_seqlens.size(0) - 1); max_seqlen_in_batch = seqs_len.max().item<int>(); total_q = valid_cu_seqlens[-1].item<int>(); - return torch::arange(total_q,torch::dtype(torch::kInt64).device(torch::kCUDA)); + return torch::arange(total_q, + torch::dtype(torch::kInt64).device(torch::kCUDA)); } at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, @@ -443,11 +450,12 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, std::array<int64_t, 2> shape = {batch_size, seqlen}; auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); - torch::Tensor rows = - torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); // (batch_size,1) - torch::Tensor cols = - torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); // (1,seqlen) - torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1) + torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32)) + .unsqueeze(1); // (batch_size,1) + torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32)) + .unsqueeze(0); // (1,seqlen) + torch::Tensor mask = + cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1) max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>(); torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32)); @@ -475,67 +483,85 @@ at::Tensor mask_to_indices(const at::Tensor& attention_mask, torch::Tensor unpad_softmax_lse( const torch::Tensor& pad_softmax_lse, // (batch_size, nhead, max_seqlen) - const torch::Tensor& cu_seqlens) // (total_seqlen + 1) + const torch::Tensor& cu_seqlens) // (total_seqlen + 1) { - int batch_size = pad_softmax_lse.size(0); - int nhead = pad_softmax_lse.size(1); - int max_seqlen = pad_softmax_lse.size(2); - int total, max_seqlen_in_batch; - at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1); - at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seqlen,torch::kInt64,max_seqlen_in_batch,total); - at::Tensor result = at::empty({total,nhead},pad_softmax_lse.options()); - result.copy_(pad_softmax_lse.transpose(1,2).reshape({batch_size*max_seqlen,nhead}).index({indices,torch::indexing::Slice()})); - return result.transpose(0,1).unsqueeze(0); + int batch_size = pad_softmax_lse.size(0); + int nhead = pad_softmax_lse.size(1); + int max_seqlen = pad_softmax_lse.size(2); + int total, max_seqlen_in_batch; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1); + at::Tensor indices = + cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen, + torch::kInt64, max_seqlen_in_batch, total); + at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options()); + result.copy_(pad_softmax_lse.transpose(1, 2) + .reshape({batch_size * max_seqlen, nhead}) + .index({indices, torch::indexing::Slice()})); + return result.transpose(0, 1).unsqueeze(0); } -torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, // (1,nheads,total_seqlen) - const at::Tensor& cu_seqlens,const int max_seq_len, const int batch_size){ - const int nheads = softmax_lse.size(1); - int max_seqlen_in_batch; - int total; - at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1); - at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seq_len, - torch::kInt32,max_seqlen_in_batch,total); - TORCH_CHECK(indices.size(0) == softmax_lse.size(2),"indice should be same size with softmax_lse") - - at::Tensor result = at::empty({batch_size * max_seq_len,nheads},softmax_lse.options()); - - result.index_put_({indices,torch::indexing::Slice()},softmax_lse.squeeze(0).transpose(0,1)); - return result.reshape({batch_size,max_seq_len,nheads}).transpose(1,2).contiguous(); +torch::Tensor pad_softmax_lse( + const at::Tensor& softmax_lse, // (1,nheads,total_seqlen) + const at::Tensor& cu_seqlens, const int max_seq_len, const int batch_size) { + const int nheads = softmax_lse.size(1); + int max_seqlen_in_batch; + int total; + at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1); + at::Tensor indices = + cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seq_len, + torch::kInt32, max_seqlen_in_batch, total); + TORCH_CHECK(indices.size(0) == softmax_lse.size(2), + "indice should be same size with softmax_lse") + + at::Tensor result = + at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options()); + + result.index_put_({indices, torch::indexing::Slice()}, + softmax_lse.squeeze(0).transpose(0, 1)); + return result.reshape({batch_size, max_seq_len, nheads}) + .transpose(1, 2) + .contiguous(); } -at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size){ +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, + int& max_seqlen_in_batch, int& total, + at::Tensor& cu_seqlen, + int& real_batch_size) { at::Tensor flatten_position_ids = position_ids.flatten(); - at::Tensor indices = torch::arange(flatten_position_ids.size(0), torch::dtype(torch::kInt64).device(torch::kCUDA)); - + at::Tensor indices = + torch::arange(flatten_position_ids.size(0), + torch::dtype(torch::kInt64).device(torch::kCUDA)); + at::Tensor batch_seq_start_idx = indices.index({flatten_position_ids == 0}); real_batch_size = batch_seq_start_idx.size(0); - at::Tensor batch_seqlen_cumsum = at::empty({real_batch_size+1},torch::dtype(torch::kInt32).device(torch::kCUDA)); - batch_seqlen_cumsum.index({torch::indexing::Slice(0,real_batch_size)}) = batch_seq_start_idx; + at::Tensor batch_seqlen_cumsum = at::empty( + {real_batch_size + 1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + batch_seqlen_cumsum.index({torch::indexing::Slice(0, real_batch_size)}) = + batch_seq_start_idx; total = flatten_position_ids.size(0); batch_seqlen_cumsum.index({-1}) = total; - at::Tensor batch_seqlen = batch_seqlen_cumsum.slice(0,1,batch_seqlen_cumsum.size(0)) - batch_seqlen_cumsum.slice(0,0,batch_seqlen_cumsum.size(0)-1); + at::Tensor batch_seqlen = + batch_seqlen_cumsum.slice(0, 1, batch_seqlen_cumsum.size(0)) - + batch_seqlen_cumsum.slice(0, 0, batch_seqlen_cumsum.size(0) - 1); max_seqlen_in_batch = batch_seqlen.max().item<int>(); - cu_seqlen.narrow(0,0,real_batch_size+1) = batch_seqlen_cumsum; + cu_seqlen.narrow(0, 0, real_batch_size + 1) = batch_seqlen_cumsum; return indices; } - - at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices) { torch::IntArrayRef sizes = input.sizes(); - int64_t first_axis_dim = sizes[0]; // bs - auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h] + int64_t first_axis_dim = sizes[0]; // bs + auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h] int64_t second_dim = 1; for (auto dim : other_shape) { second_dim *= dim; } - at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah] - torch::Tensor repeated_indices = + at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah] + torch::Tensor repeated_indices = indices.unsqueeze(1).expand({indices.size(0), second_dim}); at::Tensor gather_input = torch::gather(flat_input, 0, repeated_indices); std::vector<int64_t> reshaped_size = {-1}; diff --git a/torch_xla/csrc/flash_attention_utils.h b/torch_xla/csrc/flash_attention_utils.h index 1692e1bd8678..d231f2e5165c 100644 --- a/torch_xla/csrc/flash_attention_utils.h +++ b/torch_xla/csrc/flash_attention_utils.h @@ -132,15 +132,16 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, int& max_seqlen_in_batch, int& total); at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens, - int& max_seqlen_in_batch,int& total_q,int& real_batch_size); + int& max_seqlen_in_batch, int& total_q, + int& real_batch_size); at::Tensor mask_to_indices(const at::Tensor& attention_mask, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen); -at::Tensor position_ids_to_indices(const at::Tensor& position_ids, - int& max_seqlen_in_batch, int& total, - at::Tensor& cu_seqlen, int& real_batch_size); +at::Tensor position_ids_to_indices(const at::Tensor& position_ids, + int& max_seqlen_in_batch, int& total, + at::Tensor& cu_seqlen, int& real_batch_size); at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices); @@ -150,9 +151,11 @@ xla::Shape shape_like(const torch::lazy::Value& input); xla::Shape shape_like(const xla::XlaBuilder* builder, const xla::XlaOp& input); -torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, const torch::Tensor& cu_seqlens); +torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, + const torch::Tensor& cu_seqlens); -torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse,const at::Tensor& cu_seqlens, - const int max_seq_len, const int batch_size); +torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, + const at::Tensor& cu_seqlens, + const int max_seq_len, const int batch_size); } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_FLASH_ATTENTION_UTILS_H \ No newline at end of file diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 55d70bc03ed6..2ca6b5140716 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2409,9 +2409,9 @@ void InitXlaModuleBindings(py::module m) { c10::optional<at::Generator> gen_) { // get launch params on at::Tensor auto params = get_flash_attention_forward_params( - q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, softmax_scale, - zero_tensors, is_causal, window_size_left, window_size_right, - return_softmax); + q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, + softmax_scale, zero_tensors, is_causal, window_size_left, + window_size_right, return_softmax); // call flash attention forward XLATensorPtr q_xla = bridge::GetXlaTensor(q); XLATensorPtr k_xla = bridge::GetXlaTensor(k); @@ -2451,9 +2451,9 @@ void InitXlaModuleBindings(py::module m) { c10::optional<at::Generator> gen_) { // get launch params on at::Tensor auto params = get_flash_attention_forward_params( - q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, softmax_scale, - zero_tensors, is_causal, window_size_left, window_size_right, - return_softmax); + q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, + softmax_scale, zero_tensors, is_causal, window_size_left, + window_size_right, return_softmax); // call flash attention forward XLATensorPtr q_xla = bridge::GetXlaTensor(q); XLATensorPtr k_xla = bridge::GetXlaTensor(k); @@ -2465,12 +2465,12 @@ void InitXlaModuleBindings(py::module m) { std::vector<XLATensorPtr> xresults; if (position_ids.has_value()) { - XLATensorPtr position_ids_xla = bridge::GetXlaTensor(position_ids.value()); - xresults = tensor_methods::flash_attention_varlen_position_ids_forward( - q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla, - params.ToString()); + xresults = + tensor_methods::flash_attention_varlen_position_ids_forward( + q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla, + params.ToString()); } else { xresults = tensor_methods::flash_attention_forward( q_xla, k_xla, v_xla, alibi_slopes_xla, params.ToString()); @@ -2483,7 +2483,7 @@ void InitXlaModuleBindings(py::module m) { } return results; }); - + m.def( "_flash_attention_backward", [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k, @@ -2571,10 +2571,11 @@ void InitXlaModuleBindings(py::module m) { bridge::GetXlaTensor(cu_seqlens_q.value()); XLATensorPtr cu_seqlens_k_xla = bridge::GetXlaTensor(cu_seqlens_k.value()); - xresults = tensor_methods::flash_attention_varlen_position_ids_backward( - dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, - cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla, - alibi_slopes_xla, params.ToString()); + xresults = + tensor_methods::flash_attention_varlen_position_ids_backward( + dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, + cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla, + alibi_slopes_xla, params.ToString()); } else { xresults = tensor_methods::flash_attention_backward( dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla, diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp index 1a7cfed5bfa3..92b06f1addf5 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp @@ -3,6 +3,7 @@ #include <ATen/cuda/CUDAContext.h> #include <c10/cuda/CUDAGuard.h> #include <torch/extension.h> + #include <iostream> #include "cutlass/numeric_types.h" @@ -36,7 +37,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) { // buffers[0] = q // buffers[1] = k // buffers[2] = v -// buffers[3] = attention_mask +// buffers[3] = attention_mask // buffers[4] = alibi_slopes // buffers[5] = softmax_lse // this is output // buffers[6] = out_for_output // this is output @@ -47,7 +48,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { - std::string opaque_str(opaque, opaque_len); TF_VLOG(3) << "custom_call_flash_attention_varlen_forward opaque str: " << opaque_str; @@ -100,21 +100,21 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, int max_seqlen_in_batch_k = params.seqlen_k; int total_k = params.b * params.seqlen_k; at::Tensor indices_k = mask_to_indices(attention_mask, max_seqlen_in_batch_k, - total_k,cu_seqlens_k); - - auto unpad_k = index_first_axis(k, indices_k); + total_k, cu_seqlens_k); + + auto unpad_k = index_first_axis(k, indices_k); auto unpad_v = index_first_axis(v, indices_k); - int max_seqlen_in_batch_q = max_seqlen_in_batch_k; + int max_seqlen_in_batch_q = max_seqlen_in_batch_k; int total_q = total_k; at::Tensor indices_q; if (params.seqlen_q == params.seqlen_k) { cu_seqlens_q.copy_(cu_seqlens_k); indices_q = indices_k; - } else if (params.seqlen_q == 1){ + } else if (params.seqlen_q == 1) { max_seqlen_in_batch_q = 1; - cu_seqlens_q = torch::arange(0,params.b+1,opts); + cu_seqlens_q = torch::arange(0, params.b + 1, opts); indices_q = cu_seqlens_q.slice(/*dim=*/0, /*start=*/0, /*end=*/params.b); total_q = params.b; } else { @@ -122,14 +122,14 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream, /*dim=*/1, /*start=*/-params.seqlen_q, /*end=*/torch::indexing::None); indices_q = mask_to_indices(attention_mask_slice, max_seqlen_in_batch_q, total_q, cu_seqlens_q); - } + } at::Tensor unpad_q = index_first_axis(q, indices_q); at::Tensor unpad_output = torch::zeros({total_q, params.h * params.d}, opts.dtype(scalar_type)); at::Tensor unpad_softmax_lse = torch::zeros( - {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); - + {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); + if (max_seqlen_in_batch_q == 1) { params.is_causal = false; } @@ -322,7 +322,7 @@ torch::lazy::NodePtr FlashAttentionVarlenForward::Clone( } } -XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const{ +XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const { xla::XlaOp q = loctx->GetOutputOp(operand(0)); xla::XlaOp k = loctx->GetOutputOp(operand(1)); xla::XlaOp v = loctx->GetOutputOp(operand(2)); diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp index 435b40e822cb..5cb54805a30f 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -34,7 +34,7 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, } // Layout of `buffers` listed above: -// buffers[0] = dout +// buffers[0] = dout // buffers[1] = q // buffers[2] = k // buffers[3] = v @@ -48,13 +48,13 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, // buffers[11] = dk // this is output // buffers[12] = dv // this is output // buffers[13] = softmax_d // this is output -void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t stream, - void** buffers, - const char* opaque, - size_t opaque_len) { +void custom_call_flash_attention_varlen_position_ids_backward( + cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { std::string opaque_str(opaque, opaque_len); - TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_backward opaque str: " - << opaque_str; + TF_VLOG(3) + << "custom_call_flash_attention_varlen_position_ids_backward opaque str: " + << opaque_str; FlashAttentionBackwardParams params; params.FromString(std::move(opaque_str)); int buf_offset = params.enable_alibi_slopes; @@ -87,10 +87,10 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea at::Tensor softmax_lse = torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, opts.dtype(torch::kFloat)); - at::Tensor cu_seqlens_q = - torch::from_blob(buffers[6], {params.seqlen_q + 1}, opts.dtype(torch::kInt32)); - at::Tensor cu_seqlens_k = - torch::from_blob(buffers[7], {params.seqlen_k + 1}, opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1}, + opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1}, + opts.dtype(torch::kInt32)); // Outputs at::Tensor dq = @@ -120,8 +120,8 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea int total_q = params.b * params.seqlen_q; int total_k = params.b * params.seqlen_k; int real_batch_size; - at::Tensor indices_q = - cu_seqlens_to_indices(cu_seqlens_q,max_seqlen_in_batch_q,total_q,real_batch_size); + at::Tensor indices_q = cu_seqlens_to_indices( + cu_seqlens_q, max_seqlen_in_batch_q, total_q, real_batch_size); at::Tensor indices_k; if (params.seqlen_q == params.seqlen_k) { @@ -129,12 +129,13 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea max_seqlen_in_batch_k = max_seqlen_in_batch_q; total_k = total_q; } else { - indices_k = - cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, total_k,real_batch_size); + indices_k = cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, + total_k, real_batch_size); } - auto padded_softmax_lse = pad_softmax_lse(softmax_lse,cu_seqlens_q,max_seqlen_in_batch_q,real_batch_size); - + auto padded_softmax_lse = pad_softmax_lse( + softmax_lse, cu_seqlens_q, max_seqlen_in_batch_q, real_batch_size); + Flash_bwd_params launch_params; // Reset the parameters @@ -308,7 +309,6 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea launch_params.rng_state = reinterpret_cast<uint64_t*>(buffers[8]); launch(launch_params, torch_stream, /*configure=*/false); - // For MQA/GQA we need to sum dK and dV across the groups if (launch_params.h_k != launch_params.h) { @@ -324,13 +324,13 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea {2}); } - dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum,cu_seqlens_q)); + dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum, cu_seqlens_q)); cudaEventRecord(xla_wait_torch_event, torch_stream); cudaStreamWaitEvent(stream, xla_wait_torch_event); } -XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_backward, - "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET( + custom_call_flash_attention_varlen_position_ids_backward, "CUDA"); std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward( const xla::XlaOp& dout, const xla::XlaOp& q, const xla::XlaOp& k, @@ -365,13 +365,14 @@ std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward( } // namespace -FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward( - const torch::lazy::Value& dout, const torch::lazy::Value& q, - const torch::lazy::Value& k, const torch::lazy::Value& v, - const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, - const torch::lazy::Value& cu_seqlens_q, - const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state, - const std::string params) +FlashAttentionVarlenPositionIdsBackward:: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, const std::string params) : XlaNode(xla_flash_attention_backward, {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state}, @@ -379,13 +380,15 @@ FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward /*num_outputs=*/4, torch::lazy::MHash(params)), params_(params) {} -FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward( - const torch::lazy::Value& dout, const torch::lazy::Value& q, - const torch::lazy::Value& k, const torch::lazy::Value& v, - const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, - const torch::lazy::Value& cu_seqlens_q, - const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state, - const torch::lazy::Value& alibi_slopes, const std::string params) +FlashAttentionVarlenPositionIdsBackward:: + FlashAttentionVarlenPositionIdsBackward( + const torch::lazy::Value& dout, const torch::lazy::Value& q, + const torch::lazy::Value& k, const torch::lazy::Value& v, + const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse, + const torch::lazy::Value& cu_seqlens_q, + const torch::lazy::Value& cu_seqlens_k, + const torch::lazy::Value& rng_state, + const torch::lazy::Value& alibi_slopes, const std::string params) : XlaNode(xla_flash_attention_backward, {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, alibi_slopes}, @@ -408,7 +411,8 @@ torch::lazy::NodePtr FlashAttentionVarlenPositionIdsBackward::Clone( } } -XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower(LoweringContext* loctx) const { +XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower( + LoweringContext* loctx) const { xla::XlaOp dout = loctx->GetOutputOp(operand(0)); xla::XlaOp q = loctx->GetOutputOp(operand(1)); xla::XlaOp k = loctx->GetOutputOp(operand(2)); diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp index 9403879add54..945a807236cc 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp @@ -3,6 +3,7 @@ #include <ATen/cuda/CUDAContext.h> #include <c10/cuda/CUDAGuard.h> #include <torch/extension.h> + #include <iostream> #include "cutlass/numeric_types.h" @@ -22,12 +23,12 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) { auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions()); xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape( xla::PrimitiveType::F32, - {q_shape[0], q_shape[2], - q_shape[1]}); // 1, num_heads, total_q + {q_shape[0], q_shape[2], q_shape[1]}); // 1, num_heads, total_q xla::Shape rng_state_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); - xla::Shape cu_seqlens_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim] + xla::Shape cu_seqlens_shape = xla::ShapeUtil::MakeShape( + xla::PrimitiveType::S32, + {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim] return xla::ShapeUtil::MakeTupleShape({softmax_lse_shape, shape_like(q), rng_state_shape, cu_seqlens_shape, cu_seqlens_shape}); @@ -44,14 +45,13 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) { // buffers[7] = rng_state // this is output // buffers[8] = cu_seqlen_q // this is output // buffers[9] = cu_seqlen_k // this is output -void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream, - void** buffers, - const char* opaque, - size_t opaque_len) { - +void custom_call_flash_attention_varlen_position_ids_forward( + cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { std::string opaque_str(opaque, opaque_len); - TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_forward opaque str: " - << opaque_str; + TF_VLOG(3) + << "custom_call_flash_attention_varlen_position_ids_forward opaque str: " + << opaque_str; FlashAttentionForwardParams params; params.FromString(std::move(opaque_str)); int buf_offset = params.enable_alibi_slopes; @@ -102,17 +102,20 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream int total_k = params.b * params.seqlen_k; at::Tensor indices_k; - int real_batch_size; // packed qkv's batch size is 1, but fa need to know real batch size before packing. - indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k,total_k, - cu_seqlens_k,real_batch_size); - TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, "cu_seqlen'shape should be params.seqlen_k"); + int real_batch_size; // packed qkv's batch size is 1, but fa need to know + // real batch size before packing. + indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k, + total_k, cu_seqlens_k, real_batch_size); + TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, + "cu_seqlen'shape should be params.seqlen_k"); TORCH_CHECK(indices_k.dtype() == torch::kInt64, "indice should be int64"); - int max_seqlen_in_batch_q = max_seqlen_in_batch_k; + int max_seqlen_in_batch_q = max_seqlen_in_batch_k; int total_q = total_k; at::Tensor indices_q; - TORCH_CHECK(params.seqlen_q == params.seqlen_k, "now only support same seqlen for q and k"); + TORCH_CHECK(params.seqlen_q == params.seqlen_k, + "now only support same seqlen for q and k"); if (params.seqlen_q == params.seqlen_k) { cu_seqlens_q.copy_(cu_seqlens_k); @@ -120,7 +123,7 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream } else { // TODO:(wangtianxing.wtx) support different seqlen_q and seqlen_k indices_q = position_ids_to_indices(position_ids, max_seqlen_in_batch_q, - total_q, cu_seqlens_q,real_batch_size); + total_q, cu_seqlens_q, real_batch_size); } if (max_seqlen_in_batch_q == 1) { @@ -141,7 +144,9 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - at::Tensor pad_softmax_lse = at::empty({real_batch_size,params.h,max_seqlen_in_batch_q},torch::dtype(torch::kFloat).device(torch::kCUDA)); + at::Tensor pad_softmax_lse = + at::empty({real_batch_size, params.h, max_seqlen_in_batch_q}, + torch::dtype(torch::kFloat).device(torch::kCUDA)); Flash_fwd_params launch_params; @@ -242,16 +247,16 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream run_mha_fwd_<elem_type, kHeadDim>(launch_params, torch_stream); }); }); - softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse,cu_seqlens_q)); - + softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse, cu_seqlens_q)); + // TODO(wenting.swt): we should pad and unpad q,k,v when head_size_og % 8 != 0 // sync with cudaEvent cudaEventRecord(xla_wait_torch_event, torch_stream); cudaStreamWaitEvent(stream, xla_wait_torch_event); } -XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_forward, - "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET( + custom_call_flash_attention_varlen_position_ids_forward, "CUDA"); std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsForward( const xla::XlaOp& q, const xla::XlaOp& k, const xla::XlaOp& v, @@ -310,7 +315,8 @@ torch::lazy::NodePtr FlashAttentionVarlenPositionIdsForward::Clone( } } -XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower(LoweringContext* loctx) const { +XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower( + LoweringContext* loctx) const { xla::XlaOp q = loctx->GetOutputOp(operand(0)); xla::XlaOp k = loctx->GetOutputOp(operand(1)); xla::XlaOp v = loctx->GetOutputOp(operand(2)); diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h index 2a37eec1c499..2e27ee973b0f 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h @@ -9,17 +9,17 @@ namespace torch_xla { class FlashAttentionVarlenPositionIdsForward : public XlaNode { public: FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, - const torch::lazy::Value& k, - const torch::lazy::Value& v, - const torch::lazy::Value& position_ids, - const std::string params); + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const std::string params); FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q, - const torch::lazy::Value& k, - const torch::lazy::Value& v, - const torch::lazy::Value& position_ids, - const torch::lazy::Value& alibi_slopes, - const std::string params); + const torch::lazy::Value& k, + const torch::lazy::Value& v, + const torch::lazy::Value& position_ids, + const torch::lazy::Value& alibi_slopes, + const std::string params); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 3992c31289b3..44489c0c44e8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -53,8 +53,8 @@ #include "torch_xla/csrc/ops/flash_attention_forward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_backward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_forward.h" -#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" #include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h" +#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h" #include "torch_xla/csrc/ops/flip.h" #include "torch_xla/csrc/ops/gather.h" #include "torch_xla/csrc/ops/generic.h" @@ -681,22 +681,21 @@ std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward( const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes, const std::string& params) { - if (alibi_slopes) { - torch::lazy::NodePtr node = - torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + if (alibi_slopes) { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), position_ids->GetIrValue(), alibi_slopes->GetIrValue(), params); - return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); - } else { - torch::lazy::NodePtr node = - torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } else { + torch::lazy::NodePtr node = + torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>( q->GetIrValue(), k->GetIrValue(), v->GetIrValue(), position_ids->GetIrValue(), params); - return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); - } + return q->MakeOutputTensors(node, /*inherit_logical_type=*/false); + } } - std::vector<XLATensorPtr> flash_attention_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, @@ -742,7 +741,6 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward( } } - std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward( const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v, const XLATensorPtr& out, From 74814401e2ab61603a2527a847612e5f4d0c194c Mon Sep 17 00:00:00 2001 From: tianxingwang <wangtianxing.wtx@alibaba-inc.com> Date: Mon, 2 Dec 2024 14:14:49 +0800 Subject: [PATCH 3/5] reformat files --- test/test_flash_attention_forward.py | 3 - test/test_flash_attention_varlen_backward.py | 57 ++++++++++--- test/test_flash_attention_varlen_forward.py | 90 ++++++++++++++------ 3 files changed, 105 insertions(+), 45 deletions(-) diff --git a/test/test_flash_attention_forward.py b/test/test_flash_attention_forward.py index 6414855a9dd2..9e019ed29325 100644 --- a/test/test_flash_attention_forward.py +++ b/test/test_flash_attention_forward.py @@ -136,6 +136,3 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose(softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2) assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2) - - - diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py index e71edec0a75f..2da5985b53d4 100644 --- a/test/test_flash_attention_varlen_backward.py +++ b/test/test_flash_attention_varlen_backward.py @@ -266,9 +266,10 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal, ], ) @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_p, causal, - local, alibi, deterministic, mha_type, - dtype): +def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, + dropout_p, causal, local, + alibi, deterministic, mha_type, + dtype): if d % 8 != 0: pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") @@ -368,14 +369,19 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_ indices_q = indices_q.cpu() indices_k = indices_k.cpu() - q_xla = q.flatten(0,1)[indices_q].unsqueeze(0).to(device) - k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device) - v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device) - do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device) + q_xla = q.flatten(0, 1)[indices_q].unsqueeze(0).to(device) + k_xla = k.flatten(0, 1)[indices_k].unsqueeze(0).to(device) + v_xla = v.flatten(0, 1)[indices_k].unsqueeze(0).to(device) + do_xla = do.flatten(0, 1)[indices_q].unsqueeze(0).to(device) def attention_mask_to_position_ids(attention_mask): seqlens = attention_mask.sum(dim=1).flatten() - position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0) + position_ids = torch.cat([ + torch.arange( + 0, seqlen, dtype=torch.int32, device=attention_mask.device) + for seqlen in seqlens + ], + dim=0) return position_ids.unsqueeze(0) position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device) @@ -389,12 +395,36 @@ def attention_mask_to_position_ids(attention_mask): position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, False, causal, window_size[0], window_size[1], True, None) - assert torch.allclose(q_xla.cpu(), q_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) - assert torch.allclose(k_xla.cpu(), k_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) - assert torch.allclose(v_xla.cpu(), v_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( - cu_seqlen_k_xla[:batch_size+1].cpu(), cu_seqlens_k.cpu(), rtol=1e-3, atol=1e-3, equal_nan=True) - assert torch.allclose(o_xla.cpu(), o_cuda.unsqueeze(0).cpu(), rtol=1e-2, atol=1e-2, equal_nan=True) + q_xla.cpu(), + q_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + k_xla.cpu(), + k_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + v_xla.cpu(), + v_cuda.cpu().unsqueeze(0), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + cu_seqlen_k_xla[:batch_size + 1].cpu(), + cu_seqlens_k.cpu(), + rtol=1e-3, + atol=1e-3, + equal_nan=True) + assert torch.allclose( + o_xla.cpu(), + o_cuda.unsqueeze(0).cpu(), + rtol=1e-2, + atol=1e-2, + equal_nan=True) q_xla.requires_grad = True k_xla.requires_grad = True @@ -425,4 +455,3 @@ def attention_mask_to_position_ids(attention_mask): assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True) assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True) - diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py index 119f90d6c6cd..843573436fb8 100644 --- a/test/test_flash_attention_varlen_forward.py +++ b/test/test_flash_attention_varlen_forward.py @@ -178,7 +178,6 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, return_attn_probs=True, ) - out_fa = pad_input(out_fa, indices_q, batch_size, seqlen_q) q = q.cpu().detach() @@ -227,8 +226,13 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, assert torch.allclose( cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True) for i in range(len(cu_seq_lens[0]) - 1): - seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i] - assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True) + seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i] + assert torch.allclose( + softmax_lse_xla[i, :, :seqlen], + softmax_lse[i, :, :seqlen], + rtol=1e-2, + atol=1e-3, + equal_nan=True) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -248,9 +252,10 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, ], ) @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, dropout_p, causal, - softmax_scale, local, alibi, deterministic, mha_type, - dtype): +def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, + dropout_p, causal, softmax_scale, + local, alibi, deterministic, + mha_type, dtype): if d % 8 != 0: pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") @@ -266,15 +271,17 @@ def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, drop torch.randint(0, max_seqlen_k, (2,)).tolist()) torch.cuda.synchronize() - def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, n_heads_q, n_heads_k, device, head_dim, use_same_seqlen): + def generate_qkv_and_position_ids(batch_size, max_seqlen_q, max_seqlen_k, + dtype, n_heads_q, n_heads_k, device, + head_dim, use_same_seqlen): ''' generate varlen qkv and postion_ids and pad total seqlen to 8 ''' - seq_len_q = torch.randint(1,max_seqlen_q+1,(batch_size,)) + seq_len_q = torch.randint(1, max_seqlen_q + 1, (batch_size,)) if use_same_seqlen: seq_len_k = seq_len_q.clone() else: - seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,)) + seq_len_k = torch.randint(1, max_seqlen_k + 1, (batch_size,)) total_q = seq_len_q.sum().item() total_k = seq_len_k.sum().item() @@ -293,38 +300,57 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, total_k += padd_k assert total_k % 8 == 0 - q = torch.randn((1,total_q,n_heads_q,head_dim),dtype=dtype,device=device) - k = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device) - v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device) + q = torch.randn((1, total_q, n_heads_q, head_dim), + dtype=dtype, + device=device) + k = torch.randn((1, total_k, n_heads_k, head_dim), + dtype=dtype, + device=device) + v = torch.randn((1, total_k, n_heads_k, head_dim), + dtype=dtype, + device=device) assert torch.all(seq_len_q > 0) assert torch.all(seq_len_k > 0) - position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0) - position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0) + position_ids_q = torch.cat([ + torch.arange(0, seq_len, dtype=torch.int32, device=device) + for seq_len in seq_len_q + ], + dim=0).unsqueeze(0) + position_ids_k = torch.cat([ + torch.arange(0, seq_len, dtype=torch.int32, device=device) + for seq_len in seq_len_k + ], + dim=0).unsqueeze(0) assert position_ids_q.shape[1] % 8 == 0 assert position_ids_k.shape[1] % 8 == 0 - return q,k,v,position_ids_q,position_ids_k + return q, k, v, position_ids_q, position_ids_k - q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k) + q, k, v, position_ids_q, position_ids_k = generate_qkv_and_position_ids( + batch_size, max_seqlen_q, max_seqlen_k, dtype, nheads, nheads_k, device, + d, max_seqlen_q == max_seqlen_k) - indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64) - indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64) + indices_q = torch.arange( + 0, position_ids_q.size(1), device=device, dtype=torch.int64) + indices_k = torch.arange( + 0, position_ids_k.size(1), device=device, dtype=torch.int64) seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0] seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0] - cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32) - cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32) + cu_seq_lens_q = F.pad( + seq_q_start_idx, (0, 1), value=position_ids_q.size(1)).to(torch.int32) + cu_seq_lens_k = F.pad( + seq_k_start_idx, (0, 1), value=position_ids_k.size(1)).to(torch.int32) max_seqlen_q = (cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]).max().item() max_seqlen_k = (cu_seq_lens_k[1:] - cu_seq_lens_k[:-1]).max().item() - q_cuda = q.reshape(-1,nheads,d).cuda() - k_cuda = k.reshape(-1,nheads_k,d).cuda() - v_cuda = v.reshape(-1,nheads_k,d).cuda() - + q_cuda = q.reshape(-1, nheads, d).cuda() + k_cuda = k.reshape(-1, nheads_k, d).cuda() + v_cuda = v.reshape(-1, nheads_k, d).cuda() if alibi: alibi_slopes = torch.rand( @@ -383,7 +409,6 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, position_ids_q_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale, False, causal, window_size[0], window_size[1], True, None) - ta.mark_step(wait=True) q_xla = q_xla.cpu().detach() k_xla = k_xla.cpu().detach() @@ -404,10 +429,19 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, assert torch.allclose(k_xla, k, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose(v_xla, v, rtol=1e-3, atol=1e-3, equal_nan=True) assert torch.allclose( - cu_seq_len_k_xla[:batch_size+1], cu_seq_lens_k, rtol=1e-3, atol=1e-3, equal_nan=True) + cu_seq_len_k_xla[:batch_size + 1], + cu_seq_lens_k, + rtol=1e-3, + atol=1e-3, + equal_nan=True) assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True) for i in range(batch_size): start_idx = cu_seq_len_q_xla[i] - end_idx = cu_seq_len_q_xla[i+1] + end_idx = cu_seq_len_q_xla[i + 1] seqlen = end_idx - start_idx - assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True) + assert torch.allclose( + softmax_lse_xla[0, :, start_idx:end_idx], + softmax_lse[i, :, :seqlen], + rtol=1e-3, + atol=1e-3, + equal_nan=True) From 7e73ee211eaf75788e3b83244d760dd635cc7f09 Mon Sep 17 00:00:00 2001 From: tianxingwang <wangtianxing.wtx@alibaba-inc.com> Date: Tue, 3 Dec 2024 14:46:28 +0800 Subject: [PATCH 4/5] refine code --- torch_xla/csrc/flash_attention_utils.cpp | 23 +++++++++---------- .../ops/flash_attention_varlen_forward.cpp | 2 -- ...attention_varlen_position_ids_backward.cpp | 4 ++-- ..._attention_varlen_position_ids_forward.cpp | 3 --- 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 950eaf50ce39..376f492a1bb9 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -3,8 +3,6 @@ #include <ATen/cuda/CUDAContext.h> #include <torch/extension.h> -#include <iostream> - #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -390,15 +388,15 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || - batch_size == 1); + batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future TORCH_CHECK( cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}), - "cu_seqlens_q shape should be batch_size+1 or seqlen_q"); + cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}), + "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1"); TORCH_CHECK( cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}), - "cu_seqlens_k shape should be batch_size+1 or seqlen_k"); + cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}), + "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1"); } int alibi_slopes_batch_stride = 0; @@ -451,11 +449,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32)) - .unsqueeze(1); // (batch_size,1) + .unsqueeze(1); torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32)) - .unsqueeze(0); // (1,seqlen) + .unsqueeze(0); torch::Tensor mask = - cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1) + cols < nonzero_counts.unsqueeze(1); max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>(); torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32)); @@ -496,7 +494,7 @@ torch::Tensor unpad_softmax_lse( at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options()); result.copy_(pad_softmax_lse.transpose(1, 2) .reshape({batch_size * max_seqlen, nhead}) - .index({indices, torch::indexing::Slice()})); + .index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future return result.transpose(0, 1).unsqueeze(0); } @@ -514,7 +512,7 @@ torch::Tensor pad_softmax_lse( "indice should be same size with softmax_lse") at::Tensor result = - at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options()); + at::zeros({batch_size * max_seq_len, nheads}, softmax_lse.options()); result.index_put_({indices, torch::indexing::Slice()}, softmax_lse.squeeze(0).transpose(0, 1)); @@ -527,6 +525,7 @@ at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size) { + cu_seqlen.fill_(-1); at::Tensor flatten_position_ids = position_ids.flatten(); at::Tensor indices = torch::arange(flatten_position_ids.size(0), diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp index 92b06f1addf5..832e6c280faf 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp @@ -4,8 +4,6 @@ #include <c10/cuda/CUDAGuard.h> #include <torch/extension.h> -#include <iostream> - #include "cutlass/numeric_types.h" #include "flash.h" #include "static_switch.h" diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp index 5cb54805a30f..132742921833 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -87,9 +87,9 @@ void custom_call_flash_attention_varlen_position_ids_backward( at::Tensor softmax_lse = torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, opts.dtype(torch::kFloat)); - at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1}, + at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1}, opts.dtype(torch::kInt32)); - at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1}, + at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1}, opts.dtype(torch::kInt32)); // Outputs diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp index 945a807236cc..23d5d1c4cc6f 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp @@ -4,8 +4,6 @@ #include <c10/cuda/CUDAGuard.h> #include <torch/extension.h> -#include <iostream> - #include "cutlass/numeric_types.h" #include "flash.h" #include "static_switch.h" @@ -96,7 +94,6 @@ void custom_call_flash_attention_varlen_position_ids_forward( torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64)); softmax_lse.fill_(0); o_output.fill_(0); - cu_seqlens_k.fill_(-1); int max_seqlen_in_batch_k = params.seqlen_k; int total_k = params.b * params.seqlen_k; From e531a14583ab5c602ce8aa3f2df4cc4689f666c4 Mon Sep 17 00:00:00 2001 From: tianxingwang <wangtianxing.wtx@alibaba-inc.com> Date: Tue, 3 Dec 2024 14:49:54 +0800 Subject: [PATCH 5/5] refine code --- torch_xla/csrc/flash_attention_utils.cpp | 31 ++++++++++++------- ...attention_varlen_position_ids_backward.cpp | 8 ++--- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 376f492a1bb9..6536dc96c6cd 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -388,14 +388,17 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || - batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future + batch_size == 1); // now pack qkv batch size only support 1, + // maybe need to change in the future TORCH_CHECK( cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}), + cu_seqlens_q.value().sizes() == + torch::IntArrayRef({seqlen_q * batch_size + 1}), "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1"); TORCH_CHECK( cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}), + cu_seqlens_k.value().sizes() == + torch::IntArrayRef({seqlen_k * batch_size + 1}), "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1"); } @@ -448,12 +451,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, std::array<int64_t, 2> shape = {batch_size, seqlen}; auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); - torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32)) - .unsqueeze(1); - torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32)) - .unsqueeze(0); - torch::Tensor mask = - cols < nonzero_counts.unsqueeze(1); + torch::Tensor rows = + torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); + torch::Tensor cols = + torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); + torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>(); torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32)); @@ -492,9 +494,14 @@ torch::Tensor unpad_softmax_lse( cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen, torch::kInt64, max_seqlen_in_batch, total); at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options()); - result.copy_(pad_softmax_lse.transpose(1, 2) - .reshape({batch_size * max_seqlen, nhead}) - .index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future + result.copy_( + pad_softmax_lse.transpose(1, 2) + .reshape({batch_size * max_seqlen, nhead}) + .index({indices, + torch::indexing::Slice()})); // if packed tensor's batch size + // > 1 is supported in the + // future, need to modify here + // in the future return result.transpose(0, 1).unsqueeze(0); } diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp index 132742921833..0e508b80af6d 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -87,10 +87,10 @@ void custom_call_flash_attention_varlen_position_ids_backward( at::Tensor softmax_lse = torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, opts.dtype(torch::kFloat)); - at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1}, - opts.dtype(torch::kInt32)); - at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1}, - opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_q = torch::from_blob( + buffers[6], {params.b * params.seqlen_q + 1}, opts.dtype(torch::kInt32)); + at::Tensor cu_seqlens_k = torch::from_blob( + buffers[7], {params.b * params.seqlen_k + 1}, opts.dtype(torch::kInt32)); // Outputs at::Tensor dq =