From 141832e24c5932bc7326fa56322ff7781f4399af Mon Sep 17 00:00:00 2001
From: tianxingwang <wangtianxing.wtx@alibaba-inc.com>
Date: Mon, 2 Dec 2024 13:43:36 +0800
Subject: [PATCH 1/5] support position_ids input for flash attention

---
 test/test_flash_attention_forward.py          |   3 +
 test/test_flash_attention_varlen_backward.py  | 180 ++++++++
 test/test_flash_attention_varlen_forward.py   | 193 +++++++-
 torch_xla/csrc/flash_attention_utils.cpp      |  98 +++-
 torch_xla/csrc/flash_attention_utils.h        |  14 +-
 torch_xla/csrc/init_python_bindings.cpp       |  99 +++-
 .../ops/flash_attention_varlen_forward.cpp    |  27 +-
 ...attention_varlen_position_ids_backward.cpp | 430 ++++++++++++++++++
 ...h_attention_varlen_position_ids_backward.h |  38 ++
 ..._attention_varlen_position_ids_forward.cpp | 325 +++++++++++++
 ...sh_attention_varlen_position_ids_forward.h |  34 ++
 torch_xla/csrc/tensor_methods.cpp             |  48 ++
 torch_xla/csrc/tensor_methods.h               |  12 +
 13 files changed, 1470 insertions(+), 31 deletions(-)
 create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
 create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h
 create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
 create mode 100644 torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h

diff --git a/test/test_flash_attention_forward.py b/test/test_flash_attention_forward.py
index 9e019ed29325..6414855a9dd2 100644
--- a/test/test_flash_attention_forward.py
+++ b/test/test_flash_attention_forward.py
@@ -136,3 +136,6 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
 
   assert torch.allclose(softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2)
   assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2)
+
+
+
diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py
index 1ee13118ee20..751b9aafd74f 100644
--- a/test/test_flash_attention_varlen_backward.py
+++ b/test/test_flash_attention_varlen_backward.py
@@ -249,3 +249,183 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
   assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
   assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
   assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
+
+
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("alibi", [False])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [8])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (128, 128),
+        (1024, 1024),
+    ],
+)
+@pytest.mark.parametrize("dropout_p", [0.0])
+def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
+                                    local, alibi, deterministic, mha_type,
+                                    dtype):
+  if d % 8 != 0:
+    pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
+
+  device = "cuda"
+  # set seed
+  torch.random.manual_seed(101)
+  batch_size = 4
+  nheads = 9
+  nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+
+  assert nheads % nheads_k == 0
+  window_size = (-1, -1) if not local else tuple(
+      torch.randint(0, seqlen_k, (2,)).tolist())
+  torch.cuda.synchronize()
+  q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
+  softmax_scale = q.shape[-1]**(-0.5)
+  k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+  v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+  do = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
+  rng_state = torch.Tensor([0, 0]).to(torch.int64).to(device)
+  dq = torch.empty_like(q)
+  dk = torch.empty_like(k)
+  dv = torch.empty_like(v)
+
+  attention_mask = torch.zeros(
+      batch_size, seqlen_k, dtype=torch.int32).to(device)
+  k_lengths = torch.randint(low=2, high=seqlen_k, size=(batch_size,))
+  for i in range(batch_size):
+    k_len = k_lengths[i].item()
+    attention_mask[i, :k_len] = 1
+    q[i, k_len:, :, :] = 0
+    k[i, k_len:, :, :] = 0
+    v[i, k_len:, :, :] = 0
+    do[i, k_len:, :, :] = 0
+  q.requires_grad = True
+  k.requires_grad = True
+  v.requires_grad = True
+
+  q_cuda, k_cuda, v_cuda, do_cuda, dq_cuda, dk_cuda, dv_cuda, \
+  indices_q, indices_k, cu_seq_lens, max_seq_lens = _unpad_input(
+      q, k, v, do, dq, dk, dv, attention_mask, seqlen_q, nheads
+  )
+  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+  max_seqlen_q, max_seqlen_k = max_seq_lens
+
+  if alibi:
+    alibi_slopes = torch.rand(
+        batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+  else:
+    alibi_slopes = None
+
+  o_cuda, softmax_lse_cuda, _ = flash_attn_varlen_func(
+      q_cuda.contiguous(),
+      k_cuda.contiguous(),
+      v_cuda.contiguous(),
+      cu_seqlens_q.contiguous(),
+      cu_seqlens_k.contiguous(),
+      max_seqlen_q,
+      max_seqlen_k,
+      dropout_p=dropout_p,
+      softmax_scale=softmax_scale,
+      causal=causal,
+      window_size=window_size,
+      alibi_slopes=alibi_slopes,
+      deterministic=deterministic,
+      return_attn_probs=True,
+  )
+  dq_cuda, dk_cuda, dv_cuda, softmax_d_cuda = flash_attn_cuda.varlen_bwd(
+      do_cuda.contiguous(), q_cuda.contiguous(), k_cuda.contiguous(),
+      v_cuda.contiguous(), o_cuda.contiguous(), softmax_lse_cuda.contiguous(),
+      dq_cuda, dk_cuda, dv_cuda, cu_seqlens_q, cu_seqlens_k, alibi_slopes,
+      max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
+      window_size[0], window_size[1], deterministic, None, rng_state)
+
+  dq_cuda = pad_input(dq_cuda, indices_q, batch_size, seqlen_q)
+  dk_cuda = pad_input(dk_cuda, indices_k, batch_size, seqlen_k)
+  dv_cuda = pad_input(dv_cuda, indices_k, batch_size, seqlen_k)
+  softmax_d_cuda = softmax_d_cuda[:, :, :seqlen_q]
+
+  q = q.cpu().detach()
+  k = k.cpu().detach()
+  v = v.cpu().detach()
+  do = do.cpu().detach()
+  rng_state = rng_state.cpu().detach()
+
+  dq_cuda = dq_cuda.cpu().detach()
+  dk_cuda = dk_cuda.cpu().detach()
+  dv_cuda = dv_cuda.cpu().detach()
+  softmax_d_cuda = softmax_d_cuda.cpu().detach()
+  if alibi:
+    alibi_slopes = alibi_slopes.cpu()
+  torch.cuda.synchronize()
+
+  device = ta.lazy_device()
+  torch.random.manual_seed(101)
+  
+  indices_q = indices_q.cpu()
+  indices_k = indices_k.cpu()
+
+  q_xla = q.flatten(0,1)[indices_q].unsqueeze(0).to(device)
+  k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device)
+  v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device)
+  do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device)
+  
+  def attention_mask_to_position_ids(attention_mask):
+    seqlens = attention_mask.sum(dim=1).flatten()
+    position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0)
+    return position_ids.unsqueeze(0)
+
+  position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device)
+
+  rng_state_xla = rng_state.to(device)
+  if alibi:
+    alibi_slopes = alibi_slopes.cpu().to(device)
+  
+
+  softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward(
+      q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(),
+      position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale,
+      False, causal, window_size[0], window_size[1], True, None)
+
+  assert torch.allclose(q_xla.cpu(), q_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(k_xla.cpu(), k_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(v_xla.cpu(), v_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(
+      cu_seqlen_k_xla[:batch_size+1].cpu(), cu_seqlens_k.cpu(), rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(o_xla.cpu(), o_cuda.unsqueeze(0).cpu(), rtol=1e-2, atol=1e-2, equal_nan=True)
+
+  q_xla.requires_grad = True
+  k_xla.requires_grad = True
+  v_xla.requires_grad = True
+  o_xla.requires_grad = True
+  softmax_lse_xla.requires_grad = True
+
+  dq_xla, dk_xla, dv_xla, softmax_d_xla = torch_xla._XLAC._flash_attention_position_ids_backward(
+      do_xla.contiguous(), q_xla.contiguous(), k_xla.contiguous(),
+      v_xla.contiguous(), o_xla.contiguous(), softmax_lse_xla.contiguous(),
+      cu_seqlen_q_xla, cu_seqlen_k_xla, alibi_slopes, dropout_p, softmax_scale,
+      False, causal, window_size[0], window_size[1], deterministic, None,
+      rng_state_xla)
+
+  ta.mark_step(wait=True)
+  torch.cuda.synchronize()
+
+  dq_xla = dq_xla.cpu().detach().squeeze(0)
+  dk_xla = dk_xla.cpu().detach().squeeze(0)
+  dv_xla = dv_xla.cpu().detach().squeeze(0)
+  do_xla = do_xla.cpu().detach().squeeze(0)
+  softmax_d_xla = softmax_d_xla.cpu().detach()
+  
+  dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q)
+  dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k)
+  dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k)
+
+  assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
+  assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
+  assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
+
diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py
index b17d0c4e91a2..30e4a4aa5287 100644
--- a/test/test_flash_attention_varlen_forward.py
+++ b/test/test_flash_attention_varlen_forward.py
@@ -178,6 +178,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
       return_attn_probs=True,
   )
 
+
   out_fa = pad_input(out_fa, indices_q, batch_size, seqlen_q)
 
   q = q.cpu().detach()
@@ -225,7 +226,193 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
       cu_seqlen_q_xla, cu_seqlens_q, rtol=1e-3, atol=1e-3, equal_nan=True)
   assert torch.allclose(
       cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True)
-  softmax_lse_xla = softmax_lse_xla[:, :, :max_seqlen_in_batch_q]
-  softmax_lse = softmax_lse[:, :, :max_seqlen_in_batch_q]
+  for i in range(len(cu_seq_lens[0]) - 1):
+    seqlen = cu_seq_lens[0][i + 1] -  cu_seq_lens[0][i]
+    assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True)  
+  
+
+
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+@pytest.mark.parametrize("deterministic", [True])
+@pytest.mark.parametrize("alibi", [False])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32])
+@pytest.mark.parametrize("softmax_scale", [0.25])
+@pytest.mark.parametrize(
+    "max_seqlen_q,max_seqlen_k",
+    [
+        (8, 8),
+        (128, 128),
+        (2048, 2048),
+    ],
+)
+@pytest.mark.parametrize("dropout_p", [0.0])
+def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, dropout_p, causal,
+                           softmax_scale, local, alibi, deterministic, mha_type,
+                           dtype):
+  if d % 8 != 0:
+    pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
+
+  device = "cuda"
+  # set seed
+  torch.random.manual_seed(0)
+  batch_size = 4
+  nheads = 9
+  nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) 
+
+  assert nheads % nheads_k == 0
+  window_size = (-1, -1) if not local else tuple(
+      torch.randint(0, max_seqlen_k, (2,)).tolist())
+  torch.cuda.synchronize()
+
+  def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, n_heads_q, n_heads_k, device, head_dim, use_same_seqlen):
+    '''
+    generate varlen qkv and postion_ids and pad total seqlen to 8
+    '''
+    seq_len_q = torch.randint(1,max_seqlen_q+1,(batch_size,))
+    if use_same_seqlen:
+      seq_len_k = seq_len_q.clone()
+    else:
+      seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,))
+    total_q = seq_len_q.sum().item()
+    total_k = seq_len_k.sum().item()
+    
+    padd_q = 0 if total_q % 8 == 0 else 8 - total_q % 8
+    padd_k = 0 if total_k % 8 == 0 else 8 - total_k % 8
+    if use_same_seqlen:
+      assert torch.all(seq_len_k == seq_len_q)
+
+    # padding to last q and k
+    if padd_q:
+      seq_len_q[-1] += padd_q
+      total_q += padd_q
+    assert total_q % 8 == 0
+    if padd_k:
+      seq_len_k[-1] += padd_k
+      total_k += padd_k
+    assert total_k % 8 == 0
+
+    q = torch.randn((1,total_q,n_heads_q,head_dim),dtype=dtype,device=device)
+    k = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device)
+    v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device)
+
+    assert torch.all(seq_len_q > 0)
+    assert torch.all(seq_len_k > 0) 
+
+    position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0)
+    position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0)
+    assert position_ids_q.shape[1] % 8 == 0
+    assert position_ids_k.shape[1] % 8 == 0
+    
+    return q,k,v,position_ids_q,position_ids_k
+  
+
+  q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k)
+ 
+  indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64)
+  indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64)
+
+  seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0]
+  seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0]
+  
+  cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32)
+  cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32)
+
+  max_seqlen_q = (cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]).max().item()
+  max_seqlen_k = (cu_seq_lens_k[1:] - cu_seq_lens_k[:-1]).max().item()
+
+  q_cuda = q.reshape(-1,nheads,d).cuda()
+  k_cuda = k.reshape(-1,nheads_k,d).cuda()
+  v_cuda = v.reshape(-1,nheads_k,d).cuda()
+
+
+  if alibi:
+    alibi_slopes = torch.rand(
+        batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+  else:
+    alibi_slopes = None
+  
+
+  out_fa, softmax_lse, _ = flash_attn_varlen_func(
+      q_cuda.contiguous(),
+      k_cuda.contiguous(),
+      v_cuda.contiguous(),
+      cu_seq_lens_q.contiguous(),
+      cu_seq_lens_k.contiguous(),
+      max_seqlen_q,
+      max_seqlen_k,
+      dropout_p=dropout_p,
+      softmax_scale=softmax_scale,
+      causal=causal,
+      window_size=window_size,
+      alibi_slopes=alibi_slopes,
+      deterministic=deterministic,
+      return_attn_probs=True,
+  )
+
+  out_fa = pad_input(out_fa, indices_q, batch_size, max_seqlen_q)
+
+  q = q.cpu().detach()
+  k = k.cpu().detach()
+  v = v.cpu().detach()
+
+  out_fa = out_fa.cpu().detach()
+  softmax_lse = softmax_lse.cpu().detach()
+  cu_seq_lens_q = cu_seq_lens_q.cpu().detach()
+  cu_seq_lens_k = cu_seq_lens_k.cpu().detach()
+
+  if alibi:
+    alibi_slopes = alibi_slopes.cpu()
+  torch.cuda.synchronize()
+
+  device = ta.lazy_device()
+  torch.random.manual_seed(0)
+  q_xla = q.to(device)
+  k_xla = k.to(device)
+  v_xla = v.to(device)
+
+  position_ids_q_xla = position_ids_q.to(device)
+
+  q_xla.requires_grad = False
+  k_xla.requires_grad = False
+  v_xla.requires_grad = False
+  if alibi:
+    alibi_slopes = alibi_slopes.cpu().to(device)
+
+  softmax_lse_xla, out_xla, _, cu_seq_len_q_xla, cu_seq_len_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward(
+      q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(),
+      position_ids_q_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale,
+      False, causal, window_size[0], window_size[1], True, None)
+
+
+  ta.mark_step(wait=True)
+  q_xla = q_xla.cpu().detach()
+  k_xla = k_xla.cpu().detach()
+  v_xla = v_xla.cpu().detach()
+  out_xla = out_xla.cpu().detach()
+  cu_seq_len_q_xla = cu_seq_len_q_xla.cpu().detach()
+  cu_seq_len_k_xla = cu_seq_len_k_xla.cpu().detach()
+  softmax_lse_xla = softmax_lse_xla.cpu().detach()
+  position_ids_q_xla = position_ids_q_xla.cpu().detach()
+
+  out_xla = out_xla.squeeze(0)
+
+  out_xla = pad_input(out_xla, indices_q, batch_size, max_seqlen_q)
+
+  assert out_xla.shape == out_fa.shape
+
+  assert torch.allclose(q_xla, q, rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(k_xla, k, rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(v_xla, v, rtol=1e-3, atol=1e-3, equal_nan=True)
   assert torch.allclose(
-      softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2, equal_nan=True)
+      cu_seq_len_k_xla[:batch_size+1], cu_seq_lens_k, rtol=1e-3, atol=1e-3, equal_nan=True)
+  assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
+  for i in range(batch_size):
+    start_idx = cu_seq_len_q_xla[i]
+    end_idx = cu_seq_len_q_xla[i+1]
+    seqlen = end_idx - start_idx
+    assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True)  
diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp
index 2c3331c8d1d9..546a00c71947 100644
--- a/torch_xla/csrc/flash_attention_utils.cpp
+++ b/torch_xla/csrc/flash_attention_utils.cpp
@@ -2,6 +2,7 @@
 
 #include <ATen/cuda/CUDAContext.h>
 #include <torch/extension.h>
+#include <iostream>
 
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
@@ -236,7 +237,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b,
 
 FlashAttentionForwardParams get_flash_attention_forward_params(
     const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
-    c10::optional<at::Tensor>& attention_mask,  // (batch_size, seqlen)
+    c10::optional<at::Tensor> attention_mask,  // (batch_size, seqlen)
+    c10::optional<at::Tensor> position_ids, // (1,seqlen_q) 
     c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout,
     const float softmax_scale, const bool zero_tensors, const bool is_causal,
     int window_size_left, int window_size_right, const bool return_softmax) {
@@ -258,7 +260,7 @@ FlashAttentionForwardParams get_flash_attention_forward_params(
 
   const auto sizes = q.sizes();
   const int batch_size = sizes[0];
-  const int seqlen_q = sizes[1];
+  const int seqlen_q = sizes[1]; 
   const int num_heads = sizes[2];
   const int head_size_og = sizes[3];
   const int seqlen_k = k.size(1);
@@ -274,10 +276,15 @@ FlashAttentionForwardParams get_flash_attention_forward_params(
   CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
   CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
   CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
+  
   if (attention_mask.has_value()) {
     TORCH_CHECK(attention_mask.value().dtype() == torch::kInt32);
     CHECK_SHAPE(attention_mask.value(), batch_size, seqlen_k);
   }
+  if (position_ids.has_value()){
+    TORCH_CHECK(position_ids.value().dtype() == torch::kInt32);
+    CHECK_SHAPE(position_ids.value(),1, seqlen_q);
+  }
 
   auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
   const int head_size = round_multiple(head_size_og, 8);
@@ -308,7 +315,6 @@ FlashAttentionForwardParams get_flash_attention_forward_params(
                      p_dropout, softmax_scale, is_causal, window_size_left,
                      window_size_right, alibi_slopes_batch_stride,
                      enable_alibi_slopes, /*seqlenq_ngroups_swapped*/ false);
-
   return params;
 }
 
@@ -382,9 +388,13 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
     TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32);
     TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
     TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
-    TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1);
-    CHECK_SHAPE(cu_seqlens_q.value(), batch_size + 1);
-    CHECK_SHAPE(cu_seqlens_k.value(), batch_size + 1);
+    TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || batch_size == 1);
+    TORCH_CHECK(cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size+1}) || 
+                 cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q+1}),
+                 "cu_seqlens_q shape should be batch_size+1 or seqlen_q");
+    TORCH_CHECK(cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size+1}) || 
+                 cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k+1}),
+                 "cu_seqlens_k shape should be batch_size+1 or seqlen_k");    
   }
 
   int alibi_slopes_batch_stride = 0;
