Skip to content

Commit

Permalink
feat: expose pytorch api for block sparse attention (#375)
Browse files Browse the repository at this point in the history
The block sparse attention (for any block size (R, C)) are hidden in
flashinfer's codebase but it was never exposed explicitly in python. As
requested in #367 , this PR implements the PyTorch APIs for block sparse
attention, accordingly to our experiments, it can greatly accelerate
attention computation with low density (10x for Tree Attention in
Sequoia).
  • Loading branch information
yzh119 authored Jul 17, 2024
1 parent b2d5994 commit 4bba6fa
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 7 deletions.
13 changes: 13 additions & 0 deletions docs/api/python/sparse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _apisparse:

flashinfer.sparse
=================

Kernels for block sparse flashattention.

.. currentmodule:: flashinfer.sparse

.. autoclass:: BlockSparseAttentionWrapper
:members:

.. automethod:: __init__
13 changes: 7 additions & 6 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -751,15 +751,16 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, uint32_t* v_smem_o
*v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_in;
}

template <uint32_t num_frags_x, uint32_t num_frags_y>
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], float (*d)[2]) {
template <uint32_t num_frags_x, uint32_t num_frags_y, typename DTypeQKAccum>
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], DTypeQKAccum (*m)[2],
float (*d)[2]) {
float d_rcp[num_frags_x][2];
// compute reciprocal of d
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
d_rcp[fx][j] = math::ptx_rcp(d[fx][j]);
d_rcp[fx][j] = (m[fx][j] != DTypeQKAccum(-5e4)) ? math::ptx_rcp(d[fx][j]) : 0.f;
}
}

Expand Down Expand Up @@ -1161,7 +1162,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC
o_frag, (float*)smem, m, d, warp_idx, lane_idx);

// normalize d
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);

// write back
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
Expand Down Expand Up @@ -1428,7 +1429,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
o_frag, (float*)smem, m, d, warp_idx, lane_idx);

// normalize d
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);

const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);

Expand Down Expand Up @@ -1719,7 +1720,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
o_frag, (float*)smem, m, d, warp_idx, lane_idx);

// normalize d
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);

