diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ccfc6b254c1e7..914a3a86d846a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -103,6 +103,10 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # Number of original input tokens (without any decoding). + # Some model (phi3-) need this info to decide model settings + num_orig_input_tokens_tensor: torch.Tensor + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: @@ -184,7 +188,8 @@ def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: @abstractmethod def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int) -> T: + num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int, + batch_size: int) -> T: """Build attention metadata with on-device tensors.""" raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebbd..4662fefab4722 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -220,6 +220,9 @@ def prefill_metadata( query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=( + None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]), block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, ) @@ -248,6 +251,9 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, + num_orig_input_tokens_tensor=( + None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]), block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 30ce715d5d05a..a4fcf3644fd0f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -263,6 +263,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -291,6 +293,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -427,7 +431,8 @@ def _add_seq_group( self.block_size, inter_data.block_tables) def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): + num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int, + batch_size: int): """Build attention metadata with on-device tensors. Args: @@ -499,6 +504,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], dim=0, dtype=query_start_loc.dtype, out=query_start_loc[1:]) + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=device) return FlashAttentionMetadata( num_prefills=self.num_prefills, @@ -507,6 +515,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index aa9d4a71dbf87..21095b5a4ea96 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -142,6 +142,8 @@ def graph_capture(self, max_batch_size: int): device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) + self._num_orig_input_tokens_tensor = torch.zeros( + max_batch_size, dtype=torch.int32, device=self.runner.device) self._graph_decode_workspace_buffer = self._get_workspace_buffer() self._graph_indices_buffer = torch.empty( max_batch_size * self.runner.cache_config.num_gpu_blocks, @@ -162,6 +164,7 @@ def graph_capture(self, max_batch_size: int): del self._graph_indptr_buffer del self._graph_last_page_len_buffer del self._graph_decode_wrapper + del self._num_orig_input_tokens_tensor def graph_clone(self, batch_size: int): assert self._is_graph_capturing @@ -211,6 +214,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): slot_mapping=self._graph_slot_mapping[:batch_size], num_prefill_tokens=0, num_decode_tokens=batch_size, + num_orig_input_tokens_tensor=self. + _num_orig_input_tokens_tensor[:batch_size], max_prefill_seq_len=0, block_tables=self._graph_block_tables, paged_kv_indptr=paged_kv_indptr_tensor_host, @@ -232,10 +237,15 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): def get_graph_input_buffers(self, attn_metadata): return { - "slot_mapping": attn_metadata.slot_mapping, + "slot_mapping": + attn_metadata.slot_mapping, + "num_orig_input_tokens_tensor": + attn_metadata.num_orig_input_tokens_tensor, } def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + input_buffers["num_orig_input_tokens_tensor"].copy_( + attn_metadata.num_orig_input_tokens_tensor, non_blocking=True) return def begin_forward(self, model_input): @@ -506,7 +516,8 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): self.paged_kv_last_page_len.append(last_page_len) def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): + num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int, + batch_size: int): """Build attention metadata with on-device tensors. Args: @@ -576,6 +587,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=device) + if len(self.paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", @@ -602,6 +617,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..2608a27fd2679 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -136,6 +136,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -164,6 +166,8 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb15..87da574ef3634 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -190,7 +190,8 @@ def _add_seq_group( self.block_size, inter_data.block_tables) def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): + num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int, + batch_size: int): """Build attention metadata with on-device tensors. Args: @@ -258,6 +259,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=device) + return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -265,6 +270,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, @@ -294,11 +300,16 @@ def graph_capture(self, max_batch_size: int): device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) + + self._num_orig_input_tokens_tensor = torch.zeros( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens del self._graph_block_tables + del self._num_orig_input_tokens_tensor def graph_clone(self, batch_size: int) -> "CommonAttentionState": assert self._is_graph_capturing @@ -313,6 +324,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): slot_mapping=self._graph_slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], + num_orig_input_tokens_tensor=self. + _num_orig_input_tokens_tensor[:batch_size], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, @@ -326,9 +339,14 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: return { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, + "slot_mapping": + attn_metadata.slot_mapping, + "seq_lens_tensor": + attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": + attn_metadata.decode_metadata.block_tables, + "num_orig_input_tokens_tensor": + attn_metadata.num_orig_input_tokens_tensor, } def prepare_graph_input_buffers(self, input_buffers, @@ -337,6 +355,8 @@ def prepare_graph_input_buffers(self, input_buffers, attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + input_buffers["num_orig_input_tokens_tensor"].copy_( + attn_metadata.num_orig_input_tokens_tensor, non_blocking=True) def begin_forward(self, model_input) -> None: return diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01d..36002b589882d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -200,6 +200,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: self.seq_lens_tensor[:self.num_prefills]) context_lens_tensor = (None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills]) + num_orig_input_tokens_tensor = ( + None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) @@ -211,6 +214,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -245,6 +249,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) + num_orig_input_tokens_tensor = ( + None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]) # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( @@ -253,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, seq_lens_tensor=seq_lens_tensor, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index c5a0278e485d4..ce88a7c9b7beb 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -540,8 +540,8 @@ def __init__( self.short_mscale = short_mscale self.long_mscale = long_mscale - short_cache = self._compute_cos_sin_cache( - original_max_position_embeddings, short_factor, short_mscale) + short_cache = self._compute_cos_sin_cache(max_position_embeddings, + short_factor, short_mscale) short_cache = short_cache.to(dtype) self.register_buffer("short_cos_sin_cache", short_cache, @@ -586,13 +586,18 @@ def forward( query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + *, + num_orig_input_tokens_tensor: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) k = self.original_max_position_embeddings - long_prompt_offset = (torch.any(positions > k).float() * - torch.full_like(positions, k)).long() + long_prompt_offset = torch.where( + num_orig_input_tokens_tensor <= k, + torch.zeros_like(num_orig_input_tokens_tensor), + torch.full_like(num_orig_input_tokens_tensor, + self.max_position_embeddings)) idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) self.long_short_cos_sin_cache: torch.Tensor = ( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index e30370596496a..fb7adf3be8c21 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -49,7 +49,7 @@ "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py new file mode 100644 index 0000000000000..d1dbc32f679ab --- /dev/null +++ b/vllm/model_executor/models/phi3.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Adapted from llama.py +"""Inference-only Phi3 model code inherit from Llama.py""" + +from typing import Optional + +import torch +from transformers import Phi3Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.llama import (LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, LlamaModel) + +from .utils import make_layers + + +class Phi3Attention(LlamaAttention): + + def __init__( + self, + config: Phi3Config, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=config.rope_theta, + rope_scaling=config.rope_scaling, + max_position_embeddings=config.max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=prefix) + + self.rope_scaling = config.rope_scaling + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) \ + if self.rope_scaling is None \ + else self.rotary_emb( + positions, + q, + k, + num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Phi3DecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: Phi3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + self.self_attn = Phi3Attention( + config=config, + quant_config=quant_config, + bias=getattr(config, "attention_bias", False) + or getattr(config, "bias", False), + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + +class Phi3Model(LlamaModel): + + def __init__( + self, + config: Phi3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config, + prefix=prefix) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Phi3DecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + + +class Phi3ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Phi3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + + self.model = Phi3Model(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + + if get_pp_group().is_last_rank and config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index afc6fe9844ad6..1a094d78d4474 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -240,7 +240,14 @@ def forward( k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) \ + if getattr(self.config, "rope_scaling", None) is None \ + else self.rotary_emb( + positions, + q, + k, + num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) output, _ = self.dense(attn_output) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 25bc0590c745c..12f7723c4dd49 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -355,7 +355,14 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) \ + if self.rope_scaling is None \ + else self.rotary_emb( + positions, + q, + k, + num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7205b1a7beb8d..032ebab9ee7d6 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -138,6 +138,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -160,6 +162,8 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) + num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * + (seq_len - computed_len)) mm_data = seq_group_metadata.multi_modal_data if mm_data: @@ -196,6 +200,9 @@ def _prepare_prompt( input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore + num_orig_input_tokens_tensor = torch.tensor( + num_orig_input_tokens_list, dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -204,6 +211,7 @@ def _prepare_prompt( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=torch.tensor([]), + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, max_decode_seq_len=0, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, @@ -224,6 +232,8 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -242,6 +252,7 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) + num_orig_input_tokens_list.append(seq_data.get_prompt_len()) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -267,6 +278,9 @@ def _prepare_decode( input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) + num_orig_input_tokens_tensor = torch.tensor( + num_orig_input_tokens_list, dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) @@ -291,6 +305,7 @@ def _prepare_decode( num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) return ( input_tokens, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a3c99a45b149..a241f6551b923 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -179,6 +179,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.num_orig_input_tokens_list[0].clear() # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore @@ -205,6 +206,8 @@ def __init__( input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + # The number of original input tokens of each sequence + num_orig_input_tokens_list: Optional[List[List[int]]] = None, # The sequence length (may be capped to the sliding window). seq_lens: Optional[List[int]] = None, # The original sequence length (before applying sliding window). @@ -264,6 +267,13 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + if num_orig_input_tokens_list: + self.num_orig_input_tokens_list = \ + num_orig_input_tokens_list + else: + for seq_id in range(len(self.seq_ids)): + self.num_orig_input_tokens_list[seq_id].clear() + if seq_lens: self.seq_lens = seq_lens else: @@ -325,6 +335,8 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.num_orig_input_tokens_list = \ + num_orig_input_tokens_list or [] self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] self.query_lens = query_lens or [] @@ -355,6 +367,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.num_orig_input_tokens_list = [[] for _ in range(self.n_seqs)] self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs @@ -488,6 +501,13 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.input_positions[seq_idx].extend( range(context_len, seq_len)) + if (seq_len - context_len) == 1: + inter_data.num_orig_input_tokens_list[seq_idx].append( + seq_data.get_prompt_len()) + else: + inter_data.num_orig_input_tokens_list[seq_idx].extend( + [seq_data.get_prompt_len()] * (seq_len - context_len)) + inter_data.query_lens[ seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 @@ -531,9 +551,10 @@ def _compute_for_prefix_cache_hit( inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] - context_len = prefix_cache_len - + seq_idx][context_len:] + inter_data.num_orig_input_tokens_list[ + seq_idx] = inter_data.num_orig_input_tokens_list[seq_idx][ + context_len:] inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len @@ -687,6 +708,13 @@ def build(self) -> ModelInputForGPU: for cur_input_positions in inter_data.input_positions: input_positions.extend(cur_input_positions) + num_orig_input_tokens_list = [] + for inter_data in self.inter_data_list: + for cur_num_orig_input_tokens_list \ + in inter_data.num_orig_input_tokens_list: + num_orig_input_tokens_list.extend( + cur_num_orig_input_tokens_list) + seq_lens = [] max_decode_seq_len = 0 for inter_data in self.inter_data_list: @@ -723,6 +751,9 @@ def build(self) -> ModelInputForGPU: if cuda_graph_pad_size: input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) + num_orig_input_tokens_list.extend( + itertools.repeat(0, cuda_graph_pad_size)) + assert self.runner.device is not None input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, @@ -737,7 +768,8 @@ def build(self) -> ModelInputForGPU: # Attention metadata. attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) + seq_lens, query_lens, num_orig_input_tokens_list, + cuda_graph_pad_size, batch_size) # LoRA data. lora_requests = set() diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a0498315516b8..158482f2ef19e 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -164,6 +164,10 @@ def _dummy_run( position_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=self.device) + num_orig_input_tokens_tensor = torch.full((batch_size, seq_len), + seq_len, + dtype=torch.int32, + device=self.device) slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) @@ -172,6 +176,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=None, context_lens=None, ) @@ -186,6 +191,9 @@ def _dummy_run( position_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=self.device) + num_orig_input_tokens_tensor = torch.ones((batch_size, seq_len), + dtype=torch.int32, + device=self.device) slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) @@ -204,6 +212,7 @@ def _dummy_run( num_prefill_tokens=0, num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=block_tables, context_lens=context_lens, ) @@ -291,6 +300,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] prompt_lens: List[int] = [] slot_mapping: List[int] = [] @@ -308,6 +319,8 @@ def _prepare_prompt( input_tokens.extend(prompt_tokens) input_positions.extend(list(range(prompt_len))) + num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * + prompt_len) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] @@ -327,6 +340,7 @@ def _prepare_prompt( num_paddings = padded_prompt_len - prompt_len input_tokens += [0] * num_paddings input_positions += [0] * num_paddings + num_orig_input_tokens_list += [0] * num_paddings slot_mapping += [_PAD_SLOT_ID] * num_paddings assert len(prompt_lens) > 0 @@ -337,6 +351,9 @@ def _prepare_prompt( input_positions = torch.tensor(input_positions, dtype=torch.int32, device="cpu") + num_orig_input_tokens_tensor = torch.tensor( + num_orig_input_tokens_list, dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="cpu") @@ -348,6 +365,7 @@ def _prepare_prompt( num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=None, context_lens=None, ) @@ -360,6 +378,8 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] @@ -375,6 +395,7 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append([position]) + num_orig_input_tokens_list.append([seq_data.get_prompt_len()]) context_lens.append(seq_len) assert seq_group_metadata.block_tables is not None @@ -391,6 +412,9 @@ def _prepare_decode( num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings input_positions = input_positions + [[0]] * num_paddings + num_orig_input_tokens_list = num_orig_input_tokens_list + [[ + 0 + ]] * num_paddings slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings context_lens = context_lens + [0] * num_paddings @@ -400,6 +424,9 @@ def _prepare_decode( input_positions = torch.tensor(input_positions, dtype=torch.int32, device="cpu") + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device="cpu") slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="cpu") @@ -419,6 +446,7 @@ def _prepare_decode( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) return input_tokens, input_positions, attn_metadata, input_lens diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f9037625d4af9..552bf268c97f7 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -156,6 +156,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -179,6 +181,9 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) + num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * + (seq_len - computed_len)) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -215,6 +220,9 @@ def _prepare_prompt( input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore + num_orig_input_tokens_tensor = torch.tensor( + num_orig_input_tokens_list, dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -237,6 +245,7 @@ def _prepare_prompt( num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -251,6 +260,8 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -269,6 +280,7 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) + num_orig_input_tokens_list.append(seq_data.get_prompt_len()) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -300,6 +312,9 @@ def _prepare_decode( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=self.device) block_tables = make_tensor_with_pad( block_tables, @@ -320,6 +335,7 @@ def _prepare_decode( num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) return ( input_tokens,