Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interleave sliding window by using fusedsdpa kernel. #725

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@d4f37bb
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@9e45874
33 changes: 31 additions & 2 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import math
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.flags import enabled_flags
Expand Down Expand Up @@ -224,11 +225,16 @@
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
if attn_metadata is None or attn_metadata.block_list is None:
valid_seq_lengths=attn_metadata.seq_lens_tensor
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
# Force to use fused for performance/memory benefit
# Also the non fusedsdpa has accuracy issue to be fixed
if self.sliding_window is None:
self.fused_scaled_dot_product_attention = None
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
Expand All @@ -238,6 +244,9 @@
attn_bias.add_(position_bias)
else:
attn_bias = None
if self.sliding_window:
attn_bias = _make_sliding_window_bias(batch_size, seq_len, attn_metadata.seq_lens_tensor, self.sliding_window, query.dtype)

Check failure on line 248 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:248:81: E501 Line too long (143 > 80)
valid_seq_lengths = None #TODO: remove after fusedsdpa optimization is done

Check failure on line 249 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:249:81: E501 Line too long (95 > 80)

out = ops.prompt_attention(
query.view(query_shape),
Expand All @@ -249,7 +258,7 @@
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
valid_seq_lengths=valid_seq_lengths,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
else:
Expand All @@ -269,7 +278,7 @@
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
output = out.reshape(batch_size, seq_len, hidden_size)
else:

# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
Expand All @@ -288,7 +297,7 @@
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Check failure on line 300 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Item "None" of "Optional[Any]" has no attribute "view" [union-attr]

Check failure on line 300 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Item "None" of "Any | None" has no attribute "view" [union-attr]

Check failure on line 300 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Item "None" of "Any | None" has no attribute "view" [union-attr]

Check failure on line 300 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Item "None" of "Any | None" has no attribute "view" [union-attr]

def forward_encoder_decoder(
self,
Expand Down Expand Up @@ -432,3 +441,23 @@
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias

def _make_sliding_window_bias(
batch_size: int,
seq_len: int,
query_lens_t: torch.tensor,
window_size:int,
dtype: torch.dtype,
):
shift = 0
query_lens_t = query_lens_t.reshape(batch_size, 1)
tensor = torch.full((batch_size, 1, seq_len, seq_len), dtype=dtype, fill_value=1)

Check failure on line 454 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:454:81: E501 Line too long (85 > 80)
mask = torch.tril(tensor, diagonal=shift)

len_mask = torch.arange(0, seq_len, device=query_lens_t.device, dtype=torch.int32).view(seq_len,1)

Check failure on line 457 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:457:81: E501 Line too long (102 > 80)
len_mask = len_mask.ge(query_lens_t.unsqueeze(-1)).view(batch_size, 1, seq_len, 1)

Check failure on line 458 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/attention/backends/hpu_attn.py:458:81: E501 Line too long (86 > 80)
len_mask= torch.where(len_mask == False, 1, 0)

Check failure on line 459 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E712)

vllm/attention/backends/hpu_attn.py:459:27: E712 Avoid equality comparisons to `False`; use `if not len_mask:` for false checks
mask = mask.logical_and(len_mask)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias = torch.where(mask, 0, -math.inf)
return attn_bias
Loading