@@ -412,6 +422,18 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
   return params;
 }
 
+
+at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens,
+                                 int& max_seqlen_in_batch,int& total_q, int& real_batch_size) {
+  const at::Tensor valid_cu_seqlens = padded_cu_seqlens.index({padded_cu_seqlens > -1});
+  real_batch_size = valid_cu_seqlens.size(0)-1;
+  at::Tensor seqs_len = valid_cu_seqlens.slice(0,1,valid_cu_seqlens.size(0)) - 
+                          valid_cu_seqlens.slice(0,0,valid_cu_seqlens.size(0) - 1);
+  max_seqlen_in_batch = seqs_len.max().item<int>();
+  total_q = valid_cu_seqlens[-1].item<int>();
+  return torch::arange(total_q,torch::dtype(torch::kInt64).device(torch::kCUDA));                                     
+}
+
 at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
                                  int seqlen, torch::Dtype scalar_type,
                                  int& max_seqlen_in_batch, int& total) {
@@ -422,10 +444,10 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
 
   auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
   torch::Tensor rows =
-      torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1);
+      torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); // (batch_size,1)
   torch::Tensor cols =
-      torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0);
-  torch::Tensor mask = cols < nonzero_counts.unsqueeze(1);
+      torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); // (1,seqlen)
+  torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1)
   max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();
 
   torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
@@ -451,19 +473,69 @@ at::Tensor mask_to_indices(const at::Tensor& attention_mask,
   return indices;
 }
 
+torch::Tensor unpad_softmax_lse(
+    const torch::Tensor& pad_softmax_lse,  // (batch_size, nhead, max_seqlen)
+    const torch::Tensor& cu_seqlens)        // (total_seqlen + 1)
+{
+    int batch_size = pad_softmax_lse.size(0);
+    int nhead = pad_softmax_lse.size(1);
+    int max_seqlen = pad_softmax_lse.size(2);
+    int total, max_seqlen_in_batch;
+    at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1);
+    at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seqlen,torch::kInt64,max_seqlen_in_batch,total);
+    at::Tensor result = at::empty({total,nhead},pad_softmax_lse.options());
+    result.copy_(pad_softmax_lse.transpose(1,2).reshape({batch_size*max_seqlen,nhead}).index({indices,torch::indexing::Slice()}));
+    return result.transpose(0,1).unsqueeze(0);
+}
+
+torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, // (1,nheads,total_seqlen)
+                         const at::Tensor& cu_seqlens,const int max_seq_len, const int batch_size){
+    const int nheads = softmax_lse.size(1);
+    int max_seqlen_in_batch;
+    int total;
+    at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1);
+    at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seq_len,
+                                               torch::kInt32,max_seqlen_in_batch,total);                                 
+    TORCH_CHECK(indices.size(0) == softmax_lse.size(2),"indice should be same size with softmax_lse")
+                                            
+    at::Tensor result = at::empty({batch_size * max_seq_len,nheads},softmax_lse.options());
+
+    result.index_put_({indices,torch::indexing::Slice()},softmax_lse.squeeze(0).transpose(0,1));
+    return result.reshape({batch_size,max_seq_len,nheads}).transpose(1,2).contiguous();
+}
+
+at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size){
+  at::Tensor flatten_position_ids = position_ids.flatten();
+  at::Tensor indices = torch::arange(flatten_position_ids.size(0), torch::dtype(torch::kInt64).device(torch::kCUDA));
+  
+  at::Tensor batch_seq_start_idx = indices.index({flatten_position_ids == 0});
+  real_batch_size = batch_seq_start_idx.size(0);
+  at::Tensor batch_seqlen_cumsum = at::empty({real_batch_size+1},torch::dtype(torch::kInt32).device(torch::kCUDA));
+  batch_seqlen_cumsum.index({torch::indexing::Slice(0,real_batch_size)}) = batch_seq_start_idx;
+  total = flatten_position_ids.size(0);
+  batch_seqlen_cumsum.index({-1}) = total;
+
+  at::Tensor batch_seqlen = batch_seqlen_cumsum.slice(0,1,batch_seqlen_cumsum.size(0)) - batch_seqlen_cumsum.slice(0,0,batch_seqlen_cumsum.size(0)-1);
+  max_seqlen_in_batch = batch_seqlen.max().item<int>();
+  cu_seqlen.narrow(0,0,real_batch_size+1) = batch_seqlen_cumsum;
+  return indices;
+}
+
+
+
 at::Tensor index_first_axis(const at::Tensor& input,
                             const at::Tensor& indices) {
   torch::IntArrayRef sizes = input.sizes();
-  int64_t first_axis_dim = sizes[0];
-  auto other_shape = sizes.slice(1, sizes.size() - 1);
+  int64_t first_axis_dim = sizes[0]; // bs
+  auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h]
 
   int64_t second_dim = 1;
   for (auto dim : other_shape) {
     second_dim *= dim;
   }
 
-  at::Tensor flat_input = torch::flatten(input, 1);
-  torch::Tensor repeated_indices =
+  at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah]
+  torch::Tensor repeated_indices = 
       indices.unsqueeze(1).expand({indices.size(0), second_dim});
   at::Tensor gather_input = torch::gather(flat_input, 0, repeated_indices);
   std::vector<int64_t> reshaped_size = {-1};
diff --git a/torch_xla/csrc/flash_attention_utils.h b/torch_xla/csrc/flash_attention_utils.h
index a2696f25f82a..1692e1bd8678 100644
--- a/torch_xla/csrc/flash_attention_utils.h
+++ b/torch_xla/csrc/flash_attention_utils.h
@@ -112,7 +112,8 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b,
 
 FlashAttentionForwardParams get_flash_attention_forward_params(
     const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
-    c10::optional<at::Tensor>& attention_mask,
+    c10::optional<at::Tensor> attention_mask,
+    c10::optional<at::Tensor> position_ids,
     c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout,
     const float softmax_scale, const bool zero_tensors, const bool is_causal,
     int window_size_left, int window_size_right, const bool return_softmax);
@@ -130,10 +131,17 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
                                  int seqlen, torch::Dtype scalar_type,
                                  int& max_seqlen_in_batch, int& total);
 
+at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens,
+                                 int& max_seqlen_in_batch,int& total_q,int& real_batch_size);
+
 at::Tensor mask_to_indices(const at::Tensor& attention_mask,
                            int& max_seqlen_in_batch, int& total,
                            at::Tensor& cu_seqlen);
 
+at::Tensor position_ids_to_indices(const at::Tensor& position_ids, 
+                           int& max_seqlen_in_batch, int& total,
+                           at::Tensor& cu_seqlen, int& real_batch_size);
+
 at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices);
 
 // make xla::Shape like input, same elemet_type and dimension as of input,
@@ -142,5 +150,9 @@ xla::Shape shape_like(const torch::lazy::Value& input);
 
 xla::Shape shape_like(const xla::XlaBuilder* builder, const xla::XlaOp& input);
 
+torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, const torch::Tensor& cu_seqlens); 
+
+torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse,const at::Tensor& cu_seqlens,
+                               const int max_seq_len, const int batch_size);
 }  // namespace torch_xla
 #endif  // XLA_TORCH_XLA_CSRC_FLASH_ATTENTION_UTILS_H
\ No newline at end of file
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index bce6795ae8db..55d70bc03ed6 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -2409,7 +2409,7 @@ void InitXlaModuleBindings(py::module m) {
            c10::optional<at::Generator> gen_) {
           // get launch params on at::Tensor
           auto params = get_flash_attention_forward_params(
-              q, k, v, attention_mask, alibi_slopes, p_dropout, softmax_scale,
+              q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, softmax_scale,
               zero_tensors, is_causal, window_size_left, window_size_right,
               return_softmax);
           // call flash attention forward
@@ -2440,6 +2440,50 @@ void InitXlaModuleBindings(py::module m) {
           }
           return results;
         });
+
+  m.def("_flash_attention_position_ids_forward",
+        [](const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
+           c10::optional<at::Tensor>& position_ids,
+           c10::optional<at::Tensor>& alibi_slopes, const float p_dropout,
+           const float softmax_scale, const bool zero_tensors,
+           const bool is_causal, const int window_size_left,
+           const int window_size_right, const bool return_softmax,
+           c10::optional<at::Generator> gen_) {
+          // get launch params on at::Tensor
+          auto params = get_flash_attention_forward_params(
+              q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, softmax_scale,
+              zero_tensors, is_causal, window_size_left, window_size_right,
+              return_softmax);
+          // call flash attention forward
+          XLATensorPtr q_xla = bridge::GetXlaTensor(q);
+          XLATensorPtr k_xla = bridge::GetXlaTensor(k);
+          XLATensorPtr v_xla = bridge::GetXlaTensor(v);
+          XLATensorPtr alibi_slopes_xla =
+              alibi_slopes.has_value()
+                  ? bridge::GetXlaTensor(alibi_slopes.value())
+                  : XLATensorPtr();
+
+          std::vector<XLATensorPtr> xresults;
+          if (position_ids.has_value()) {
+
+            XLATensorPtr position_ids_xla =
+                bridge::GetXlaTensor(position_ids.value());
+            xresults = tensor_methods::flash_attention_varlen_position_ids_forward(
+                q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla,
+                params.ToString());
+          } else {
+            xresults = tensor_methods::flash_attention_forward(
+                q_xla, k_xla, v_xla, alibi_slopes_xla, params.ToString());
+          }
+          std::vector<at::Tensor> results;
+          for (auto& xresult : xresults) {
+            at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult));
+            results.push_back(torch::autograd::make_variable(
+                tensor, /*requires_grad=*/false));
+          }
+          return results;
+        });
+  
   m.def(
       "_flash_attention_backward",
       [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k,
@@ -2492,6 +2536,59 @@ void InitXlaModuleBindings(py::module m) {
         }
         return results;
       });
+  m.def(
+      "_flash_attention_position_ids_backward",
+      [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k,
+         const at::Tensor& v, const at::Tensor& out,
+         const at::Tensor& softmax_lse, c10::optional<at::Tensor>& cu_seqlens_q,
+         c10::optional<at::Tensor>& cu_seqlens_k,
+         c10::optional<at::Tensor>& alibi_slopes, const float p_dropout,
+         const float softmax_scale, const bool zero_tensors,
+         const bool is_causal, const int window_size_left,
+         const int window_size_right, const bool deterministic,
+         c10::optional<at::Generator> gen_, const at::Tensor& rng_state) {
+        // get launch params on at::Tensor
+        auto params = get_flash_attention_backward_params(
+            dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
+            alibi_slopes, p_dropout, softmax_scale, zero_tensors, is_causal,
+            window_size_left, window_size_right, deterministic);
+        // call flash attention backward
+        XLATensorPtr dout_xla = bridge::GetXlaTensor(dout);
+        XLATensorPtr q_xla = bridge::GetXlaTensor(q);
+        XLATensorPtr k_xla = bridge::GetXlaTensor(k);
+        XLATensorPtr v_xla = bridge::GetXlaTensor(v);
+        XLATensorPtr out_xla = bridge::GetXlaTensor(out);
+        XLATensorPtr softmax_lse_xla = bridge::GetXlaTensor(softmax_lse);
+        XLATensorPtr rng_state_xla = bridge::GetXlaTensor(rng_state);
+        XLATensorPtr alibi_slopes_xla =
+            alibi_slopes.has_value()
+                ? bridge::GetXlaTensor(alibi_slopes.value())
+                : XLATensorPtr();
+
+        std::vector<XLATensorPtr> xresults;
+        if (cu_seqlens_q.has_value() && cu_seqlens_k.has_value()) {
+          XLATensorPtr cu_seqlens_q_xla =
+              bridge::GetXlaTensor(cu_seqlens_q.value());
+          XLATensorPtr cu_seqlens_k_xla =
+              bridge::GetXlaTensor(cu_seqlens_k.value());
+          xresults = tensor_methods::flash_attention_varlen_position_ids_backward(
+              dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
+              cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla,
+              alibi_slopes_xla, params.ToString());
+        } else {
+          xresults = tensor_methods::flash_attention_backward(
+              dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
+              rng_state_xla, alibi_slopes_xla, params.ToString());
+        }
+        std::vector<at::Tensor> results;
+        for (auto& xresult : xresults) {
+          at::Tensor tensor = bridge::AtenFromXlaTensor(std::move(xresult));
+          results.push_back(
+              torch::autograd::make_variable(tensor, /*requires_grad=*/false));
+        }
+        return results;
+      });
+
   // -------------FlashAttention Integration API End-------------------
 
   // -------------Dynamo Integration API Start-------------------------
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
index 478f9f114df2..1a7cfed5bfa3 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
@@ -3,6 +3,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <torch/extension.h>
+#include <iostream>
 
 #include "cutlass/numeric_types.h"
 #include "flash.h"
@@ -16,7 +17,6 @@
 
 namespace torch_xla {
 namespace {
-
 xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
   auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions());
   xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape(
@@ -36,7 +36,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
 //  buffers[0] = q
 //  buffers[1] = k
 //  buffers[2] = v
-//  buffers[3] = attention_mask
+//  buffers[3] = attention_mask 
 //  buffers[4] = alibi_slopes
 //  buffers[5] = softmax_lse  // this is output
 //  buffers[6] = out_for_output // this is output
@@ -47,6 +47,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
                                                 void** buffers,
                                                 const char* opaque,
                                                 size_t opaque_len) {
+                                             
   std::string opaque_str(opaque, opaque_len);
   TF_VLOG(3) << "custom_call_flash_attention_varlen_forward opaque str: "
              << opaque_str;
@@ -98,21 +99,22 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
 
   int max_seqlen_in_batch_k = params.seqlen_k;
   int total_k = params.b * params.seqlen_k;
-
   at::Tensor indices_k = mask_to_indices(attention_mask, max_seqlen_in_batch_k,
-                                         total_k, cu_seqlens_k);
-  auto unpad_k = index_first_axis(k, indices_k);
+                                total_k,cu_seqlens_k);
+  
+  auto unpad_k = index_first_axis(k, indices_k); 
   auto unpad_v = index_first_axis(v, indices_k);
 
-  int max_seqlen_in_batch_q = max_seqlen_in_batch_k;
+  int max_seqlen_in_batch_q = max_seqlen_in_batch_k; 
   int total_q = total_k;
   at::Tensor indices_q;
 
   if (params.seqlen_q == params.seqlen_k) {
     cu_seqlens_q.copy_(cu_seqlens_k);
     indices_q = indices_k;
-  } else if (params.seqlen_q == 1) {
+  } else if (params.seqlen_q == 1){
     max_seqlen_in_batch_q = 1;
+    cu_seqlens_q = torch::arange(0,params.b+1,opts);
     indices_q = cu_seqlens_q.slice(/*dim=*/0, /*start=*/0, /*end=*/params.b);
     total_q = params.b;
   } else {
@@ -120,15 +122,14 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
         /*dim=*/1, /*start=*/-params.seqlen_q, /*end=*/torch::indexing::None);
     indices_q = mask_to_indices(attention_mask_slice, max_seqlen_in_batch_q,
                                 total_q, cu_seqlens_q);
-  }
-
+  } 
   at::Tensor unpad_q = index_first_axis(q, indices_q);
 
   at::Tensor unpad_output =
       torch::zeros({total_q, params.h * params.d}, opts.dtype(scalar_type));
   at::Tensor unpad_softmax_lse = torch::zeros(
-      {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat));
-
+      {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); 
+      
   if (max_seqlen_in_batch_q == 1) {
     params.is_causal = false;
   }
@@ -241,7 +242,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
   }
 
   TF_VLOG(2) << "Running FlashAttention Forward.";
