Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gargamit/fix long seq bug #6

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
7 changes: 6 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# Number of original input tokens (without any decoding).
# Some model (phi3-) need this info to decide model settings
num_orig_input_tokens_tensor: torch.Tensor

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
Expand Down Expand Up @@ -184,7 +188,8 @@ def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:

@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def prefill_metadata(
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=(
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
Expand Down Expand Up @@ -248,6 +251,9 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
num_orig_input_tokens_tensor=(
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -291,6 +293,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down Expand Up @@ -427,7 +431,8 @@ def _add_seq_group(
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -499,6 +504,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
Expand All @@ -507,6 +515,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
Expand Down
20 changes: 18 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def graph_capture(self, max_batch_size: int):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._num_orig_input_tokens_tensor = torch.zeros(
max_batch_size, dtype=torch.int32, device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
Expand All @@ -162,6 +164,7 @@ def graph_capture(self, max_batch_size: int):
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
del self._num_orig_input_tokens_tensor

def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
Expand Down Expand Up @@ -211,6 +214,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
num_orig_input_tokens_tensor=self.
_num_orig_input_tokens_tensor[:batch_size],
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
Expand All @@ -232,10 +237,15 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):

def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
"slot_mapping":
attn_metadata.slot_mapping,
"num_orig_input_tokens_tensor":
attn_metadata.num_orig_input_tokens_tensor,
}

def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
input_buffers["num_orig_input_tokens_tensor"].copy_(
attn_metadata.num_orig_input_tokens_tensor, non_blocking=True)
return

def begin_forward(self, model_input):
Expand Down Expand Up @@ -506,7 +516,8 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
self.paged_kv_last_page_len.append(last_page_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -576,6 +587,10 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
Expand All @@ -602,6 +617,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -164,6 +166,8 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down
28 changes: 24 additions & 4 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def _add_seq_group(
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -258,13 +259,18 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
Expand Down Expand Up @@ -294,11 +300,16 @@ def graph_capture(self, max_batch_size: int):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

self._num_orig_input_tokens_tensor = torch.zeros(
max_batch_size, dtype=torch.int32, device=self.runner.device)

yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._num_orig_input_tokens_tensor

def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
Expand All @@ -313,6 +324,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
num_orig_input_tokens_tensor=self.
_num_orig_input_tokens_tensor[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
Expand All @@ -326,9 +339,14 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):

def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
return {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"slot_mapping":
attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables":
attn_metadata.decode_metadata.block_tables,
"num_orig_input_tokens_tensor":
attn_metadata.num_orig_input_tokens_tensor,
}

def prepare_graph_input_buffers(self, input_buffers,
Expand All @@ -337,6 +355,8 @@ def prepare_graph_input_buffers(self, input_buffers,
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
input_buffers["num_orig_input_tokens_tensor"].copy_(
attn_metadata.num_orig_input_tokens_tensor, non_blocking=True)

def begin_forward(self, model_input) -> None:
return
8 changes: 8 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
num_orig_input_tokens_tensor = (
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])

Expand All @@ -211,6 +214,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -245,6 +249,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
num_orig_input_tokens_tensor = (
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills])

# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata(
Expand All @@ -253,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def __init__(
self.short_mscale = short_mscale
self.long_mscale = long_mscale

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = self._compute_cos_sin_cache(max_position_embeddings,
short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
Expand Down Expand Up @@ -586,13 +586,18 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
*,
num_orig_input_tokens_tensor: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)

k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
long_prompt_offset = torch.where(
num_orig_input_tokens_tensor <= k,
torch.zeros_like(num_orig_input_tokens_tensor),
torch.full_like(num_orig_input_tokens_tensor,
self.max_position_embeddings))
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
Expand Down
Loading
Loading