Skip to content

Commit

Permalink
✨ log all errored requests
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed May 17, 2024
1 parent 7a0007c commit c0973fd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 49 deletions.
48 changes: 13 additions & 35 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
# We don't log AbortErrors since these correspond to gRPC errors
# intentionally raised during handling of requests.
if not isinstance(e, AbortError):
# try to replicate TGIS logs for when errors occur
if "generate" in func.__name__.lower():
request = kwargs.get("request", None) or args[-2]
logs.log_error(request=request, exception=e, logger=logger)

if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
logger.exception("%s caused GPU OOM error", func.__name__)
Expand Down Expand Up @@ -173,14 +178,9 @@ async def Generate(self, request: BatchedGenerationRequest,
response = self._convert_input_details(res, resp_options,
sampling_params,
response)
if request_count == 1:
kind_log = "Request"
else:
kind_log = f"Sub-request {i} from batch of {request_count}"

self._log_unary_response(request=request, response=response,
start_time=start_time, engine_response=res,
kind_log=kind_log)
logs.log_response(request=request, response=response,
start_time=start_time, engine_metrics=res.metrics,
sub_request_num=i, logger=logger)
responses[i] = response

return BatchedGenerationResponse(responses=responses)
Expand Down Expand Up @@ -254,9 +254,11 @@ async def GenerateStream(
return
first_response.text = full_output
first_response.generated_token_count = last_token_count
self._log_streaming_response(request=request, response=first_response,
start_time=start_time,
engine_response=last_engine_response)
logs.log_response(request=request, response=first_response,
start_time=start_time,
engine_metrics=last_engine_response.metrics
if last_engine_response else None,
logger=logger)

def _convert_input_details(
self, result: RequestOutput, resp_options: ResponseOptions,
Expand Down Expand Up @@ -538,30 +540,6 @@ async def _validate_prompt_and_tokenize(

return input_ids, max_is_token_limit

@staticmethod
def _log_unary_response(request: BatchedGenerationRequest,
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float, kind_log: str):
logs.log_response(inputs=[r.text for r in request.requests],
response=response, params=request.params,
prefix_id=request.prefix_id,
engine_response=engine_response,
start_time=start_time, kind_log=kind_log,
method_str="generate", logger=logger)

@staticmethod
def _log_streaming_response(request: SingleGenerationRequest,
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float):
logs.log_response(inputs=[request.request.text], response=response,
params=request.params, prefix_id=request.prefix_id,
engine_response=engine_response,
start_time=start_time, kind_log="Streaming response",
method_str="generate_stream", logger=logger)


@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
Expand Down
101 changes: 87 additions & 14 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,99 @@
"""Some methods for producing logs similar to TGIS"""
import logging
from typing import List
from typing import List, Optional, Union

from google.protobuf import text_format

from vllm import RequestOutput
from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
Parameters, StopReason)
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
GenerationResponse,
Parameters,
SingleGenerationRequest,
StopReason)
from vllm.sequence import RequestMetrics


def log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse, engine_response: RequestOutput,
start_time: float, kind_log: str, method_str: str,
logger: logging.Logger):
def log_response(
request: Union[BatchedGenerationRequest, SingleGenerationRequest],
response: GenerationResponse,
engine_metrics: Optional[RequestMetrics],
start_time: float,
logger: logging.Logger,
sub_request_num: int = 0,
):
if isinstance(request, BatchedGenerationRequest):
# unary case
request_count = len(request.requests)
if request_count == 1:
kind_log = "Request"
else:
kind_log = (f"Sub-request {sub_request_num} from batch of "
f"{request_count}")
_log_response(inputs=[r.text for r in request.requests],
response=response,
params=request.params,
prefix_id=request.prefix_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
method_str="generate",
logger=logger)
else:
# streaming case
_log_response(inputs=[request.request.text],
response=response,
params=request.params,
prefix_id=request.prefix_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log="Streaming response",
method_str="generate_stream",
logger=logger)


def log_error(request: Union[BatchedGenerationRequest,
SingleGenerationRequest], exception: Exception,
logger: logging.Logger):
"""Logs errors similar to how the TGIS server does"""
params = request.params
paramstr = text_format.MessageToString(params, as_one_line=True)
prefix_id = request.prefix_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
inputs = [r.text for r in request.requests]
else:
method_str = "generate_stream"
inputs = [request.request.text]

short_input = [_truncate(input_, 32) for input_ in inputs]
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={paramstr} ")

# Using %s to format the exception to only print the exception's message
# like TGIS does. (This is intentionally not using exc_info=True)
logger.error("%s: %s", span_str, exception)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
# This time contains both request validation and tokenization
tokenization_time = engine_response.metrics.arrival_time - start_time
inference_time = (engine_response.metrics.last_token_time -
engine_response.metrics.first_scheduled_time)
queue_time = engine_response.metrics.time_in_queue
time_per_token = _safe_div(inference_time, response.generated_token_count)
total_time = engine_response.metrics.last_token_time - start_time
if engine_metrics is not None:
tokenization_time = engine_metrics.arrival_time - start_time
inference_time = (engine_metrics.last_token_time -
engine_metrics.first_scheduled_time)
queue_time = engine_metrics.time_in_queue
time_per_token = _safe_div(inference_time,
response.generated_token_count)
total_time = engine_metrics.last_token_time - start_time
else:
logger.warning("No engine metrics for request, cannot log timing info")
tokenization_time = inference_time = queue_time = time_per_token =\
total_time = 0
output_len = len(response.text)
short_output = _truncate(response.text, 32)
short_input = [_truncate(input_, 32) for input_ in inputs]
Expand Down

0 comments on commit c0973fd

Please sign in to comment.