-
   FP16_SWITCH(!launch_params.is_bf16, [&] {
     HEADDIM_SWITCH(launch_params.d, [&] {
       // TODO(wenting.swt): support split_kv
@@ -261,6 +261,7 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
   cudaEventRecord(xla_wait_torch_event, torch_stream);
   cudaStreamWaitEvent(stream, xla_wait_torch_event);
 }
+
 XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_forward,
                                 "CUDA");
 
@@ -321,7 +322,7 @@ torch::lazy::NodePtr FlashAttentionVarlenForward::Clone(
   }
 }
 
-XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const {
+XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const{
   xla::XlaOp q = loctx->GetOutputOp(operand(0));
   xla::XlaOp k = loctx->GetOutputOp(operand(1));
   xla::XlaOp v = loctx->GetOutputOp(operand(2));
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
new file mode 100644
index 000000000000..435b40e822cb
--- /dev/null
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
@@ -0,0 +1,430 @@
+#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h"
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <torch/extension.h>
+
+#include "cutlass/numeric_types.h"
+#include "flash.h"
+#include "static_switch.h"
+#include "torch_xla/csrc/flash_attention_utils.h"
+#include "torch_xla/csrc/lowering_context.h"
+#include "torch_xla/csrc/ops/xla_ops.h"
+#include "torch_xla/csrc/runtime/tf_logging.h"
+#include "torch_xla/csrc/xla_lower_util.h"
+#include "xla/service/custom_call_target_registry.h"
+
+namespace torch_xla {
+namespace {
+
+xla::Shape NodeOutputShape(const torch::lazy::Value& q,
+                           const torch::lazy::Value& k,
+                           const torch::lazy::Value& v,
+                           const torch::lazy::Value& softmax_lse) {
+  return xla::ShapeUtil::MakeTupleShape(
+      {shape_like(q), shape_like(k), shape_like(v), shape_like(softmax_lse)});
+}
+
+void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
+                 const bool configure) {
+  FP16_SWITCH(!params.is_bf16, [&] {
+    HEADDIM_SWITCH(params.d,
+                   [&] { run_mha_bwd_<elem_type, kHeadDim>(params, stream); });
+  });
+}
+
+// Layout of `buffers` listed above:
+//  buffers[0] = dout 
+//  buffers[1] = q
+//  buffers[2] = k
+//  buffers[3] = v
+//  buffers[4] = out
+//  buffers[5] = softmax_lse
+//  buffers[6] = cu_seqlens_q
+//  buffers[7] = cu_seqlens_k
+//  buffers[8] = rng_state
+//  buffers[9] = alibi_slopes
+//  buffers[10] = dq  // this is output
+//  buffers[11] = dk  // this is output
+//  buffers[12] = dv  // this is output
+//  buffers[13] = softmax_d  // this is output
+void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t stream,
+                                                 void** buffers,
+                                                 const char* opaque,
+                                                 size_t opaque_len) {
+  std::string opaque_str(opaque, opaque_len);
+  TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_backward opaque str: "
+             << opaque_str;
+  FlashAttentionBackwardParams params;
+  params.FromString(std::move(opaque_str));
+  int buf_offset = params.enable_alibi_slopes;
+  auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16;
+
+  cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream();
+  cudaEvent_t torch_wait_xla_event;
+  cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming);
+  cudaEvent_t xla_wait_torch_event;
+  cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming);
+  cudaEventRecord(torch_wait_xla_event, stream);
+  cudaStreamWaitEvent(torch_stream, torch_wait_xla_event);
+
+  auto cuda_stream = at::cuda::getCurrentCUDAStream();
+  at::cuda::CUDAStreamGuard guard(cuda_stream);
+
+  auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
+
+  // Inputs
+  at::Tensor do_ = torch::from_blob(
+      buffers[0], {params.b * params.seqlen_q, params.h, params.d}, opts);
+  at::Tensor q = torch::from_blob(
+      buffers[1], {params.b * params.seqlen_q, params.h, params.d}, opts);
+  at::Tensor k = torch::from_blob(
+      buffers[2], {params.b * params.seqlen_k, params.h_k, params.d}, opts);
+  at::Tensor v = torch::from_blob(
+      buffers[3], {params.b * params.seqlen_k, params.h_k, params.d}, opts);
+  at::Tensor o = torch::from_blob(
+      buffers[4], {params.b * params.seqlen_q, params.h, params.d}, opts);
+  at::Tensor softmax_lse =
+      torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
+                       opts.dtype(torch::kFloat));
+  at::Tensor cu_seqlens_q =
+      torch::from_blob(buffers[6], {params.seqlen_q + 1}, opts.dtype(torch::kInt32));
+  at::Tensor cu_seqlens_k =
+      torch::from_blob(buffers[7], {params.seqlen_k + 1}, opts.dtype(torch::kInt32));
+
+  // Outputs
+  at::Tensor dq =
+      torch::from_blob(buffers[9 + buf_offset],
+                       {params.b * params.seqlen_q, params.h, params.d}, opts);
+
+  at::Tensor dk = torch::from_blob(
+      buffers[10 + buf_offset],
+      {params.b * params.seqlen_k, params.h_k, params.d}, opts);
+
+  at::Tensor dv = torch::from_blob(
+      buffers[11 + buf_offset],
+      {params.b * params.seqlen_k, params.h_k, params.d}, opts);
+
+  at::Tensor dsoftmax_sum = torch::from_blob(
+      buffers[12 + buf_offset], {params.b, params.h, params.seqlen_q},
+      opts.dtype(torch::kFloat));
+
+  // Fill zeros for outputs.
+  dq.fill_(0);
+  dk.fill_(0);
+  dv.fill_(0);
+  dsoftmax_sum.fill_(0);
+
+  int max_seqlen_in_batch_q = params.seqlen_q;
+  int max_seqlen_in_batch_k = params.seqlen_k;
+  int total_q = params.b * params.seqlen_q;
+  int total_k = params.b * params.seqlen_k;
+  int real_batch_size;
+  at::Tensor indices_q =
+     cu_seqlens_to_indices(cu_seqlens_q,max_seqlen_in_batch_q,total_q,real_batch_size);
+
+  at::Tensor indices_k;
+  if (params.seqlen_q == params.seqlen_k) {
+    indices_k = indices_q;
+    max_seqlen_in_batch_k = max_seqlen_in_batch_q;
+    total_k = total_q;
+  } else {
+    indices_k =
+        cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, total_k,real_batch_size);
+  }
+
+  auto padded_softmax_lse = pad_softmax_lse(softmax_lse,cu_seqlens_q,max_seqlen_in_batch_q,real_batch_size);
+   
+  Flash_bwd_params launch_params;
+
+  // Reset the parameters
+  memset(&launch_params, 0, sizeof(launch_params));
+
+  launch_params.is_bf16 = params.is_bf16;
+
+  // Set the pointers and strides.
+  launch_params.q_ptr = q.data_ptr();
+  launch_params.k_ptr = k.data_ptr();
+  launch_params.v_ptr = v.data_ptr();
+  launch_params.o_ptr = o.data_ptr();
+
+  // All stride are in elements, not bytes.
+  launch_params.q_row_stride = params.q_row_stride;
+  launch_params.k_row_stride = params.k_row_stride;
+  launch_params.v_row_stride = params.v_row_stride;
+  launch_params.q_head_stride = params.q_head_stride;
+  launch_params.k_head_stride = params.k_head_stride;
+
+  launch_params.v_head_stride = params.v_head_stride;
+  launch_params.o_row_stride = params.o_row_stride;
+  launch_params.o_head_stride = params.o_head_stride;
+
+  launch_params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q.data_ptr());
+  launch_params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k.data_ptr());
+  launch_params.softmax_lse_ptr = padded_softmax_lse.data_ptr();
+
+  launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[9] : nullptr;
+
+  launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride;
+
+  // P = softmax(QK^T)
+  launch_params.p_ptr = nullptr;  // no softmax returned always
+
+  // Set the dimensions.
+  launch_params.b = real_batch_size;
+  launch_params.h = params.h;
+  launch_params.h_k = params.h_k;
+  launch_params.h_h_k_ratio = params.h_h_k_ratio;
+  launch_params.seqlen_q = max_seqlen_in_batch_q;
+  launch_params.seqlen_k = max_seqlen_in_batch_k;
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128);
+  launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128);
+  launch_params.d = params.d;
+  launch_params.d_rounded = params.d_rounded;
+
+  // Set the different scale values.
+  launch_params.scale_softmax = params.scale_softmax;
+  launch_params.scale_softmax_log2 = params.scale_softmax_log2;
+
+  launch_params.p_dropout = params.p_dropout;
+  launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t;
+  launch_params.rp_dropout = params.rp_dropout;
+  launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout;
+
+  if (max_seqlen_in_batch_q == 1) {
+    params.is_causal = false;
+  }
+  if (params.is_causal) {
+    params.window_size_right = 0;
+  }
+
+  if (params.window_size_left >= max_seqlen_in_batch_k) {
+    params.window_size_left = -1;
+  }
+  if (params.window_size_right >= max_seqlen_in_batch_k) {
+    params.window_size_right = -1;
+  }
+
+  launch_params.is_causal =
+      params.window_size_left < 0 && params.window_size_right == 0;
+
+  if (params.window_size_left < 0 && params.window_size_right >= 0) {
+    params.window_size_left = max_seqlen_in_batch_k;
+  }
+  if (params.window_size_left >= 0 && params.window_size_right < 0) {
+    params.window_size_right = max_seqlen_in_batch_k;
+  }
+
+  launch_params.window_size_left = params.window_size_left;
+  launch_params.window_size_right = params.window_size_right;
+
+  launch_params.is_seqlens_k_cumulative = true;
+
+  launch_params.do_row_stride = params.do_row_stride;
+  launch_params.do_head_stride = params.do_head_stride;
+
+  launch_params.dq_row_stride = params.dq_row_stride;
+  launch_params.dk_row_stride = params.dk_row_stride;
+  launch_params.dv_row_stride = params.dv_row_stride;
+  launch_params.dq_head_stride = params.dq_head_stride;
+  launch_params.dk_head_stride = params.dk_head_stride;
+  launch_params.dv_head_stride = params.dv_head_stride;
+
+  at::Tensor rounded_dsoftmax_sum =
+      at::zeros({real_batch_size, params.h, launch_params.seqlen_q_rounded},
+                opts.dtype(torch::kFloat));
+
+  launch_params.do_ptr = do_.data_ptr();
+  launch_params.dq_ptr = dq.data_ptr();
+  launch_params.dk_ptr = dk.data_ptr();
+  launch_params.dv_ptr = dv.data_ptr();
+  launch_params.dsoftmax_sum = rounded_dsoftmax_sum.data_ptr();
+
+  // bool loop = max_seqlen_k > blocksize_c;
+  // TODO: change later, for now set to true for simplicity
+  bool loop = true;
+
+  at::Tensor dq_accum;
+  if (loop) {
+    if (!params.deterministic) {
+      dq_accum = torch::empty({total_q + 128 * launch_params.b, launch_params.h,
+                               launch_params.d_rounded},
+                              opts.dtype(at::kFloat));
+    } else {
+      auto dprops = at::cuda::getCurrentDeviceProperties();
+      const int nsplits = (dprops->multiProcessorCount +
+                           launch_params.b * launch_params.h - 1) /
+                          (launch_params.b * launch_params.h);
+      dq_accum = torch::zeros({nsplits, total_q + 128 * launch_params.b,
+                               launch_params.h, launch_params.d_rounded},
+                              opts.dtype(at::kFloat));
+    }
+  }
+
+  at::Tensor dk_expanded, dv_expanded;
+
+  if (launch_params.h_k != launch_params.h) {  // MQA / GQA
+    TF_VLOG(2) << "Running FlashAttention Backward as MQA/GQA";
+    dk_expanded =
+        torch::empty({total_k, launch_params.h, launch_params.d}, opts);
+    dv_expanded =
+        torch::empty({total_k, launch_params.h, launch_params.d}, opts);
+
+    launch_params.dk_ptr = dk_expanded.data_ptr();
+    launch_params.dv_ptr = dv_expanded.data_ptr();
+    launch_params.dk_row_stride = dk_expanded.stride(-3);
+    launch_params.dv_row_stride = dv_expanded.stride(-3);
+    launch_params.dk_head_stride = dk_expanded.stride(-2);
+    launch_params.dv_head_stride = dv_expanded.stride(-2);
+  } else {
+    TF_VLOG(2) << "Running FlashAttention Backward";
+    dk_expanded = dk;
+    dv_expanded = dv;
+  }
+
+  launch_params.dq_accum_ptr = loop ? dq_accum.data_ptr() : nullptr;
+  launch_params.dk_accum_ptr = nullptr;
+  launch_params.dv_accum_ptr = nullptr;
+
+  launch_params.deterministic = params.deterministic;
+  launch_params.dq_accum_split_stride =
+      !launch_params.deterministic ? 0 : dq_accum.stride(0);
+
+  auto launch = &run_mha_bwd;
+
+  auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+      c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
+
+  // We use a custom RNG that increases the offset by batch_size * nheads * 32.
+  int64_t counter_offset = launch_params.b * launch_params.h * 32;
+
+  bool is_dropout = (1.f - launch_params.p_dropout) > 0.0;
+
+  // TODO(wenting.swt): According to the implementation in
+  // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates
+  // `rng_state` which is passed as ctx to the backward. Hence, for simplifying
+  // the logic, the redundant branch where `rng_state` is None has been omitted.
+  launch_params.rng_state = reinterpret_cast<uint64_t*>(buffers[8]);
+
+  launch(launch_params, torch_stream, /*configure=*/false);
+  
+
+  // For MQA/GQA we need to sum dK and dV across the groups
+  if (launch_params.h_k != launch_params.h) {
+    at::sum_out(dk,
+                at::reshape(dk_expanded, {total_k, launch_params.h_k,
+                                          launch_params.h / launch_params.h_k,
+                                          launch_params.d}),
+                {2});
+    at::sum_out(dv,
+                at::reshape(dv_expanded, {total_k, launch_params.h_k,
+                                          launch_params.h / launch_params.h_k,
+                                          launch_params.d}),
+                {2});
+  }
+
+  dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum,cu_seqlens_q));
+
+  cudaEventRecord(xla_wait_torch_event, torch_stream);
+  cudaStreamWaitEvent(stream, xla_wait_torch_event);
+}
+XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_backward,
+                                "CUDA");
+
+std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward(
+    const xla::XlaOp& dout, const xla::XlaOp& q, const xla::XlaOp& k,
+    const xla::XlaOp& v, const xla::XlaOp& out, const xla::XlaOp& softmax_lse,
+    const xla::XlaOp& cu_seqlens_q, const xla::XlaOp& cu_seqlens_k,
+    const xla::XlaOp& rng_state, const xla::XlaOp& alibi_slopes,
+    const std::string& opaque, const xla::Shape& output_shape) {
+  auto builder = q.builder();
+  std::vector<xla::XlaOp> operands{
+      dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state};
+  std::vector<xla::Shape> operand_shapes_with_layout{
+      shape_like(builder, dout),
+      shape_like(builder, q),
+      shape_like(builder, k),
+      shape_like(builder, v),
+      shape_like(builder, out),
+      shape_like(builder, softmax_lse),
+      builder->GetShape(cu_seqlens_q).value(),
+      builder->GetShape(cu_seqlens_k).value(),
+      builder->GetShape(rng_state).value()};
+  if (alibi_slopes.valid()) {
+    operands.push_back(alibi_slopes);
+    operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes));
+  }
+  xla::XlaOp result = xla::CustomCallWithLayout(
+      builder, "custom_call_flash_attention_varlen_position_ids_backward",
+      std::move(operands), output_shape, std::move(operand_shapes_with_layout),
+      opaque);
+  return {xla::GetTupleElement(result, 0), xla::GetTupleElement(result, 1),
+          xla::GetTupleElement(result, 2), xla::GetTupleElement(result, 3)};
+}
+
+}  // namespace
+
+FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward(
+    const torch::lazy::Value& dout, const torch::lazy::Value& q,
+    const torch::lazy::Value& k, const torch::lazy::Value& v,
+    const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+    const torch::lazy::Value& cu_seqlens_q,
+    const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state,
+    const std::string params)
+    : XlaNode(xla_flash_attention_backward,
+              {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
+               rng_state},
+              NodeOutputShape(q, k, v, softmax_lse),
+              /*num_outputs=*/4, torch::lazy::MHash(params)),
+      params_(params) {}
+
+FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward(
+    const torch::lazy::Value& dout, const torch::lazy::Value& q,
+    const torch::lazy::Value& k, const torch::lazy::Value& v,
+    const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+    const torch::lazy::Value& cu_seqlens_q,
+    const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state,
+    const torch::lazy::Value& alibi_slopes, const std::string params)
+    : XlaNode(xla_flash_attention_backward,
+              {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
+               rng_state, alibi_slopes},
+              NodeOutputShape(q, k, v, softmax_lse),
+              /*num_outputs=*/4, torch::lazy::MHash(params)),
+      params_(params) {}
+
+torch::lazy::NodePtr FlashAttentionVarlenPositionIdsBackward::Clone(
+    torch::lazy::OpList operands) const {
+  if (operands.size() > 9) {
+    torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>(
+        operands.at(0), operands.at(1), operands.at(2), operands.at(3),
+        operands.at(4), operands.at(5), operands.at(6), operands.at(7),
+        operands.at(8), operands.at(9), params_);
+  } else {
+    torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>(
+        operands.at(0), operands.at(1), operands.at(2), operands.at(3),
+        operands.at(4), operands.at(5), operands.at(6), operands.at(7),
+        operands.at(8), params_);
+  }
+}
+
+XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower(LoweringContext* loctx) const {
+  xla::XlaOp dout = loctx->GetOutputOp(operand(0));
+  xla::XlaOp q = loctx->GetOutputOp(operand(1));
+  xla::XlaOp k = loctx->GetOutputOp(operand(2));
+  xla::XlaOp v = loctx->GetOutputOp(operand(3));
+  xla::XlaOp out = loctx->GetOutputOp(operand(4));
+  xla::XlaOp softmax_lse = loctx->GetOutputOp(operand(5));
+  xla::XlaOp cu_seqlens_q = loctx->GetOutputOp(operand(6));
+  xla::XlaOp cu_seqlens_k = loctx->GetOutputOp(operand(7));
+  xla::XlaOp rng_state = loctx->GetOutputOp(operand(8));
+  xla::XlaOp alibi_slopes =
+      operands().size() > 9 ? loctx->GetOutputOp(operand(9)) : xla::XlaOp();
+  std::vector<xla::XlaOp> result = BuildFlashAttentionVarlenPositionIdsBackward(
+      dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state,
+      alibi_slopes, params_, xla_shape());
+
+  return ReturnOps({result}, loctx);
+}
+
+}  // namespace torch_xla
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h
new file mode 100644
index 000000000000..42ba728ceb02
--- /dev/null
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h
@@ -0,0 +1,38 @@
+#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_
+#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_
+
+#include "torch_xla/csrc/flash_attention_utils.h"
+#include "torch_xla/csrc/ir.h"
+
+namespace torch_xla {
+
+class FlashAttentionVarlenPositionIdsBackward : public XlaNode {
+ public:
+  FlashAttentionVarlenPositionIdsBackward(
+      const torch::lazy::Value& dout, const torch::lazy::Value& q,
+      const torch::lazy::Value& k, const torch::lazy::Value& v,
+      const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+      const torch::lazy::Value& cu_seqlens_q,
+      const torch::lazy::Value& cu_seqlens_k,
+      const torch::lazy::Value& rng_state, const std::string params);
+
+  FlashAttentionVarlenPositionIdsBackward(
+      const torch::lazy::Value& dout, const torch::lazy::Value& q,
+      const torch::lazy::Value& k, const torch::lazy::Value& v,
+      const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+      const torch::lazy::Value& cu_seqlens_q,
+      const torch::lazy::Value& cu_seqlens_k,
+      const torch::lazy::Value& rng_state,
+      const torch::lazy::Value& alibi_slopes, const std::string params);
+
+  torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
+
+  XlaOpVector Lower(LoweringContext* loctx) const override;
+
+ private:
+  const std::string params_;
+};
+
+}  // namespace torch_xla
+
+#endif  // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_BACKWARD_H_
\ No newline at end of file
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
new file mode 100644
index 000000000000..9403879add54
--- /dev/null
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
@@ -0,0 +1,325 @@
+#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h"
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <torch/extension.h>
+#include <iostream>
+
+#include "cutlass/numeric_types.h"
+#include "flash.h"
+#include "static_switch.h"
+#include "torch_xla/csrc/flash_attention_utils.h"
+#include "torch_xla/csrc/lowering_context.h"
+#include "torch_xla/csrc/ops/xla_ops.h"
+#include "torch_xla/csrc/runtime/tf_logging.h"
+#include "torch_xla/csrc/xla_lower_util.h"
+#include "xla/service/custom_call_target_registry.h"
+
+namespace torch_xla {
+namespace {
+
+xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
+  auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions());
+  xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape(
+      xla::PrimitiveType::F32,
+      {q_shape[0], q_shape[2],
+       q_shape[1]});  // 1, num_heads, total_q
+  xla::Shape rng_state_shape =
+      xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
+  xla::Shape cu_seqlens_shape =
+      xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim]
+  return xla::ShapeUtil::MakeTupleShape({softmax_lse_shape, shape_like(q),
+                                         rng_state_shape, cu_seqlens_shape,
+                                         cu_seqlens_shape});
+}
+
+// Layout of `buffers` listed above:
+//  buffers[0] = q
+//  buffers[1] = k
+//  buffers[2] = v
+//  buffers[3] = attention_mask or position_ids
+//  buffers[4] = alibi_slopes
+//  buffers[5] = softmax_lse  // this is output
+//  buffers[6] = out_for_output // this is output
+//  buffers[7] = rng_state // this is output
+//  buffers[8] = cu_seqlen_q // this is output
+//  buffers[9] = cu_seqlen_k // this is output
+void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream,
+                                                void** buffers,
+                                                const char* opaque,
+                                                size_t opaque_len) {
+                                             
+  std::string opaque_str(opaque, opaque_len);
+  TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_forward opaque str: "
+             << opaque_str;
+  FlashAttentionForwardParams params;
+  params.FromString(std::move(opaque_str));
+  int buf_offset = params.enable_alibi_slopes;
+  auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16;
+
+  cudaStream_t torch_stream = at::cuda::getCurrentCUDAStream().stream();
+  cudaEvent_t torch_wait_xla_event;
+  cudaEventCreateWithFlags(&torch_wait_xla_event, cudaEventDisableTiming);
+  cudaEvent_t xla_wait_torch_event;
+  cudaEventCreateWithFlags(&xla_wait_torch_event, cudaEventDisableTiming);
+  cudaEventRecord(torch_wait_xla_event, stream);
+  cudaStreamWaitEvent(torch_stream, torch_wait_xla_event);
+
+  auto cuda_stream = at::cuda::getCurrentCUDAStream();
+  at::cuda::CUDAStreamGuard guard(cuda_stream);
+
+  auto opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
+
+  at::Tensor q = torch::from_blob(
+      buffers[0], {params.b * params.seqlen_q, params.h, params.d},
+      opts.dtype(scalar_type));
+  at::Tensor k = torch::from_blob(
+      buffers[1], {params.b * params.seqlen_k, params.h_k, params.d},
+      opts.dtype(scalar_type));
+  at::Tensor v = torch::from_blob(
+      buffers[2], {params.b * params.seqlen_k, params.h_k, params.d},
+      opts.dtype(scalar_type));
+  at::Tensor position_ids =
+      torch::from_blob(buffers[3], {params.b, params.seqlen_k}, opts);
+  at::Tensor softmax_lse = torch::from_blob(
+      buffers[4 + buf_offset], {params.b, params.h, params.seqlen_q},
+      opts.dtype(torch::kFloat));
+  at::Tensor o_output =
+      torch::from_blob(buffers[5 + buf_offset],
+                       {params.b * params.seqlen_q, params.h * params.d},
+                       opts.dtype(scalar_type));
+  at::Tensor cu_seqlens_q =
+      torch::from_blob(buffers[7 + buf_offset], {params.seqlen_q + 1}, opts);
+  at::Tensor cu_seqlens_k =
+      torch::from_blob(buffers[8 + buf_offset], {params.seqlen_k + 1}, opts);
+  at::Tensor rng_state =
+      torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64));
+  softmax_lse.fill_(0);
+  o_output.fill_(0);
+  cu_seqlens_k.fill_(-1);
+
+  int max_seqlen_in_batch_k = params.seqlen_k;
+  int total_k = params.b * params.seqlen_k;
+  at::Tensor indices_k;
+
+  int real_batch_size; // packed qkv's batch size is 1, but fa need to know real batch size before packing.
+  indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k,total_k,
+                                      cu_seqlens_k,real_batch_size);
+  TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, "cu_seqlen'shape should be params.seqlen_k");
+  TORCH_CHECK(indices_k.dtype() == torch::kInt64, "indice should be int64");
+
+  int max_seqlen_in_batch_q = max_seqlen_in_batch_k; 
+  int total_q = total_k;
+  at::Tensor indices_q;
+
+  TORCH_CHECK(params.seqlen_q == params.seqlen_k, "now only support same seqlen for q and k");
+
+  if (params.seqlen_q == params.seqlen_k) {
+    cu_seqlens_q.copy_(cu_seqlens_k);
+    indices_q = indices_k;
+  } else {
+    // TODO:(wangtianxing.wtx) support different seqlen_q and seqlen_k
+    indices_q = position_ids_to_indices(position_ids, max_seqlen_in_batch_q,
+                              total_q, cu_seqlens_q,real_batch_size);
+  }
+
+  if (max_seqlen_in_batch_q == 1) {
+    params.is_causal = false;
+  }
+  if (params.is_causal) {
+    params.window_size_right = 0;
+  }
+
+  if (params.window_size_left >= max_seqlen_in_batch_k) {
+    params.window_size_left = -1;
+  }
+  if (params.window_size_right >= max_seqlen_in_batch_k) {
+    params.window_size_right = -1;
+  }
+
+  // Otherwise the kernel will be launched from cuda:0 device
+  // Cast to char to avoid compiler warning about narrowing
+  at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+  at::Tensor pad_softmax_lse = at::empty({real_batch_size,params.h,max_seqlen_in_batch_q},torch::dtype(torch::kFloat).device(torch::kCUDA));
+
+  Flash_fwd_params launch_params;
+
+  // Reset the parameters
+  memset(&launch_params, 0, sizeof(launch_params));
+
+  launch_params.is_bf16 = params.is_bf16;
+
+  // Set the pointers and strides.
+  launch_params.q_ptr = q.data_ptr();
+  launch_params.k_ptr = k.data_ptr();
+  launch_params.v_ptr = v.data_ptr();
+  // All stride are in elements, not bytes.
+  launch_params.q_row_stride = params.q_row_stride;
+  launch_params.k_row_stride = params.k_row_stride;
+  launch_params.v_row_stride = params.v_row_stride;
+  launch_params.q_head_stride = params.q_head_stride;
+  launch_params.k_head_stride = params.k_head_stride;
+  launch_params.v_head_stride = params.v_head_stride;
+  launch_params.o_ptr = o_output.data_ptr();
+  launch_params.o_row_stride = params.o_row_stride;
+  launch_params.o_head_stride = params.o_head_stride;
+
+  launch_params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q.data_ptr());
+  launch_params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k.data_ptr());
+
+  launch_params.seqused_k = static_cast<int*>(nullptr);
+
+  // P = softmax(QK^T)
+  launch_params.p_ptr = nullptr;  // no softmax returned always
+
+  // Softmax sum
+  launch_params.softmax_lse_ptr = pad_softmax_lse.data_ptr();
+
+  // Set the dimensions.
+  launch_params.b = real_batch_size;
+  launch_params.h = params.h;
+  launch_params.h_k = params.h_k;
+  launch_params.h_h_k_ratio = params.h_h_k_ratio;
+  launch_params.seqlen_q = max_seqlen_in_batch_q;
+  launch_params.seqlen_k = max_seqlen_in_batch_k;
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  launch_params.seqlen_q_rounded = round_multiple(max_seqlen_in_batch_q, 128);
+  launch_params.seqlen_k_rounded = round_multiple(max_seqlen_in_batch_k, 128);
+  launch_params.d = params.d;
+  launch_params.d_rounded = params.d_rounded;
+
+  // Set the different scale values.
+  launch_params.scale_softmax = params.scale_softmax;
+  launch_params.scale_softmax_log2 = params.scale_softmax_log2;
+
+  launch_params.p_dropout = params.p_dropout;
+  launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t;
+  launch_params.rp_dropout = params.rp_dropout;
+  launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout;
+
+  launch_params.is_causal =
+      params.window_size_left < 0 && params.window_size_right == 0;
+
+  if (params.window_size_left < 0 && params.window_size_right >= 0) {
+    params.window_size_left = max_seqlen_in_batch_k;
+  }
+  if (params.window_size_left >= 0 && params.window_size_right < 0) {
+    params.window_size_right = max_seqlen_in_batch_k;
+  }
+
+  launch_params.window_size_left = params.window_size_left;
+  launch_params.window_size_right = params.window_size_right;
+
+  launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative;
+
+  launch_params.alibi_slopes_ptr = buf_offset > 0 ? buffers[4] : nullptr;
+  launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride;
+
+  // set params splitkv
+  launch_params.num_splits = params.num_splits;
+
+  int64_t counter_offset = params.b * params.h * 32;
+
+  // Forward kernel will populate memory with the seed and offset.
+  launch_params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
+
+  if ((1.f - launch_params.p_dropout) > 0.0) {
+    // number of times random will be generated per thread, to offset philox
+    // counter in thc random state We use a custom RNG that increases the offset
+    // by batch_size * nheads * 32.
+    int64_t counter_offset = launch_params.b * launch_params.h * 32;
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard<std::mutex> lock(gen->mutex_);
+    launch_params.philox_args = gen->philox_cuda_state(counter_offset);
+  }
+  TF_VLOG(2) << "Running FlashAttention Forward.";
+  FP16_SWITCH(!launch_params.is_bf16, [&] {
+    HEADDIM_SWITCH(launch_params.d, [&] {
+      // TODO(wenting.swt): support split_kv
+      run_mha_fwd_<elem_type, kHeadDim>(launch_params, torch_stream);
+    });
+  });
+  softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse,cu_seqlens_q));
+  
+  // TODO(wenting.swt): we should pad and unpad q,k,v when head_size_og % 8 != 0
+  // sync with cudaEvent
+  cudaEventRecord(xla_wait_torch_event, torch_stream);
+  cudaStreamWaitEvent(stream, xla_wait_torch_event);
+}
+
+XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_forward,
+                                "CUDA");
+
+std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsForward(
+    const xla::XlaOp& q, const xla::XlaOp& k, const xla::XlaOp& v,
+    const xla::XlaOp& position_ids, const xla::XlaOp& alibi_slopes,
+    const std::string& opaque, const xla::Shape& output_shape) {
+  auto builder = q.builder();
+  std::vector<xla::XlaOp> operands{q, k, v, position_ids};
+  std::vector<xla::Shape> operand_shapes_with_layout{
+      shape_like(builder, q), shape_like(builder, k), shape_like(builder, v),
+      shape_like(builder, position_ids)};
+  if (alibi_slopes.valid()) {
+    operands.push_back(alibi_slopes);
+    operand_shapes_with_layout.push_back(shape_like(builder, alibi_slopes));
+  }
+  xla::XlaOp result = xla::CustomCallWithLayout(
+      builder, "custom_call_flash_attention_varlen_position_ids_forward",
+      std::move(operands), output_shape, std::move(operand_shapes_with_layout),
+      opaque);
+  return {/*softmax_lse*/ xla::GetTupleElement(result, 0),
+          /*output*/ xla::GetTupleElement(result, 1),
+          /*rng_state*/ xla::GetTupleElement(result, 2),
+          /*cu_seqlen_q*/ xla::GetTupleElement(result, 3),
+          /*cu_seqlen_k*/ xla::GetTupleElement(result, 4)};
+}
+
+}  // namespace
+
+FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward(
+    const torch::lazy::Value& q, const torch::lazy::Value& k,
+    const torch::lazy::Value& v, const torch::lazy::Value& position_ids,
+    const std::string params)
+    : XlaNode(xla_flash_attention_forward, {q, k, v, position_ids},
+              NodeOutputShape(q),
+              /*num_outputs=*/5, torch::lazy::MHash(params)),
+      params_(params) {}
+
+FlashAttentionVarlenPositionIdsForward::FlashAttentionVarlenPositionIdsForward(
+    const torch::lazy::Value& q, const torch::lazy::Value& k,
+    const torch::lazy::Value& v, const torch::lazy::Value& position_ids,
+    const torch::lazy::Value& alibi_slopes, const std::string params)
+    : XlaNode(xla_flash_attention_forward,
+              {q, k, v, position_ids, alibi_slopes}, NodeOutputShape(q),
+              /*num_outputs=*/5, torch::lazy::MHash(params)),
+      params_(params) {}
+
+torch::lazy::NodePtr FlashAttentionVarlenPositionIdsForward::Clone(
+    torch::lazy::OpList operands) const {
+  if (operands.size() > 4) {
+    torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+        operands.at(0), operands.at(1), operands.at(2), operands.at(3),
+        operands.at(4), params_);
+  } else {
+    torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+        operands.at(0), operands.at(1), operands.at(2), operands.at(3),
+        params_);
+  }
+}
+
+XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower(LoweringContext* loctx) const {
+  xla::XlaOp q = loctx->GetOutputOp(operand(0));
+  xla::XlaOp k = loctx->GetOutputOp(operand(1));
+  xla::XlaOp v = loctx->GetOutputOp(operand(2));
+  xla::XlaOp position_ids = loctx->GetOutputOp(operand(3));
+  xla::XlaOp alibi_slopes =
+      operands().size() > 4 ? loctx->GetOutputOp(operand(4)) : xla::XlaOp();
+  std::vector<xla::XlaOp> result = BuildFlashAttentionVarlenPositionIdsForward(
+      q, k, v, position_ids, alibi_slopes, params_, xla_shape());
+  return ReturnOps({result}, loctx);
+}
+
+}  // namespace torch_xla
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h
new file mode 100644
index 000000000000..2a37eec1c499
--- /dev/null
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h
@@ -0,0 +1,34 @@
+#ifndef XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_
+#define XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_
+
+#include "torch_xla/csrc/flash_attention_utils.h"
+#include "torch_xla/csrc/ir.h"
+
+namespace torch_xla {
+
+class FlashAttentionVarlenPositionIdsForward : public XlaNode {
+ public:
+  FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q,
+                              const torch::lazy::Value& k,
+                              const torch::lazy::Value& v,
+                              const torch::lazy::Value& position_ids,
+                              const std::string params);
+
+  FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q,
+                              const torch::lazy::Value& k,
+                              const torch::lazy::Value& v,
+                              const torch::lazy::Value& position_ids,
+                              const torch::lazy::Value& alibi_slopes,
+                              const std::string params);
+
+  torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
+
+  XlaOpVector Lower(LoweringContext* loctx) const override;
+
+ private:
+  const std::string params_;
+};
+
+}  // namespace torch_xla
+
+#endif  // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_POSITION_IDS_FORWARD_H_
\ No newline at end of file
diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp
index 41ddfce80438..3992c31289b3 100644
--- a/torch_xla/csrc/tensor_methods.cpp
+++ b/torch_xla/csrc/tensor_methods.cpp
@@ -53,6 +53,8 @@
 #include "torch_xla/csrc/ops/flash_attention_forward.h"
 #include "torch_xla/csrc/ops/flash_attention_varlen_backward.h"
 #include "torch_xla/csrc/ops/flash_attention_varlen_forward.h"
