Skip to content

Commit

Permalink
[V1][3/N] API Server: Reduce Task Switching + Handle Abort Properly (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat authored Dec 27, 2024
1 parent eb881ed commit 1b875a0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 153 deletions.
159 changes: 62 additions & 97 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
Expand Down Expand Up @@ -54,10 +53,8 @@ def __init__(
lora_config=vllm_config.lora_config)
self.tokenizer.ping()

# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Request streams (map of request_id -> queue).
self.rid_to_queue: Dict[str, asyncio.Queue] = {}

# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
Expand Down Expand Up @@ -153,14 +150,13 @@ async def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM."""

if self.detokenizer.is_request_active(request_id):
raise ValueError(f"Request {request_id} already exists.")

# 1) Create a new AsyncStream for the request.
stream = self._add_request_to_streams(request_id)
# 1) Create a new output queue for the request.
if request_id in self.rid_to_queue:
raise ValueError(f"Request id {request_id} already running.")
self.rid_to_queue[request_id] = asyncio.Queue()

# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
Expand All @@ -173,8 +169,10 @@ async def add_request(
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)

# 5) Return the generator.
return stream.generator()
if self.log_requests:
logger.info("Added request %s.", request_id)

return self.rid_to_queue[request_id]

# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
Expand All @@ -194,7 +192,7 @@ async def generate(
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
Expand All @@ -206,94 +204,58 @@ async def generate(
returning the RequestOutput back to the caller.
"""

# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())

async for output in await self.add_request(
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())

q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield output

def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish()

def _add_request_to_streams(
self,
request_id: str,
) -> AsyncStream:

if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")

# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream

if self.log_requests:
logger.info("Added request %s.", request_id)
)

return stream

async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user disconnecting.
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.
As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""

# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()

# Remove from Detokenizer.
self.detokenizer.abort_requests(reqs_to_abort)

# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)

# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
while True:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()

# Note: both Detokenizer and EngineCore handle their
# own request cleanup based on finished.
if out.finished:
del self.rid_to_queue[request_id]
yield out
break

yield out

# If the request is disconnected by the client, the
# generate() task will be canceled. So, we abort the
# request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
raise

def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams."""
"""Process outputs by putting them into per-request queues."""

for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams

# Each request in the API server pulls from the per-request stream.
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)

# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)
# Note: it is possible a request was aborted and removed from
# the state due to client cancellations, so if we encounter a
# request id not in the state, we skip.
if request_id in self.rid_to_queue:
self.rid_to_queue[request_id].put_nowait(request_output)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
Expand All @@ -306,24 +268,27 @@ async def _run_output_handler(self):
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

# 3) Put the RequestOutputs into the per-request AsyncStreams.
# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)

# 5) Abort any requests due to client cancellations.
await self._process_cancellations()

except BaseException as e:
logger.error(e)
raise e

# TODO: can we eliminate these?

async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used.
raise ValueError("Not Supported on V1 yet.")
"""Abort RequestId in self, detokenizer, and engine core."""

request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)

# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]

def encode(
self,
Expand Down
55 changes: 0 additions & 55 deletions vllm/v1/engine/async_stream.py

This file was deleted.

2 changes: 1 addition & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000
LOGGING_TIME_S = POLLING_TIMEOUT_S


class EngineCore:
Expand Down

0 comments on commit 1b875a0

Please sign in to comment.