Skip to content

Commit

Permalink
Fix auto prefix bug (#3239)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola authored Mar 8, 2024
1 parent 8cbba46 commit b35cc93
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
34 changes: 34 additions & 0 deletions tests/engine/test_computed_prefix_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
prompt2 = (
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration.")

engine_args = EngineArgs(model=model,
block_size=block_size,
enable_prefix_caching=True)

engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams()

engine.add_request("0", prompt + prompt2, sampling_params)
engine.step()
engine.add_request("1", prompt, sampling_params)
engine.step()
28 changes: 16 additions & 12 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A block manager that manages token blocks."""
import enum
from itertools import count
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -426,23 +426,29 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_last_full_block_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
block_table[max_full_block].computed = True
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True

def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
for block_idx in reversed(range(len(block_table))):
if block_table[block_idx].computed:
return [b.block_number for b in block_table[:block_idx + 1]]
return []
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]

def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
Expand All @@ -451,14 +457,12 @@ def get_common_computed_block_ids(self,
return []

ids_list = [
self.get_all_block_ids_till_computed(seq)
self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values())
]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)
self.compute_full_blocks_in_seq(seq)
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _prepare_prompt(
slot_mapping[-1].append(slot)

max_prompt_len = max(subquery_lens)
assert max_prompt_len > 0
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
Expand Down

0 comments on commit b35cc93

Please sign in to comment.