+#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h"
+#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h"
 #include "torch_xla/csrc/ops/flip.h"
 #include "torch_xla/csrc/ops/gather.h"
 #include "torch_xla/csrc/ops/generic.h"
@@ -675,6 +677,26 @@ std::vector<XLATensorPtr> flash_attention_varlen_forward(
   }
 }
 
+std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward(
+    const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v,
+    const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes,
+    const std::string& params) {
+    if (alibi_slopes) {
+      torch::lazy::NodePtr node = 
+          torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+            q->GetIrValue(), k->GetIrValue(), v->GetIrValue(),
+            position_ids->GetIrValue(), alibi_slopes->GetIrValue(), params);
+      return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+    } else {
+      torch::lazy::NodePtr node = 
+          torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+            q->GetIrValue(), k->GetIrValue(), v->GetIrValue(),
+            position_ids->GetIrValue(), params);
+      return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+    } 
+}
+
+
 std::vector<XLATensorPtr> flash_attention_backward(
     const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
     const XLATensorPtr& v, const XLATensorPtr& out,
@@ -720,6 +742,32 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward(
   }
 }
 
+
+std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward(
+    const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
+    const XLATensorPtr& v, const XLATensorPtr& out,
+    const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q,
+    const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state,
+    const XLATensorPtr& alibi_slopes, const std::string& params) {
+  if (alibi_slopes) {
+    torch::lazy::NodePtr node =
+        torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>(
+            dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(),
+            v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(),
+            cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(),
+            rng_state->GetIrValue(), alibi_slopes->GetIrValue(), params);
+    return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+  } else {
+    torch::lazy::NodePtr node =
+        torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsBackward>(
+            dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(),
+            v->GetIrValue(), out->GetIrValue(), softmax_lse->GetIrValue(),
+            cu_seqlens_q->GetIrValue(), cu_seqlens_k->GetIrValue(),
+            rng_state->GetIrValue(), params);
+    return dout->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+  }
+}
+
 std::vector<XLATensorPtr> user_computation(
     const std::string& opname, absl::Span<const XLATensorPtr> inputs,
     runtime::ComputationClient::ComputationPtr computation) {
diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h
index fb574640045c..a04a9538463d 100644
--- a/torch_xla/csrc/tensor_methods.h
+++ b/torch_xla/csrc/tensor_methods.h
@@ -123,6 +123,11 @@ std::vector<XLATensorPtr> flash_attention_varlen_forward(
     const XLATensorPtr& attention_mask, const XLATensorPtr& alibi_slopes,
     const std::string& params);
 
+std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward(
+    const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v,
+    const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes,
+    const std::string& params);
+
 std::vector<XLATensorPtr> flash_attention_backward(
     const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
     const XLATensorPtr& v, const XLATensorPtr& out,
@@ -136,6 +141,13 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward(
     const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state,
     const XLATensorPtr& alibi_slopes, const std::string& params);
 
+std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward(
+    const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
+    const XLATensorPtr& v, const XLATensorPtr& out,
+    const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q,
+    const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state,
+    const XLATensorPtr& alibi_slopes, const std::string& params);
+
 std::vector<XLATensorPtr> user_computation(
     const std::string& opname, absl::Span<const XLATensorPtr> inputs,
     runtime::ComputationClient::ComputationPtr computation);

From 09067b6b7ca8c21134c7c793ed3a1b7b78358856 Mon Sep 17 00:00:00 2001
From: tianxingwang <wangtianxing.wtx@alibaba-inc.com>
Date: Mon, 2 Dec 2024 14:11:05 +0800
Subject: [PATCH 2/5] reformat files

---
 test/test_flash_attention_varlen_backward.py  |   9 +-
 test/test_flash_attention_varlen_forward.py   |  21 +--
 torch_xla/csrc/flash_attention_utils.cpp      | 148 ++++++++++--------
 torch_xla/csrc/flash_attention_utils.h        |  17 +-
 torch_xla/csrc/init_python_bindings.cpp       |  31 ++--
 .../ops/flash_attention_varlen_forward.cpp    |  24 +--
 ...attention_varlen_position_ids_backward.cpp |  76 ++++-----
 ..._attention_varlen_position_ids_forward.cpp |  54 ++++---
 ...sh_attention_varlen_position_ids_forward.h |  18 +--
 torch_xla/csrc/tensor_methods.cpp             |  22 ++-
 10 files changed, 225 insertions(+), 195 deletions(-)

diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py
index 751b9aafd74f..e71edec0a75f 100644
--- a/test/test_flash_attention_varlen_backward.py
+++ b/test/test_flash_attention_varlen_backward.py
@@ -251,8 +251,6 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
   assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
 
 
-
-
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
 @pytest.mark.parametrize("deterministic", [False])
@@ -366,7 +364,7 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_
 
   device = ta.lazy_device()
   torch.random.manual_seed(101)
-  
+
   indices_q = indices_q.cpu()
   indices_k = indices_k.cpu()
 
@@ -374,7 +372,7 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_
   k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device)
   v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device)
   do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device)