const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);

Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BatchPrefillWithRaggedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from .sparse import BlockSparseAttentionWrapper
from .cascade import (
merge_state,
merge_state_in_place,
Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def top_k_top_p_sampling_from_probs(
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')
tensor([True, True, True, True], device='cuda:0')
Notes
-----
Expand Down
292 changes: 292 additions & 0 deletions python/flashinfer/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional
import torch
import logging
from .prefill import _compute_page_qk_indptr
from .quantization import segment_packbits
from .utils import (
check_pos_encoding_mode,
check_kv_layout,
is_float8,
expand_5d,
PosEncodingMode,
TensorLayout,
)

try:
from . import _kernels
except ImportError as e:
import os
import logging

if os.environ.get("BUILD_DOC", "0") == "1":
_kernels = None
logging.warning("Kernels are not loaded in documentation build mode.")
else:
raise e


class BlockSparseAttentionWrapper:
def __init__(
self,
workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
):
r"""Constructs of :class:`BlockSparseAttentionWrapper`.
Warning(Zihao): this is an experimental API and subject to change.
Parameters
----------
workspace_buffer : torch.Tensor
The user reserved workspace buffer used to store auxiliary data structures,
recommended size is 128MB, the device of the workspace buffer should be the
same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
"""
check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self._workspace_buffer = workspace_buffer
self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper(
TensorLayout[kv_layout].value,
False, # use_cuda_graph
)

def begin_forward(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
M: int,
N: int,
R: int,
C: int,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
mask: Optional[torch.Tensor] = None,
packed_mask: Optional[torch.Tensor] = None,
q_data_type: str = "float16",
):
r"""Create auxiliary data structures for block sparse attention.
Parameters
----------
indptr : torch.Tensor
The indptr of the block-sparse matrix, shape (MB + 1,), where MB is the number of blocks in the row dimension.
indices: torch.Tensor
The indices of the block-sparse matrix, shape (nnz,), where nnz is the number of non-zero blocks.
M : int
The number of rows of the block-sparse matrix, MB = ceil_div(M, R).
N : int
The number of columns of the block-sparse matrix, NB = ceil_div(N, C).
R : int
The number of rows in each block.
C : int
The number of columns in each block.
num_qo_heads : int
The number of heads in the query/output tensor.
num_kv_heads : int
The number of heads in the key/value tensor.
head_dim : int
The dimension of each head.
mask : torch.Tensor, optional
The flattened mask tensor, shape (nnz * R * C,), where nnz is the number of non-zero blocks.
If every block is full, then we don't need to provide the mask tensor.
packed_mask : torch.Tensor, optional
The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored.
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
q_data_type : str, optional
The data type of the query tensor.
The :meth:`begin_forward` method should be called before any :meth:`forward` or
:meth:`forward_return_lse` calls, auxiliary data structures will be created
during this call and cached for multiple forward calls.
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
"""
num_rows = len(indptr) - 1
qo_indptr_host = R * torch.arange(num_rows + 1, dtype=torch.int32)
qo_indptr_host[-1] = M
self._qo_indptr = qo_indptr_host.to(indptr.device)
row_empty = indptr[1:] == indptr[:1]
if indices.max().item() * C > N:
raise ValueError("indices out of bound")
last_block_pos = indices[torch.clamp(indptr[1:], min=1) - 1]
last_block_pos.masked_fill_(row_empty, 0)
last_block_len = torch.clamp(N - last_block_pos * C, max=C)

if mask is not None or packed_mask is not None:
qk_indptr = _compute_page_qk_indptr(
self._qo_indptr,
indptr, # paged_kv_indptr
last_block_len, # paged_kv_last_page_len
C, # page_size
)
if packed_mask is None and mask is not None:
# create packed mask from mask
packed_mask, qk_indptr = segment_packbits(
mask.contiguous().view(-1), qk_indptr, bitorder="little"
)

self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len = last_block_len
if packed_mask is not None:
self._packed_mask_buf = packed_mask
self._qk_indptr_buf = qk_indptr
else:
self._packed_mask_buf = None

empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)

self._wrapper.begin_forward(
self._workspace_buffer,
self._qo_indptr,
self._paged_kv_indptr_buf,
num_rows,
num_qo_heads,
num_kv_heads,
head_dim,
C,
empty_q_data,
)

def end_forward(self):
r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
self._qo_indptr = None
self._paged_kv_indptr_buf = None
self._paged_kv_indices_buf = None
self._paged_kv_last_page_len = None
self._packed_mask_buf = None
self._qk_indptr_buf = None

def forward(
self,
q: torch.Tensor,
kv_data: torch.Tensor,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
logits_soft_cap: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
r"""Compute block-sparse attention between Q/K/V tensors.
Warning(Zihao): in the next release, kv_data will be decoupled into standalone k/v tensors, each
with shape (N, num_kv_heads, head_dim).
Parameters
----------
q : torch.Tensor
The query tensor, shape (M, num_qo_heads, head_dim).
kv_data : torch.Tensor
The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim).
pos_encoding_mode : str, optional
The position encoding applied inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
Default is ``NONE``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
logits_soft_cap : Optional[float]
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
formula:
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
where :math:`x` is the input logits.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to
``1.0 / sqrt(head_dim)``.
rope_scale : Optional[float]
The scale used in RoPE interpolation, if not provided, will be set to
``1.0``.
rope_theta : Optional[float]
The theta used in RoPE, if not provided, will be set to ``1e4``.
Returns
-------
torch.Tensor
The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
"""
check_pos_encoding_mode(pos_encoding_mode)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
if is_float8(q):
logging.warning(
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
" are casted to f16, which could result in performance degradation."
)
q = q.to(torch.float16)
kv_data = kv_data.to(torch.float16)

kv_data = expand_5d(kv_data, self._kv_layout)

if self._packed_mask_buf is None:
return self._wrapper.forward(
q,
self._qo_indptr,
kv_data,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len,
False, # causal
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
False, # return LSE
)[0]
else:
return self._wrapper.forward_custom_mask(
q,
self._qo_indptr,
kv_data,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len,
self._packed_mask_buf,
self._qk_indptr_buf,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
False, # return LSE
)[0]
Loading

0 comments on commit 4bba6fa

Please sign in to comment.