Skip to content

Commit

Permalink
[Performance] Optimize e2e overheads: Reduce python allocations (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat authored Aug 9, 2024
1 parent 73388c0 commit e02ac55
Show file tree
Hide file tree
Showing 11 changed files with 550 additions and 125 deletions.
6 changes: 5 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,11 @@ def _add_seq_group(
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):

def add_slot(i):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)

if start_idx == 0 and (seq_len - context_len) == 1:
# Optimization for common-case of decoding next token
add_slot(seq_len - 1)
else:
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
add_slot(i)


TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')

Expand Down
48 changes: 45 additions & 3 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Token blocks."""
from typing import List
from typing import List, Optional

from vllm.utils import Device

Expand Down Expand Up @@ -37,5 +37,47 @@ def __repr__(self) -> str:
f'computed={self.computed})')


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockTable:
"""Holds a list of blocks with caching of their associated block_ids
"""

def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
self._blocks: List[PhysicalTokenBlock] = []
self._block_ids: List[int] = []

if blocks is not None:
for block in blocks:
self.append(block)

def append(self, block: PhysicalTokenBlock):
self._blocks.append(block)
self._block_ids.append(block.block_number)

def __len__(self) -> int:
return len(self._blocks)

def __getitem__(self, key):
return self._blocks[key]

def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
self._blocks[key] = blocks
self._block_ids[key] = [b.block_number for b in blocks]
else:
block = value
self._blocks[key] = block
self._block_ids[key] = block.block_number

def reset(self):
self._blocks = []
self._block_ids = []

def copy(self) -> "BlockTable":
return BlockTable(self._blocks)

def list(self) -> List[PhysicalTokenBlock]:
return self._blocks

def ids(self) -> List[int]:
return self._block_ids
27 changes: 15 additions & 12 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
self.num_blocks = num_blocks

# Initialize the free blocks.
self.free_blocks: BlockTable = []
self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
Expand Down Expand Up @@ -256,6 +256,7 @@ def __init__(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}

# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
Expand Down Expand Up @@ -299,7 +300,7 @@ def _allocate_sequence(self, \
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks

block_table: BlockTable = []
block_table: BlockTable = BlockTable()
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
Expand All @@ -326,15 +327,19 @@ def allocate(self, seq_group: SequenceGroup) -> None:
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
seq = wait_seqs[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)

# Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
if len(wait_seqs) == 1:
self.block_tables[wait_seqs[0].seq_id] = block_table
else:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()

# Allocate encoder sequence
if is_encoder_decoder:
Expand Down Expand Up @@ -476,6 +481,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()

# When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
Expand Down Expand Up @@ -527,7 +533,7 @@ def _swap_block_table(
dest_allocator: BlockAllocatorBase,
mapping: Dict[PhysicalTokenBlock,
PhysicalTokenBlock]) -> BlockTable:
new_block_table = []
new_block_table: BlockTable = BlockTable()

for from_block in block_table:
if from_block in mapping:
Expand All @@ -553,8 +559,7 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator,
self.gpu_allocator,
self.cpu_allocator, self.gpu_allocator,
mapping)

if seq_group.is_encoder_decoder():
Expand All @@ -580,8 +585,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator,
self.cpu_allocator,
self.gpu_allocator, self.cpu_allocator,
mapping)

if seq_group.is_encoder_decoder():
Expand Down Expand Up @@ -636,8 +640,7 @@ def reset(self) -> None:
self.cross_block_tables.clear()

def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
return self.block_tables[seq.seq_id].ids()

def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
block_table = self.cross_block_tables[seq_group.request_id]
Expand Down
Loading

0 comments on commit e02ac55

Please sign in to comment.