-  
+
   def attention_mask_to_position_ids(attention_mask):
     seqlens = attention_mask.sum(dim=1).flatten()
     position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0)
@@ -385,7 +383,6 @@ def attention_mask_to_position_ids(attention_mask):
   rng_state_xla = rng_state.to(device)
   if alibi:
     alibi_slopes = alibi_slopes.cpu().to(device)
-  
 
   softmax_lse_xla, o_xla, _, cu_seqlen_q_xla, cu_seqlen_k_xla = torch_xla._XLAC._flash_attention_position_ids_forward(
       q_xla.contiguous(), k_xla.contiguous(), v_xla.contiguous(),
@@ -420,7 +417,7 @@ def attention_mask_to_position_ids(attention_mask):
   dv_xla = dv_xla.cpu().detach().squeeze(0)
   do_xla = do_xla.cpu().detach().squeeze(0)
   softmax_d_xla = softmax_d_xla.cpu().detach()
-  
+
   dq_xla = pad_input(dq_xla, indices_q, batch_size, seqlen_q)
   dk_xla = pad_input(dk_xla, indices_k, batch_size, seqlen_k)
   dv_xla = pad_input(dv_xla, indices_k, batch_size, seqlen_k)
diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py
index 30e4a4aa5287..119f90d6c6cd 100644
--- a/test/test_flash_attention_varlen_forward.py
+++ b/test/test_flash_attention_varlen_forward.py
@@ -228,10 +228,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
       cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True)
   for i in range(len(cu_seq_lens[0]) - 1):
     seqlen = cu_seq_lens[0][i + 1] -  cu_seq_lens[0][i]
-    assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True)  
-  
-
-
+    assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True)
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -262,7 +259,7 @@ def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, drop
   torch.random.manual_seed(0)
   batch_size = 4
   nheads = 9
-  nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) 
+  nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
 
   assert nheads % nheads_k == 0
   window_size = (-1, -1) if not local else tuple(
@@ -280,7 +277,7 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
       seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,))
     total_q = seq_len_q.sum().item()
     total_k = seq_len_k.sum().item()
-    
+
     padd_q = 0 if total_q % 8 == 0 else 8 - total_q % 8
     padd_k = 0 if total_k % 8 == 0 else 8 - total_k % 8
     if use_same_seqlen:
@@ -301,24 +298,23 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
     v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device)
 
     assert torch.all(seq_len_q > 0)
-    assert torch.all(seq_len_k > 0) 
+    assert torch.all(seq_len_k > 0)
 
     position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0)
     position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0)
     assert position_ids_q.shape[1] % 8 == 0
     assert position_ids_k.shape[1] % 8 == 0
-    
+
     return q,k,v,position_ids_q,position_ids_k
-  
 
   q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k)
- 
+
   indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64)
   indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64)
 
   seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0]
   seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0]
-  
+
   cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32)
   cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32)
 
@@ -335,7 +331,6 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
         batch_size, nheads, device=device, dtype=torch.float32) * 0.3
   else:
     alibi_slopes = None
-  
 
   out_fa, softmax_lse, _ = flash_attn_varlen_func(
       q_cuda.contiguous(),
@@ -415,4 +410,4 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
     start_idx = cu_seq_len_q_xla[i]
     end_idx = cu_seq_len_q_xla[i+1]
     seqlen = end_idx - start_idx
-    assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True)  
+    assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True)
diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp
index 546a00c71947..950eaf50ce39 100644
--- a/torch_xla/csrc/flash_attention_utils.cpp
+++ b/torch_xla/csrc/flash_attention_utils.cpp
@@ -2,6 +2,7 @@
 
 #include <ATen/cuda/CUDAContext.h>
 #include <torch/extension.h>
+
 #include <iostream>
 
 #include "absl/strings/numbers.h"
@@ -238,7 +239,7 @@ void set_backward_params(FlashAttentionBackwardParams& params, const size_t b,
 FlashAttentionForwardParams get_flash_attention_forward_params(
     const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
     c10::optional<at::Tensor> attention_mask,  // (batch_size, seqlen)
-    c10::optional<at::Tensor> position_ids, // (1,seqlen_q) 
+    c10::optional<at::Tensor> position_ids,    // (1,seqlen_q)
     c10::optional<at::Tensor>& alibi_slopes_, const float p_dropout,
     const float softmax_scale, const bool zero_tensors, const bool is_causal,
     int window_size_left, int window_size_right, const bool return_softmax) {
@@ -260,7 +261,7 @@ FlashAttentionForwardParams get_flash_attention_forward_params(
 
   const auto sizes = q.sizes();
   const int batch_size = sizes[0];
-  const int seqlen_q = sizes[1]; 
+  const int seqlen_q = sizes[1];
   const int num_heads = sizes[2];
   const int head_size_og = sizes[3];
   const int seqlen_k = k.size(1);
@@ -276,14 +277,14 @@ FlashAttentionForwardParams get_flash_attention_forward_params(
   CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
   CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
   CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
-  
+
   if (attention_mask.has_value()) {
     TORCH_CHECK(attention_mask.value().dtype() == torch::kInt32);
     CHECK_SHAPE(attention_mask.value(), batch_size, seqlen_k);
   }
-  if (position_ids.has_value()){
+  if (position_ids.has_value()) {
     TORCH_CHECK(position_ids.value().dtype() == torch::kInt32);
-    CHECK_SHAPE(position_ids.value(),1, seqlen_q);
+    CHECK_SHAPE(position_ids.value(), 1, seqlen_q);
   }
 
   auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -388,13 +389,16 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
     TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32);
     TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
     TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
-    TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || batch_size == 1);
-    TORCH_CHECK(cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size+1}) || 
-                 cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q+1}),
-                 "cu_seqlens_q shape should be batch_size+1 or seqlen_q");
-    TORCH_CHECK(cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size+1}) || 
-                 cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k+1}),
-                 "cu_seqlens_k shape should be batch_size+1 or seqlen_k");    
+    TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 ||
+                batch_size == 1);
+    TORCH_CHECK(
+        cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
+            cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}),
+        "cu_seqlens_q shape should be batch_size+1 or seqlen_q");
+    TORCH_CHECK(
+        cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
+            cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}),
+        "cu_seqlens_k shape should be batch_size+1 or seqlen_k");
   }
 
   int alibi_slopes_batch_stride = 0;
@@ -422,16 +426,19 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
   return params;
 }
 
-
 at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens,
-                                 int& max_seqlen_in_batch,int& total_q, int& real_batch_size) {
-  const at::Tensor valid_cu_seqlens = padded_cu_seqlens.index({padded_cu_seqlens > -1});
-  real_batch_size = valid_cu_seqlens.size(0)-1;
-  at::Tensor seqs_len = valid_cu_seqlens.slice(0,1,valid_cu_seqlens.size(0)) - 
-                          valid_cu_seqlens.slice(0,0,valid_cu_seqlens.size(0) - 1);
+                                 int& max_seqlen_in_batch, int& total_q,
+                                 int& real_batch_size) {
+  const at::Tensor valid_cu_seqlens =
+      padded_cu_seqlens.index({padded_cu_seqlens > -1});
+  real_batch_size = valid_cu_seqlens.size(0) - 1;
+  at::Tensor seqs_len =
+      valid_cu_seqlens.slice(0, 1, valid_cu_seqlens.size(0)) -
+      valid_cu_seqlens.slice(0, 0, valid_cu_seqlens.size(0) - 1);
   max_seqlen_in_batch = seqs_len.max().item<int>();
   total_q = valid_cu_seqlens[-1].item<int>();
-  return torch::arange(total_q,torch::dtype(torch::kInt64).device(torch::kCUDA));                                     
+  return torch::arange(total_q,
+                       torch::dtype(torch::kInt64).device(torch::kCUDA));
 }
 
 at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
@@ -443,11 +450,12 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
   std::array<int64_t, 2> shape = {batch_size, seqlen};
 
   auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
-  torch::Tensor rows =
-      torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1); // (batch_size,1)
-  torch::Tensor cols =
-      torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0); // (1,seqlen)
-  torch::Tensor mask = cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1)
+  torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32))
+                           .unsqueeze(1);  // (batch_size,1)
+  torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32))
+                           .unsqueeze(0);  // (1,seqlen)
+  torch::Tensor mask =
+      cols < nonzero_counts.unsqueeze(1);  // (1,seqlen) < (batch_size, 1)
   max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();
 
   torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
@@ -475,67 +483,85 @@ at::Tensor mask_to_indices(const at::Tensor& attention_mask,
 
 torch::Tensor unpad_softmax_lse(
     const torch::Tensor& pad_softmax_lse,  // (batch_size, nhead, max_seqlen)
-    const torch::Tensor& cu_seqlens)        // (total_seqlen + 1)
+    const torch::Tensor& cu_seqlens)       // (total_seqlen + 1)
 {
-    int batch_size = pad_softmax_lse.size(0);
-    int nhead = pad_softmax_lse.size(1);
-    int max_seqlen = pad_softmax_lse.size(2);
-    int total, max_seqlen_in_batch;
-    at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1);
-    at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seqlen,torch::kInt64,max_seqlen_in_batch,total);
-    at::Tensor result = at::empty({total,nhead},pad_softmax_lse.options());
-    result.copy_(pad_softmax_lse.transpose(1,2).reshape({batch_size*max_seqlen,nhead}).index({indices,torch::indexing::Slice()}));
-    return result.transpose(0,1).unsqueeze(0);
+  int batch_size = pad_softmax_lse.size(0);
+  int nhead = pad_softmax_lse.size(1);
+  int max_seqlen = pad_softmax_lse.size(2);
+  int total, max_seqlen_in_batch;
+  at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1);
+  at::Tensor indices =
+      cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen,
+                            torch::kInt64, max_seqlen_in_batch, total);
+  at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options());
+  result.copy_(pad_softmax_lse.transpose(1, 2)
+                   .reshape({batch_size * max_seqlen, nhead})
+                   .index({indices, torch::indexing::Slice()}));
+  return result.transpose(0, 1).unsqueeze(0);
 }
 
-torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse, // (1,nheads,total_seqlen)
-                         const at::Tensor& cu_seqlens,const int max_seq_len, const int batch_size){
-    const int nheads = softmax_lse.size(1);
-    int max_seqlen_in_batch;
-    int total;
-    at::Tensor valid_cu_seqlens = cu_seqlens.slice(0,0,batch_size+1);
-    at::Tensor indices = cu_seqlens_to_indices(valid_cu_seqlens,batch_size,max_seq_len,
-                                               torch::kInt32,max_seqlen_in_batch,total);                                 
-    TORCH_CHECK(indices.size(0) == softmax_lse.size(2),"indice should be same size with softmax_lse")
-                                            
-    at::Tensor result = at::empty({batch_size * max_seq_len,nheads},softmax_lse.options());
-
-    result.index_put_({indices,torch::indexing::Slice()},softmax_lse.squeeze(0).transpose(0,1));
-    return result.reshape({batch_size,max_seq_len,nheads}).transpose(1,2).contiguous();
+torch::Tensor pad_softmax_lse(
+    const at::Tensor& softmax_lse,  // (1,nheads,total_seqlen)
+    const at::Tensor& cu_seqlens, const int max_seq_len, const int batch_size) {
+  const int nheads = softmax_lse.size(1);
+  int max_seqlen_in_batch;
+  int total;
+  at::Tensor valid_cu_seqlens = cu_seqlens.slice(0, 0, batch_size + 1);
+  at::Tensor indices =
+      cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seq_len,
+                            torch::kInt32, max_seqlen_in_batch, total);
+  TORCH_CHECK(indices.size(0) == softmax_lse.size(2),
+              "indice should be same size with softmax_lse")
+
+  at::Tensor result =
+      at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options());
+
+  result.index_put_({indices, torch::indexing::Slice()},
+                    softmax_lse.squeeze(0).transpose(0, 1));
+  return result.reshape({batch_size, max_seq_len, nheads})
+      .transpose(1, 2)
+      .contiguous();
 }
 
-at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size){
+at::Tensor position_ids_to_indices(const at::Tensor& position_ids,
+                                   int& max_seqlen_in_batch, int& total,
+                                   at::Tensor& cu_seqlen,
+                                   int& real_batch_size) {
   at::Tensor flatten_position_ids = position_ids.flatten();
-  at::Tensor indices = torch::arange(flatten_position_ids.size(0), torch::dtype(torch::kInt64).device(torch::kCUDA));
-  
+  at::Tensor indices =
+      torch::arange(flatten_position_ids.size(0),
+                    torch::dtype(torch::kInt64).device(torch::kCUDA));
+
   at::Tensor batch_seq_start_idx = indices.index({flatten_position_ids == 0});
   real_batch_size = batch_seq_start_idx.size(0);
-  at::Tensor batch_seqlen_cumsum = at::empty({real_batch_size+1},torch::dtype(torch::kInt32).device(torch::kCUDA));
-  batch_seqlen_cumsum.index({torch::indexing::Slice(0,real_batch_size)}) = batch_seq_start_idx;
+  at::Tensor batch_seqlen_cumsum = at::empty(
+      {real_batch_size + 1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
+  batch_seqlen_cumsum.index({torch::indexing::Slice(0, real_batch_size)}) =
+      batch_seq_start_idx;
   total = flatten_position_ids.size(0);
   batch_seqlen_cumsum.index({-1}) = total;
 
-  at::Tensor batch_seqlen = batch_seqlen_cumsum.slice(0,1,batch_seqlen_cumsum.size(0)) - batch_seqlen_cumsum.slice(0,0,batch_seqlen_cumsum.size(0)-1);
+  at::Tensor batch_seqlen =
+      batch_seqlen_cumsum.slice(0, 1, batch_seqlen_cumsum.size(0)) -
+      batch_seqlen_cumsum.slice(0, 0, batch_seqlen_cumsum.size(0) - 1);
   max_seqlen_in_batch = batch_seqlen.max().item<int>();
-  cu_seqlen.narrow(0,0,real_batch_size+1) = batch_seqlen_cumsum;
+  cu_seqlen.narrow(0, 0, real_batch_size + 1) = batch_seqlen_cumsum;
   return indices;
 }
 
-
-
 at::Tensor index_first_axis(const at::Tensor& input,
                             const at::Tensor& indices) {
   torch::IntArrayRef sizes = input.sizes();
-  int64_t first_axis_dim = sizes[0]; // bs
-  auto other_shape = sizes.slice(1, sizes.size() - 1); // [a,h]
+  int64_t first_axis_dim = sizes[0];                    // bs
+  auto other_shape = sizes.slice(1, sizes.size() - 1);  // [a,h]
 
   int64_t second_dim = 1;
   for (auto dim : other_shape) {
     second_dim *= dim;
   }
 
-  at::Tensor flat_input = torch::flatten(input, 1); // [bs,ah]
-  torch::Tensor repeated_indices = 
+  at::Tensor flat_input = torch::flatten(input, 1);  // [bs,ah]
+  torch::Tensor repeated_indices =
       indices.unsqueeze(1).expand({indices.size(0), second_dim});
   at::Tensor gather_input = torch::gather(flat_input, 0, repeated_indices);
   std::vector<int64_t> reshaped_size = {-1};
diff --git a/torch_xla/csrc/flash_attention_utils.h b/torch_xla/csrc/flash_attention_utils.h
index 1692e1bd8678..d231f2e5165c 100644
--- a/torch_xla/csrc/flash_attention_utils.h
+++ b/torch_xla/csrc/flash_attention_utils.h
@@ -132,15 +132,16 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
                                  int& max_seqlen_in_batch, int& total);
 
 at::Tensor cu_seqlens_to_indices(const at::Tensor& padded_cu_seqlens,
-                                 int& max_seqlen_in_batch,int& total_q,int& real_batch_size);
+                                 int& max_seqlen_in_batch, int& total_q,
+                                 int& real_batch_size);
 
 at::Tensor mask_to_indices(const at::Tensor& attention_mask,
                            int& max_seqlen_in_batch, int& total,
                            at::Tensor& cu_seqlen);
 
-at::Tensor position_ids_to_indices(const at::Tensor& position_ids, 
-                           int& max_seqlen_in_batch, int& total,
-                           at::Tensor& cu_seqlen, int& real_batch_size);
+at::Tensor position_ids_to_indices(const at::Tensor& position_ids,
+                                   int& max_seqlen_in_batch, int& total,
+                                   at::Tensor& cu_seqlen, int& real_batch_size);
 
 at::Tensor index_first_axis(const at::Tensor& input, const at::Tensor& indices);
 
@@ -150,9 +151,11 @@ xla::Shape shape_like(const torch::lazy::Value& input);
 
 xla::Shape shape_like(const xla::XlaBuilder* builder, const xla::XlaOp& input);
 
-torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse, const torch::Tensor& cu_seqlens); 
+torch::Tensor unpad_softmax_lse(const torch::Tensor& pad_softmax_lse,
+                                const torch::Tensor& cu_seqlens);
 
-torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse,const at::Tensor& cu_seqlens,
-                               const int max_seq_len, const int batch_size);
+torch::Tensor pad_softmax_lse(const at::Tensor& softmax_lse,
+                              const at::Tensor& cu_seqlens,
+                              const int max_seq_len, const int batch_size);
 }  // namespace torch_xla
 #endif  // XLA_TORCH_XLA_CSRC_FLASH_ATTENTION_UTILS_H
\ No newline at end of file
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 55d70bc03ed6..2ca6b5140716 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -2409,9 +2409,9 @@ void InitXlaModuleBindings(py::module m) {
            c10::optional<at::Generator> gen_) {
           // get launch params on at::Tensor
           auto params = get_flash_attention_forward_params(
-              q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout, softmax_scale,
-              zero_tensors, is_causal, window_size_left, window_size_right,
-              return_softmax);
+              q, k, v, attention_mask, c10::nullopt, alibi_slopes, p_dropout,
+              softmax_scale, zero_tensors, is_causal, window_size_left,
+              window_size_right, return_softmax);
           // call flash attention forward
           XLATensorPtr q_xla = bridge::GetXlaTensor(q);
           XLATensorPtr k_xla = bridge::GetXlaTensor(k);
@@ -2451,9 +2451,9 @@ void InitXlaModuleBindings(py::module m) {
            c10::optional<at::Generator> gen_) {
           // get launch params on at::Tensor
           auto params = get_flash_attention_forward_params(
-              q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout, softmax_scale,
-              zero_tensors, is_causal, window_size_left, window_size_right,
-              return_softmax);
+              q, k, v, c10::nullopt, position_ids, alibi_slopes, p_dropout,
+              softmax_scale, zero_tensors, is_causal, window_size_left,
+              window_size_right, return_softmax);
           // call flash attention forward
           XLATensorPtr q_xla = bridge::GetXlaTensor(q);
           XLATensorPtr k_xla = bridge::GetXlaTensor(k);
@@ -2465,12 +2465,12 @@ void InitXlaModuleBindings(py::module m) {
 
           std::vector<XLATensorPtr> xresults;
           if (position_ids.has_value()) {
-
             XLATensorPtr position_ids_xla =
                 bridge::GetXlaTensor(position_ids.value());
-            xresults = tensor_methods::flash_attention_varlen_position_ids_forward(
-                q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla,
-                params.ToString());
+            xresults =
+                tensor_methods::flash_attention_varlen_position_ids_forward(
+                    q_xla, k_xla, v_xla, position_ids_xla, alibi_slopes_xla,
+                    params.ToString());
           } else {
             xresults = tensor_methods::flash_attention_forward(
                 q_xla, k_xla, v_xla, alibi_slopes_xla, params.ToString());
@@ -2483,7 +2483,7 @@ void InitXlaModuleBindings(py::module m) {
           }
           return results;
         });
-  
+
   m.def(
       "_flash_attention_backward",
       [](const at::Tensor& dout, const at::Tensor& q, const at::Tensor& k,
@@ -2571,10 +2571,11 @@ void InitXlaModuleBindings(py::module m) {
               bridge::GetXlaTensor(cu_seqlens_q.value());
           XLATensorPtr cu_seqlens_k_xla =
               bridge::GetXlaTensor(cu_seqlens_k.value());
-          xresults = tensor_methods::flash_attention_varlen_position_ids_backward(
-              dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
-              cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla,
-              alibi_slopes_xla, params.ToString());
+          xresults =
+              tensor_methods::flash_attention_varlen_position_ids_backward(
+                  dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
+                  cu_seqlens_q_xla, cu_seqlens_k_xla, rng_state_xla,
+                  alibi_slopes_xla, params.ToString());
         } else {
           xresults = tensor_methods::flash_attention_backward(
               dout_xla, q_xla, k_xla, v_xla, out_xla, softmax_lse_xla,
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
index 1a7cfed5bfa3..92b06f1addf5 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
@@ -3,6 +3,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <torch/extension.h>
+
 #include <iostream>
 
 #include "cutlass/numeric_types.h"
@@ -36,7 +37,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
 //  buffers[0] = q
 //  buffers[1] = k
 //  buffers[2] = v
-//  buffers[3] = attention_mask 
+//  buffers[3] = attention_mask
 //  buffers[4] = alibi_slopes
 //  buffers[5] = softmax_lse  // this is output
 //  buffers[6] = out_for_output // this is output
@@ -47,7 +48,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
                                                 void** buffers,
                                                 const char* opaque,
                                                 size_t opaque_len) {
-                                             
   std::string opaque_str(opaque, opaque_len);
   TF_VLOG(3) << "custom_call_flash_attention_varlen_forward opaque str: "
              << opaque_str;
@@ -100,21 +100,21 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
   int max_seqlen_in_batch_k = params.seqlen_k;
   int total_k = params.b * params.seqlen_k;
   at::Tensor indices_k = mask_to_indices(attention_mask, max_seqlen_in_batch_k,
-                                total_k,cu_seqlens_k);
-  
-  auto unpad_k = index_first_axis(k, indices_k); 
+                                         total_k, cu_seqlens_k);
+
+  auto unpad_k = index_first_axis(k, indices_k);
   auto unpad_v = index_first_axis(v, indices_k);
 
-  int max_seqlen_in_batch_q = max_seqlen_in_batch_k; 
+  int max_seqlen_in_batch_q = max_seqlen_in_batch_k;
   int total_q = total_k;
   at::Tensor indices_q;
 
   if (params.seqlen_q == params.seqlen_k) {
     cu_seqlens_q.copy_(cu_seqlens_k);
     indices_q = indices_k;
-  } else if (params.seqlen_q == 1){
+  } else if (params.seqlen_q == 1) {
     max_seqlen_in_batch_q = 1;
-    cu_seqlens_q = torch::arange(0,params.b+1,opts);
+    cu_seqlens_q = torch::arange(0, params.b + 1, opts);
     indices_q = cu_seqlens_q.slice(/*dim=*/0, /*start=*/0, /*end=*/params.b);
     total_q = params.b;
   } else {
@@ -122,14 +122,14 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
         /*dim=*/1, /*start=*/-params.seqlen_q, /*end=*/torch::indexing::None);
     indices_q = mask_to_indices(attention_mask_slice, max_seqlen_in_batch_q,
                                 total_q, cu_seqlens_q);
-  } 
+  }
   at::Tensor unpad_q = index_first_axis(q, indices_q);
 
   at::Tensor unpad_output =
       torch::zeros({total_q, params.h * params.d}, opts.dtype(scalar_type));
   at::Tensor unpad_softmax_lse = torch::zeros(
-      {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat)); 
-      
+      {params.b, params.h, max_seqlen_in_batch_q}, opts.dtype(torch::kFloat));
+
   if (max_seqlen_in_batch_q == 1) {
     params.is_causal = false;
   }
@@ -322,7 +322,7 @@ torch::lazy::NodePtr FlashAttentionVarlenForward::Clone(
   }
 }
 
-XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const{
+XlaOpVector FlashAttentionVarlenForward::Lower(LoweringContext* loctx) const {
   xla::XlaOp q = loctx->GetOutputOp(operand(0));
   xla::XlaOp k = loctx->GetOutputOp(operand(1));
   xla::XlaOp v = loctx->GetOutputOp(operand(2));
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
index 435b40e822cb..5cb54805a30f 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
@@ -34,7 +34,7 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
 }
 
 // Layout of `buffers` listed above:
-//  buffers[0] = dout 
+//  buffers[0] = dout
 //  buffers[1] = q
 //  buffers[2] = k
 //  buffers[3] = v
@@ -48,13 +48,13 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
 //  buffers[11] = dk  // this is output
 //  buffers[12] = dv  // this is output
 //  buffers[13] = softmax_d  // this is output
-void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t stream,
-                                                 void** buffers,
-                                                 const char* opaque,
-                                                 size_t opaque_len) {
+void custom_call_flash_attention_varlen_position_ids_backward(
+    cudaStream_t stream, void** buffers, const char* opaque,
+    size_t opaque_len) {
   std::string opaque_str(opaque, opaque_len);
-  TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_backward opaque str: "
-             << opaque_str;
+  TF_VLOG(3)
+      << "custom_call_flash_attention_varlen_position_ids_backward opaque str: "
+      << opaque_str;
   FlashAttentionBackwardParams params;
   params.FromString(std::move(opaque_str));
   int buf_offset = params.enable_alibi_slopes;
@@ -87,10 +87,10 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea
   at::Tensor softmax_lse =
       torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
                        opts.dtype(torch::kFloat));
-  at::Tensor cu_seqlens_q =
-      torch::from_blob(buffers[6], {params.seqlen_q + 1}, opts.dtype(torch::kInt32));
-  at::Tensor cu_seqlens_k =
-      torch::from_blob(buffers[7], {params.seqlen_k + 1}, opts.dtype(torch::kInt32));
+  at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1},
+                                             opts.dtype(torch::kInt32));
+  at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1},
+                                             opts.dtype(torch::kInt32));
 
   // Outputs
   at::Tensor dq =
@@ -120,8 +120,8 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea
   int total_q = params.b * params.seqlen_q;
   int total_k = params.b * params.seqlen_k;
   int real_batch_size;
-  at::Tensor indices_q =
-     cu_seqlens_to_indices(cu_seqlens_q,max_seqlen_in_batch_q,total_q,real_batch_size);
+  at::Tensor indices_q = cu_seqlens_to_indices(
+      cu_seqlens_q, max_seqlen_in_batch_q, total_q, real_batch_size);
 
   at::Tensor indices_k;
   if (params.seqlen_q == params.seqlen_k) {
@@ -129,12 +129,13 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea
     max_seqlen_in_batch_k = max_seqlen_in_batch_q;
     total_k = total_q;
   } else {
-    indices_k =
-        cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k, total_k,real_batch_size);
+    indices_k = cu_seqlens_to_indices(cu_seqlens_k, max_seqlen_in_batch_k,
+                                      total_k, real_batch_size);
   }
 
-  auto padded_softmax_lse = pad_softmax_lse(softmax_lse,cu_seqlens_q,max_seqlen_in_batch_q,real_batch_size);
-   
+  auto padded_softmax_lse = pad_softmax_lse(
+      softmax_lse, cu_seqlens_q, max_seqlen_in_batch_q, real_batch_size);
+
   Flash_bwd_params launch_params;
 
   // Reset the parameters
@@ -308,7 +309,6 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea
   launch_params.rng_state = reinterpret_cast<uint64_t*>(buffers[8]);
 
   launch(launch_params, torch_stream, /*configure=*/false);
-  
 
   // For MQA/GQA we need to sum dK and dV across the groups
   if (launch_params.h_k != launch_params.h) {
@@ -324,13 +324,13 @@ void custom_call_flash_attention_varlen_position_ids_backward(cudaStream_t strea
                 {2});
   }
 
-  dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum,cu_seqlens_q));
+  dsoftmax_sum.copy_(unpad_softmax_lse(rounded_dsoftmax_sum, cu_seqlens_q));
 
   cudaEventRecord(xla_wait_torch_event, torch_stream);
   cudaStreamWaitEvent(stream, xla_wait_torch_event);
 }
-XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_backward,
-                                "CUDA");
+XLA_REGISTER_CUSTOM_CALL_TARGET(
+    custom_call_flash_attention_varlen_position_ids_backward, "CUDA");
 
 std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward(
     const xla::XlaOp& dout, const xla::XlaOp& q, const xla::XlaOp& k,
@@ -365,13 +365,14 @@ std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsBackward(
 
 }  // namespace
 
-FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward(
-    const torch::lazy::Value& dout, const torch::lazy::Value& q,
-    const torch::lazy::Value& k, const torch::lazy::Value& v,
-    const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
-    const torch::lazy::Value& cu_seqlens_q,
-    const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state,
-    const std::string params)
+FlashAttentionVarlenPositionIdsBackward::
+    FlashAttentionVarlenPositionIdsBackward(
+        const torch::lazy::Value& dout, const torch::lazy::Value& q,
+        const torch::lazy::Value& k, const torch::lazy::Value& v,
+        const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+        const torch::lazy::Value& cu_seqlens_q,
+        const torch::lazy::Value& cu_seqlens_k,
+        const torch::lazy::Value& rng_state, const std::string params)
     : XlaNode(xla_flash_attention_backward,
               {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
                rng_state},
@@ -379,13 +380,15 @@ FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward
               /*num_outputs=*/4, torch::lazy::MHash(params)),
       params_(params) {}
 
-FlashAttentionVarlenPositionIdsBackward::FlashAttentionVarlenPositionIdsBackward(
-    const torch::lazy::Value& dout, const torch::lazy::Value& q,
-    const torch::lazy::Value& k, const torch::lazy::Value& v,
-    const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
-    const torch::lazy::Value& cu_seqlens_q,
-    const torch::lazy::Value& cu_seqlens_k, const torch::lazy::Value& rng_state,
-    const torch::lazy::Value& alibi_slopes, const std::string params)
+FlashAttentionVarlenPositionIdsBackward::
+    FlashAttentionVarlenPositionIdsBackward(
+        const torch::lazy::Value& dout, const torch::lazy::Value& q,
+        const torch::lazy::Value& k, const torch::lazy::Value& v,
+        const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
+        const torch::lazy::Value& cu_seqlens_q,
+        const torch::lazy::Value& cu_seqlens_k,
+        const torch::lazy::Value& rng_state,
+        const torch::lazy::Value& alibi_slopes, const std::string params)
     : XlaNode(xla_flash_attention_backward,
               {dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
                rng_state, alibi_slopes},
@@ -408,7 +411,8 @@ torch::lazy::NodePtr FlashAttentionVarlenPositionIdsBackward::Clone(
   }
 }
 
-XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower(LoweringContext* loctx) const {
+XlaOpVector FlashAttentionVarlenPositionIdsBackward::Lower(
+    LoweringContext* loctx) const {
   xla::XlaOp dout = loctx->GetOutputOp(operand(0));
   xla::XlaOp q = loctx->GetOutputOp(operand(1));
   xla::XlaOp k = loctx->GetOutputOp(operand(2));
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
index 9403879add54..945a807236cc 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
@@ -3,6 +3,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <torch/extension.h>
+
 #include <iostream>
 
 #include "cutlass/numeric_types.h"
@@ -22,12 +23,12 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
   auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions());
   xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape(
       xla::PrimitiveType::F32,
-      {q_shape[0], q_shape[2],
-       q_shape[1]});  // 1, num_heads, total_q
+      {q_shape[0], q_shape[2], q_shape[1]});  // 1, num_heads, total_q
   xla::Shape rng_state_shape =
       xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
-  xla::Shape cu_seqlens_shape =
-      xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {q_shape[1] + 1}); // q.shape [1,total_q,num_head,head_dim]
+  xla::Shape cu_seqlens_shape = xla::ShapeUtil::MakeShape(
+      xla::PrimitiveType::S32,
+      {q_shape[1] + 1});  // q.shape [1,total_q,num_head,head_dim]
   return xla::ShapeUtil::MakeTupleShape({softmax_lse_shape, shape_like(q),
                                          rng_state_shape, cu_seqlens_shape,
                                          cu_seqlens_shape});
@@ -44,14 +45,13 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
 //  buffers[7] = rng_state // this is output
 //  buffers[8] = cu_seqlen_q // this is output
 //  buffers[9] = cu_seqlen_k // this is output
-void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream,
-                                                void** buffers,
-                                                const char* opaque,
-                                                size_t opaque_len) {
-                                             
+void custom_call_flash_attention_varlen_position_ids_forward(
+    cudaStream_t stream, void** buffers, const char* opaque,
+    size_t opaque_len) {
   std::string opaque_str(opaque, opaque_len);
-  TF_VLOG(3) << "custom_call_flash_attention_varlen_position_ids_forward opaque str: "
-             << opaque_str;
+  TF_VLOG(3)
+      << "custom_call_flash_attention_varlen_position_ids_forward opaque str: "
+      << opaque_str;
   FlashAttentionForwardParams params;
   params.FromString(std::move(opaque_str));
   int buf_offset = params.enable_alibi_slopes;
@@ -102,17 +102,20 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream
   int total_k = params.b * params.seqlen_k;
   at::Tensor indices_k;
 
-  int real_batch_size; // packed qkv's batch size is 1, but fa need to know real batch size before packing.
-  indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k,total_k,
-                                      cu_seqlens_k,real_batch_size);
-  TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1, "cu_seqlen'shape should be params.seqlen_k");
+  int real_batch_size;  // packed qkv's batch size is 1, but fa need to know
+                        // real batch size before packing.
+  indices_k = position_ids_to_indices(position_ids, max_seqlen_in_batch_k,
+                                      total_k, cu_seqlens_k, real_batch_size);
+  TORCH_CHECK(cu_seqlens_k.size(0) == params.seqlen_k + 1,
+              "cu_seqlen'shape should be params.seqlen_k");
   TORCH_CHECK(indices_k.dtype() == torch::kInt64, "indice should be int64");
 
-  int max_seqlen_in_batch_q = max_seqlen_in_batch_k; 
+  int max_seqlen_in_batch_q = max_seqlen_in_batch_k;
   int total_q = total_k;
   at::Tensor indices_q;
 
-  TORCH_CHECK(params.seqlen_q == params.seqlen_k, "now only support same seqlen for q and k");
+  TORCH_CHECK(params.seqlen_q == params.seqlen_k,
+              "now only support same seqlen for q and k");
 
   if (params.seqlen_q == params.seqlen_k) {
     cu_seqlens_q.copy_(cu_seqlens_k);
@@ -120,7 +123,7 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream
   } else {
     // TODO:(wangtianxing.wtx) support different seqlen_q and seqlen_k
     indices_q = position_ids_to_indices(position_ids, max_seqlen_in_batch_q,
-                              total_q, cu_seqlens_q,real_batch_size);
+                                        total_q, cu_seqlens_q, real_batch_size);
   }
 
   if (max_seqlen_in_batch_q == 1) {
@@ -141,7 +144,9 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream
   // Cast to char to avoid compiler warning about narrowing
   at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
-  at::Tensor pad_softmax_lse = at::empty({real_batch_size,params.h,max_seqlen_in_batch_q},torch::dtype(torch::kFloat).device(torch::kCUDA));
+  at::Tensor pad_softmax_lse =
+      at::empty({real_batch_size, params.h, max_seqlen_in_batch_q},
+                torch::dtype(torch::kFloat).device(torch::kCUDA));
 
   Flash_fwd_params launch_params;
 
@@ -242,16 +247,16 @@ void custom_call_flash_attention_varlen_position_ids_forward(cudaStream_t stream
       run_mha_fwd_<elem_type, kHeadDim>(launch_params, torch_stream);
     });
   });
-  softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse,cu_seqlens_q));
-  
+  softmax_lse.copy_(unpad_softmax_lse(pad_softmax_lse, cu_seqlens_q));
+
   // TODO(wenting.swt): we should pad and unpad q,k,v when head_size_og % 8 != 0
   // sync with cudaEvent
   cudaEventRecord(xla_wait_torch_event, torch_stream);
   cudaStreamWaitEvent(stream, xla_wait_torch_event);
 }
 
-XLA_REGISTER_CUSTOM_CALL_TARGET(custom_call_flash_attention_varlen_position_ids_forward,
-                                "CUDA");
+XLA_REGISTER_CUSTOM_CALL_TARGET(
+    custom_call_flash_attention_varlen_position_ids_forward, "CUDA");
 
 std::vector<xla::XlaOp> BuildFlashAttentionVarlenPositionIdsForward(
     const xla::XlaOp& q, const xla::XlaOp& k, const xla::XlaOp& v,
@@ -310,7 +315,8 @@ torch::lazy::NodePtr FlashAttentionVarlenPositionIdsForward::Clone(
   }
 }
 
-XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower(LoweringContext* loctx) const {
+XlaOpVector FlashAttentionVarlenPositionIdsForward::Lower(
+    LoweringContext* loctx) const {
   xla::XlaOp q = loctx->GetOutputOp(operand(0));
   xla::XlaOp k = loctx->GetOutputOp(operand(1));
   xla::XlaOp v = loctx->GetOutputOp(operand(2));
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h
index 2a37eec1c499..2e27ee973b0f 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h
@@ -9,17 +9,17 @@ namespace torch_xla {
 class FlashAttentionVarlenPositionIdsForward : public XlaNode {
  public:
   FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q,
-                              const torch::lazy::Value& k,
-                              const torch::lazy::Value& v,
-                              const torch::lazy::Value& position_ids,
-                              const std::string params);
+                                         const torch::lazy::Value& k,
+                                         const torch::lazy::Value& v,
+                                         const torch::lazy::Value& position_ids,
+                                         const std::string params);
 
   FlashAttentionVarlenPositionIdsForward(const torch::lazy::Value& q,
-                              const torch::lazy::Value& k,
-                              const torch::lazy::Value& v,
-                              const torch::lazy::Value& position_ids,
-                              const torch::lazy::Value& alibi_slopes,
-                              const std::string params);
+                                         const torch::lazy::Value& k,
+                                         const torch::lazy::Value& v,
+                                         const torch::lazy::Value& position_ids,
+                                         const torch::lazy::Value& alibi_slopes,
+                                         const std::string params);
 
   torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
 
diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp
index 3992c31289b3..44489c0c44e8 100644
--- a/torch_xla/csrc/tensor_methods.cpp
+++ b/torch_xla/csrc/tensor_methods.cpp
@@ -53,8 +53,8 @@
 #include "torch_xla/csrc/ops/flash_attention_forward.h"
 #include "torch_xla/csrc/ops/flash_attention_varlen_backward.h"
 #include "torch_xla/csrc/ops/flash_attention_varlen_forward.h"
-#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h"
 #include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.h"
+#include "torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.h"
 #include "torch_xla/csrc/ops/flip.h"
 #include "torch_xla/csrc/ops/gather.h"
 #include "torch_xla/csrc/ops/generic.h"
@@ -681,22 +681,21 @@ std::vector<XLATensorPtr> flash_attention_varlen_position_ids_forward(
     const XLATensorPtr& q, const XLATensorPtr& k, const XLATensorPtr& v,
     const XLATensorPtr& position_ids, const XLATensorPtr& alibi_slopes,
     const std::string& params) {
-    if (alibi_slopes) {
-      torch::lazy::NodePtr node = 
-          torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+  if (alibi_slopes) {
+    torch::lazy::NodePtr node =
+        torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
             q->GetIrValue(), k->GetIrValue(), v->GetIrValue(),
             position_ids->GetIrValue(), alibi_slopes->GetIrValue(), params);
-      return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
-    } else {
-      torch::lazy::NodePtr node = 
-          torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
+    return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+  } else {
+    torch::lazy::NodePtr node =
+        torch::lazy::MakeNode<FlashAttentionVarlenPositionIdsForward>(
             q->GetIrValue(), k->GetIrValue(), v->GetIrValue(),
             position_ids->GetIrValue(), params);
-      return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
-    } 
+    return q->MakeOutputTensors(node, /*inherit_logical_type=*/false);
+  }
 }
 
-
 std::vector<XLATensorPtr> flash_attention_backward(
     const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
     const XLATensorPtr& v, const XLATensorPtr& out,
@@ -742,7 +741,6 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward(
   }
 }
 
-
 std::vector<XLATensorPtr> flash_attention_varlen_position_ids_backward(
     const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
     const XLATensorPtr& v, const XLATensorPtr& out,

From 74814401e2ab61603a2527a847612e5f4d0c194c Mon Sep 17 00:00:00 2001
From: tianxingwang <wangtianxing.wtx@alibaba-inc.com>
Date: Mon, 2 Dec 2024 14:14:49 +0800
Subject: [PATCH 3/5] reformat files

---
 test/test_flash_attention_forward.py         |  3 -
 test/test_flash_attention_varlen_backward.py | 57 ++++++++++---
 test/test_flash_attention_varlen_forward.py  | 90 ++++++++++++++------
 3 files changed, 105 insertions(+), 45 deletions(-)

diff --git a/test/test_flash_attention_forward.py b/test/test_flash_attention_forward.py
index 6414855a9dd2..9e019ed29325 100644
--- a/test/test_flash_attention_forward.py
+++ b/test/test_flash_attention_forward.py
@@ -136,6 +136,3 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
 
   assert torch.allclose(softmax_lse_xla, softmax_lse, rtol=1e-2, atol=1e-2)
   assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2)
-
-
-
diff --git a/test/test_flash_attention_varlen_backward.py b/test/test_flash_attention_varlen_backward.py
index e71edec0a75f..2da5985b53d4 100644
--- a/test/test_flash_attention_varlen_backward.py
+++ b/test/test_flash_attention_varlen_backward.py
@@ -266,9 +266,10 @@ def test_flash_attn_varlen_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
     ],
 )
 @pytest.mark.parametrize("dropout_p", [0.0])
-def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_p, causal,
-                                    local, alibi, deterministic, mha_type,
-                                    dtype):
+def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d,
+                                                 dropout_p, causal, local,
+                                                 alibi, deterministic, mha_type,
+                                                 dtype):
   if d % 8 != 0:
     pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
 
@@ -368,14 +369,19 @@ def test_flash_attn_varlen_position_ids_backward(seqlen_q, seqlen_k, d, dropout_
   indices_q = indices_q.cpu()
   indices_k = indices_k.cpu()
 
-  q_xla = q.flatten(0,1)[indices_q].unsqueeze(0).to(device)
-  k_xla = k.flatten(0,1)[indices_k].unsqueeze(0).to(device)
-  v_xla = v.flatten(0,1)[indices_k].unsqueeze(0).to(device)
-  do_xla = do.flatten(0,1)[indices_q].unsqueeze(0).to(device)
+  q_xla = q.flatten(0, 1)[indices_q].unsqueeze(0).to(device)
+  k_xla = k.flatten(0, 1)[indices_k].unsqueeze(0).to(device)
+  v_xla = v.flatten(0, 1)[indices_k].unsqueeze(0).to(device)
+  do_xla = do.flatten(0, 1)[indices_q].unsqueeze(0).to(device)
 
   def attention_mask_to_position_ids(attention_mask):
     seqlens = attention_mask.sum(dim=1).flatten()
-    position_ids = torch.cat([torch.arange(0,seqlen,dtype=torch.int32,device=attention_mask.device) for seqlen in seqlens],dim=0)
+    position_ids = torch.cat([
+        torch.arange(
+            0, seqlen, dtype=torch.int32, device=attention_mask.device)
+        for seqlen in seqlens
+    ],
+                             dim=0)
     return position_ids.unsqueeze(0)
 
   position_ids_xla = attention_mask_to_position_ids(attention_mask).to(device)
@@ -389,12 +395,36 @@ def attention_mask_to_position_ids(attention_mask):
       position_ids_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale,
       False, causal, window_size[0], window_size[1], True, None)
 
-  assert torch.allclose(q_xla.cpu(), q_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
-  assert torch.allclose(k_xla.cpu(), k_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
-  assert torch.allclose(v_xla.cpu(), v_cuda.cpu().unsqueeze(0), rtol=1e-3, atol=1e-3, equal_nan=True)
   assert torch.allclose(
-      cu_seqlen_k_xla[:batch_size+1].cpu(), cu_seqlens_k.cpu(), rtol=1e-3, atol=1e-3, equal_nan=True)
-  assert torch.allclose(o_xla.cpu(), o_cuda.unsqueeze(0).cpu(), rtol=1e-2, atol=1e-2, equal_nan=True)
+      q_xla.cpu(),
+      q_cuda.cpu().unsqueeze(0),
+      rtol=1e-3,
+      atol=1e-3,
+      equal_nan=True)
+  assert torch.allclose(
+      k_xla.cpu(),
+      k_cuda.cpu().unsqueeze(0),
+      rtol=1e-3,
+      atol=1e-3,
+      equal_nan=True)
+  assert torch.allclose(
+      v_xla.cpu(),
+      v_cuda.cpu().unsqueeze(0),
+      rtol=1e-3,
+      atol=1e-3,
+      equal_nan=True)
+  assert torch.allclose(
+      cu_seqlen_k_xla[:batch_size + 1].cpu(),
+      cu_seqlens_k.cpu(),
+      rtol=1e-3,
+      atol=1e-3,
+      equal_nan=True)
+  assert torch.allclose(
+      o_xla.cpu(),
+      o_cuda.unsqueeze(0).cpu(),
+      rtol=1e-2,
+      atol=1e-2,
+      equal_nan=True)
 
   q_xla.requires_grad = True
   k_xla.requires_grad = True
@@ -425,4 +455,3 @@ def attention_mask_to_position_ids(attention_mask):
   assert torch.allclose(dq_cuda, dq_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
   assert torch.allclose(dk_cuda, dk_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
   assert torch.allclose(dv_cuda, dv_xla, rtol=1e-2, atol=1e-2, equal_nan=True)
-
diff --git a/test/test_flash_attention_varlen_forward.py b/test/test_flash_attention_varlen_forward.py
index 119f90d6c6cd..843573436fb8 100644
--- a/test/test_flash_attention_varlen_forward.py
+++ b/test/test_flash_attention_varlen_forward.py
@@ -178,7 +178,6 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
       return_attn_probs=True,
   )
 
-
   out_fa = pad_input(out_fa, indices_q, batch_size, seqlen_q)
 
   q = q.cpu().detach()
@@ -227,8 +226,13 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
   assert torch.allclose(
       cu_seqlen_k_xla, cu_seqlens_k, rtol=1e-3, atol=1e-3, equal_nan=True)
   for i in range(len(cu_seq_lens[0]) - 1):
-    seqlen = cu_seq_lens[0][i + 1] -  cu_seq_lens[0][i]
-    assert torch.allclose(softmax_lse_xla[i,:,:seqlen],softmax_lse[i,:,:seqlen],rtol=1e-2,atol=1e-3,equal_nan=True)
+    seqlen = cu_seq_lens[0][i + 1] - cu_seq_lens[0][i]
+    assert torch.allclose(
+        softmax_lse_xla[i, :, :seqlen],
+        softmax_lse[i, :, :seqlen],
+        rtol=1e-2,
+        atol=1e-3,
+        equal_nan=True)
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -248,9 +252,10 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
     ],
 )
 @pytest.mark.parametrize("dropout_p", [0.0])
-def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, dropout_p, causal,
-                           softmax_scale, local, alibi, deterministic, mha_type,
-                           dtype):
+def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d,
+                                             dropout_p, causal, softmax_scale,
+                                             local, alibi, deterministic,
+                                             mha_type, dtype):
   if d % 8 != 0:
     pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
 
@@ -266,15 +271,17 @@ def test_flash_attn_varlen_from_position_ids(max_seqlen_q, max_seqlen_k, d, drop
       torch.randint(0, max_seqlen_k, (2,)).tolist())
   torch.cuda.synchronize()
 
-  def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype, n_heads_q, n_heads_k, device, head_dim, use_same_seqlen):
+  def generate_qkv_and_position_ids(batch_size, max_seqlen_q, max_seqlen_k,
+                                    dtype, n_heads_q, n_heads_k, device,
+                                    head_dim, use_same_seqlen):
     '''
     generate varlen qkv and postion_ids and pad total seqlen to 8
     '''
-    seq_len_q = torch.randint(1,max_seqlen_q+1,(batch_size,))
+    seq_len_q = torch.randint(1, max_seqlen_q + 1, (batch_size,))
     if use_same_seqlen:
       seq_len_k = seq_len_q.clone()
     else:
-      seq_len_k = torch.randint(1,max_seqlen_k+1,(batch_size,))
+      seq_len_k = torch.randint(1, max_seqlen_k + 1, (batch_size,))
     total_q = seq_len_q.sum().item()
     total_k = seq_len_k.sum().item()
 
@@ -293,38 +300,57 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
       total_k += padd_k
     assert total_k % 8 == 0
 
-    q = torch.randn((1,total_q,n_heads_q,head_dim),dtype=dtype,device=device)
-    k = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device)
-    v = torch.randn((1,total_k,n_heads_k,head_dim),dtype=dtype,device=device)
+    q = torch.randn((1, total_q, n_heads_q, head_dim),
+                    dtype=dtype,
+                    device=device)
+    k = torch.randn((1, total_k, n_heads_k, head_dim),
+                    dtype=dtype,
+                    device=device)
+    v = torch.randn((1, total_k, n_heads_k, head_dim),
+                    dtype=dtype,
+                    device=device)
 
     assert torch.all(seq_len_q > 0)
     assert torch.all(seq_len_k > 0)
 
-    position_ids_q = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_q],dim=0).unsqueeze(0)
-    position_ids_k = torch.cat([torch.arange(0,seq_len,dtype=torch.int32,device=device) for seq_len in seq_len_k],dim=0).unsqueeze(0)
+    position_ids_q = torch.cat([
+        torch.arange(0, seq_len, dtype=torch.int32, device=device)
+        for seq_len in seq_len_q
+    ],
+                               dim=0).unsqueeze(0)
+    position_ids_k = torch.cat([
+        torch.arange(0, seq_len, dtype=torch.int32, device=device)
+        for seq_len in seq_len_k
+    ],
+                               dim=0).unsqueeze(0)
     assert position_ids_q.shape[1] % 8 == 0
     assert position_ids_k.shape[1] % 8 == 0
 
-    return q,k,v,position_ids_q,position_ids_k
+    return q, k, v, position_ids_q, position_ids_k
 
-  q,k,v,position_ids_q,position_ids_k = generate_qkv_and_position_ids(batch_size,max_seqlen_q,max_seqlen_k,dtype,nheads,nheads_k,device,d,max_seqlen_q == max_seqlen_k)
+  q, k, v, position_ids_q, position_ids_k = generate_qkv_and_position_ids(
+      batch_size, max_seqlen_q, max_seqlen_k, dtype, nheads, nheads_k, device,
+      d, max_seqlen_q == max_seqlen_k)
 
-  indices_q = torch.arange(0,position_ids_q.size(1),device=device,dtype=torch.int64)
-  indices_k = torch.arange(0,position_ids_k.size(1),device=device,dtype=torch.int64)
+  indices_q = torch.arange(
+      0, position_ids_q.size(1), device=device, dtype=torch.int64)
+  indices_k = torch.arange(
+      0, position_ids_k.size(1), device=device, dtype=torch.int64)
 
   seq_q_start_idx = indices_q[position_ids_q.squeeze() == 0]
   seq_k_start_idx = indices_k[position_ids_k.squeeze() == 0]
 
-  cu_seq_lens_q = F.pad(seq_q_start_idx,(0,1),value=position_ids_q.size(1)).to(torch.int32)
-  cu_seq_lens_k = F.pad(seq_k_start_idx,(0,1),value=position_ids_k.size(1)).to(torch.int32)
+  cu_seq_lens_q = F.pad(
+      seq_q_start_idx, (0, 1), value=position_ids_q.size(1)).to(torch.int32)
+  cu_seq_lens_k = F.pad(
+      seq_k_start_idx, (0, 1), value=position_ids_k.size(1)).to(torch.int32)
 
   max_seqlen_q = (cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]).max().item()
   max_seqlen_k = (cu_seq_lens_k[1:] - cu_seq_lens_k[:-1]).max().item()
 
-  q_cuda = q.reshape(-1,nheads,d).cuda()
-  k_cuda = k.reshape(-1,nheads_k,d).cuda()
-  v_cuda = v.reshape(-1,nheads_k,d).cuda()
-
+  q_cuda = q.reshape(-1, nheads, d).cuda()
+  k_cuda = k.reshape(-1, nheads_k, d).cuda()
+  v_cuda = v.reshape(-1, nheads_k, d).cuda()
 
   if alibi:
     alibi_slopes = torch.rand(
@@ -383,7 +409,6 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
       position_ids_q_xla.contiguous(), alibi_slopes, dropout_p, softmax_scale,
       False, causal, window_size[0], window_size[1], True, None)
 
-
   ta.mark_step(wait=True)
   q_xla = q_xla.cpu().detach()
   k_xla = k_xla.cpu().detach()
@@ -404,10 +429,19 @@ def generate_qkv_and_position_ids(batch_size,max_seqlen_q, max_seqlen_k, dtype,
   assert torch.allclose(k_xla, k, rtol=1e-3, atol=1e-3, equal_nan=True)
   assert torch.allclose(v_xla, v, rtol=1e-3, atol=1e-3, equal_nan=True)
   assert torch.allclose(
-      cu_seq_len_k_xla[:batch_size+1], cu_seq_lens_k, rtol=1e-3, atol=1e-3, equal_nan=True)
+      cu_seq_len_k_xla[:batch_size + 1],
+      cu_seq_lens_k,
+      rtol=1e-3,
+      atol=1e-3,
+      equal_nan=True)
   assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
   for i in range(batch_size):
     start_idx = cu_seq_len_q_xla[i]
-    end_idx = cu_seq_len_q_xla[i+1]
+    end_idx = cu_seq_len_q_xla[i + 1]
     seqlen = end_idx - start_idx
-    assert torch.allclose(softmax_lse_xla[0,:,start_idx:end_idx],softmax_lse[i,:,:seqlen],rtol=1e-3,atol=1e-3,equal_nan=True)
+    assert torch.allclose(
+        softmax_lse_xla[0, :, start_idx:end_idx],
+        softmax_lse[i, :, :seqlen],
+        rtol=1e-3,
+        atol=1e-3,
+        equal_nan=True)

From 7e73ee211eaf75788e3b83244d760dd635cc7f09 Mon Sep 17 00:00:00 2001
From: tianxingwang <wangtianxing.wtx@alibaba-inc.com>
Date: Tue, 3 Dec 2024 14:46:28 +0800
Subject: [PATCH 4/5] refine code

---
 torch_xla/csrc/flash_attention_utils.cpp      | 23 +++++++++----------
 .../ops/flash_attention_varlen_forward.cpp    |  2 --
 ...attention_varlen_position_ids_backward.cpp |  4 ++--
 ..._attention_varlen_position_ids_forward.cpp |  3 ---
 4 files changed, 13 insertions(+), 19 deletions(-)

diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp
index 950eaf50ce39..376f492a1bb9 100644
--- a/torch_xla/csrc/flash_attention_utils.cpp
+++ b/torch_xla/csrc/flash_attention_utils.cpp
@@ -3,8 +3,6 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <torch/extension.h>
 
-#include <iostream>
-
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
@@ -390,15 +388,15 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
     TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
     TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
     TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 ||
-                batch_size == 1);
+                batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future
     TORCH_CHECK(
         cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
-            cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}),
-        "cu_seqlens_q shape should be batch_size+1 or seqlen_q");
+            cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}),
+        "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1");
     TORCH_CHECK(
         cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
-            cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}),
-        "cu_seqlens_k shape should be batch_size+1 or seqlen_k");
+            cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}),
+        "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1");
   }
 
   int alibi_slopes_batch_stride = 0;
@@ -451,11 +449,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
 
   auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
   torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32))
-                           .unsqueeze(1);  // (batch_size,1)
+                           .unsqueeze(1);
   torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32))
-                           .unsqueeze(0);  // (1,seqlen)
+                           .unsqueeze(0);
   torch::Tensor mask =
-      cols < nonzero_counts.unsqueeze(1);  // (1,seqlen) < (batch_size, 1)
+      cols < nonzero_counts.unsqueeze(1);
   max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();
 
   torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
@@ -496,7 +494,7 @@ torch::Tensor unpad_softmax_lse(
   at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options());
   result.copy_(pad_softmax_lse.transpose(1, 2)
                    .reshape({batch_size * max_seqlen, nhead})
-                   .index({indices, torch::indexing::Slice()}));
+                   .index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future
   return result.transpose(0, 1).unsqueeze(0);
 }
 
@@ -514,7 +512,7 @@ torch::Tensor pad_softmax_lse(
               "indice should be same size with softmax_lse")
 
   at::Tensor result =
-      at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options());
+      at::zeros({batch_size * max_seq_len, nheads}, softmax_lse.options());
 
   result.index_put_({indices, torch::indexing::Slice()},
                     softmax_lse.squeeze(0).transpose(0, 1));
@@ -527,6 +525,7 @@ at::Tensor position_ids_to_indices(const at::Tensor& position_ids,
                                    int& max_seqlen_in_batch, int& total,
                                    at::Tensor& cu_seqlen,
                                    int& real_batch_size) {
+  cu_seqlen.fill_(-1);
   at::Tensor flatten_position_ids = position_ids.flatten();
   at::Tensor indices =
       torch::arange(flatten_position_ids.size(0),
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
index 92b06f1addf5..832e6c280faf 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
@@ -4,8 +4,6 @@
 #include <c10/cuda/CUDAGuard.h>
 #include <torch/extension.h>
 
-#include <iostream>
-
 #include "cutlass/numeric_types.h"
 #include "flash.h"
 #include "static_switch.h"
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
index 5cb54805a30f..132742921833 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
@@ -87,9 +87,9 @@ void custom_call_flash_attention_varlen_position_ids_backward(
   at::Tensor softmax_lse =
       torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
                        opts.dtype(torch::kFloat));
-  at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1},
+  at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1},
                                              opts.dtype(torch::kInt32));
-  at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1},
+  at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1},
                                              opts.dtype(torch::kInt32));
 
   // Outputs
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
index 945a807236cc..23d5d1c4cc6f 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp
@@ -4,8 +4,6 @@
 #include <c10/cuda/CUDAGuard.h>
 #include <torch/extension.h>
 
-#include <iostream>
-
 #include "cutlass/numeric_types.h"
 #include "flash.h"
 #include "static_switch.h"
@@ -96,7 +94,6 @@ void custom_call_flash_attention_varlen_position_ids_forward(
       torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64));
   softmax_lse.fill_(0);
   o_output.fill_(0);
-  cu_seqlens_k.fill_(-1);
 
   int max_seqlen_in_batch_k = params.seqlen_k;
   int total_k = params.b * params.seqlen_k;

From e531a14583ab5c602ce8aa3f2df4cc4689f666c4 Mon Sep 17 00:00:00 2001
From: tianxingwang <wangtianxing.wtx@alibaba-inc.com>
Date: Tue, 3 Dec 2024 14:49:54 +0800
Subject: [PATCH 5/5] refine code

---
 torch_xla/csrc/flash_attention_utils.cpp      | 31 ++++++++++++-------
 ...attention_varlen_position_ids_backward.cpp |  8 ++---
 2 files changed, 23 insertions(+), 16 deletions(-)

diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp
index 376f492a1bb9..6536dc96c6cd 100644
--- a/torch_xla/csrc/flash_attention_utils.cpp
+++ b/torch_xla/csrc/flash_attention_utils.cpp
@@ -388,14 +388,17 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
     TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
     TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
     TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 ||
-                batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future
+                batch_size == 1);  // now pack qkv batch size only support 1,
+                                   // maybe need to change in the future
     TORCH_CHECK(
         cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
-            cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}),
+            cu_seqlens_q.value().sizes() ==
+                torch::IntArrayRef({seqlen_q * batch_size + 1}),
         "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1");
     TORCH_CHECK(
         cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
-            cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}),
+            cu_seqlens_k.value().sizes() ==
+                torch::IntArrayRef({seqlen_k * batch_size + 1}),
         "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1");
   }
 
@@ -448,12 +451,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
   std::array<int64_t, 2> shape = {batch_size, seqlen};
 
   auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
-  torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32))
-                           .unsqueeze(1);
-  torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32))
-                           .unsqueeze(0);
-  torch::Tensor mask =
-      cols < nonzero_counts.unsqueeze(1);
+  torch::Tensor rows =
+      torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1);
+  torch::Tensor cols =
+      torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0);
+  torch::Tensor mask = cols < nonzero_counts.unsqueeze(1);
   max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();
 
   torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
@@ -492,9 +494,14 @@ torch::Tensor unpad_softmax_lse(
       cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen,
                             torch::kInt64, max_seqlen_in_batch, total);
   at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options());
-  result.copy_(pad_softmax_lse.transpose(1, 2)
-                   .reshape({batch_size * max_seqlen, nhead})
-                   .index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future
+  result.copy_(
+      pad_softmax_lse.transpose(1, 2)
+          .reshape({batch_size * max_seqlen, nhead})
+          .index({indices,
+                  torch::indexing::Slice()}));  // if packed tensor's batch size
+                                                // > 1 is supported in the
+                                                // future, need to modify here
+                                                // in the future
   return result.transpose(0, 1).unsqueeze(0);
 }
 
diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
index 132742921833..0e508b80af6d 100644
--- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
+++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp
@@ -87,10 +87,10 @@ void custom_call_flash_attention_varlen_position_ids_backward(
   at::Tensor softmax_lse =
       torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
                        opts.dtype(torch::kFloat));
-  at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1},
-                                             opts.dtype(torch::kInt32));
-  at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1},
-                                             opts.dtype(torch::kInt32));
+  at::Tensor cu_seqlens_q = torch::from_blob(
+      buffers[6], {params.b * params.seqlen_q + 1}, opts.dtype(torch::kInt32));
+  at::Tensor cu_seqlens_k = torch::from_blob(
+      buffers[7], {params.b * params.seqlen_k + 1}, opts.dtype(torch::kInt32));
 
   // Outputs
   at::Tensor dq =