Skip to content

Commit

Permalink
✨ log all errored requests (#30)
Browse files Browse the repository at this point in the history
This PR logs all errors during validation or generation
for a request like TGIS does. 

Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed May 20, 2024
1 parent 64b18a7 commit bc5ac0b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 56 deletions.
72 changes: 30 additions & 42 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,34 @@ def with_default(value: Any, default: Any) -> Any:


async def _handle_exception(e: Exception, func, *args, **kwargs):
# We don't log AbortErrors since these correspond to gRPC errors
# intentionally raised during handling of requests.
context = kwargs.get("context", None) or args[-1]
is_generate_fn = "generate" in func.__name__.lower()

# First just try to replicate the TGIS-style log messages
# for generate_* rpcs
if is_generate_fn:
if isinstance(e, AbortError):
# For things that we've already aborted, the relevant error
# string is already in the grpc context.
error_message = context.details()
else:
error_message = str(e)
request = kwargs.get("request", None) or args[-2]
logs.log_error(request=request,
exception_str=error_message,
logger=logger)

# AbortErrors likely correspond to things we've already explicitly handled,
# So we only add special handling for other types of errors
if not isinstance(e, AbortError):
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
logger.exception("%s caused GPU OOM error", func.__name__)
service_metrics.count_request_failure(FailureReasonLabel.OOM)
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
elif is_generate_fn:
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
else:
if "generate" in func.__name__.lower():
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
else:
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
logger.exception("%s failed", func.__name__)
raise e

Expand Down Expand Up @@ -173,14 +188,9 @@ async def Generate(self, request: BatchedGenerationRequest,
response = self._convert_input_details(res, resp_options,
sampling_params,
response)
if request_count == 1:
kind_log = "Request"
else:
kind_log = f"Sub-request {i} from batch of {request_count}"

self._log_unary_response(request=request, response=response,
start_time=start_time, engine_response=res,
kind_log=kind_log)
logs.log_response(request=request, response=response,
start_time=start_time, engine_metrics=res.metrics,
sub_request_num=i, logger=logger)
responses[i] = response

return BatchedGenerationResponse(responses=responses)
Expand Down Expand Up @@ -254,9 +264,11 @@ async def GenerateStream(
return
first_response.text = full_output
first_response.generated_token_count = last_token_count
self._log_streaming_response(request=request, response=first_response,
start_time=start_time,
engine_response=last_engine_response)
logs.log_response(request=request, response=first_response,
start_time=start_time,
engine_metrics=last_engine_response.metrics
if last_engine_response else None,
logger=logger)

def _convert_input_details(
self, result: RequestOutput, resp_options: ResponseOptions,
Expand Down Expand Up @@ -538,30 +550,6 @@ async def _validate_prompt_and_tokenize(

return input_ids, max_is_token_limit

@staticmethod
def _log_unary_response(request: BatchedGenerationRequest,
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float, kind_log: str):
logs.log_response(inputs=[r.text for r in request.requests],
response=response, params=request.params,
prefix_id=request.prefix_id,
engine_response=engine_response,
start_time=start_time, kind_log=kind_log,
method_str="generate", logger=logger)

@staticmethod
def _log_streaming_response(request: SingleGenerationRequest,
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float):
logs.log_response(inputs=[request.request.text], response=response,
params=request.params, prefix_id=request.prefix_id,
engine_response=engine_response,
start_time=start_time, kind_log="Streaming response",
method_str="generate_stream", logger=logger)


@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
Expand Down
99 changes: 85 additions & 14 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,97 @@
"""Some methods for producing logs similar to TGIS"""
import logging
from typing import List
from typing import List, Optional, Union

from google.protobuf import text_format

from vllm import RequestOutput
from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
Parameters, StopReason)
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
GenerationResponse,
Parameters,
SingleGenerationRequest,
StopReason)
from vllm.sequence import RequestMetrics


def log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse, engine_response: RequestOutput,
start_time: float, kind_log: str, method_str: str,
logger: logging.Logger):
def log_response(
request: Union[BatchedGenerationRequest, SingleGenerationRequest],
response: GenerationResponse,
engine_metrics: Optional[RequestMetrics],
start_time: float,
logger: logging.Logger,
sub_request_num: int = 0,
):
if isinstance(request, BatchedGenerationRequest):
# unary case
request_count = len(request.requests)
if request_count == 1:
kind_log = "Request"
else:
kind_log = (f"Sub-request {sub_request_num} from batch of "
f"{request_count}")
inputs = [r.text for r in request.requests]
method_str = "generate"
else:
# streaming case
inputs = [request.request.text]
kind_log = "Streaming response"
method_str = "generate_stream"

_log_response(
inputs=inputs,
response=response,
params=request.params,
prefix_id=request.prefix_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
method_str=method_str,
logger=logger,
)


def log_error(request: Union[BatchedGenerationRequest,
SingleGenerationRequest], exception_str: str,
logger: logging.Logger):
"""Logs errors similar to how the TGIS server does"""
# NB: We don't actually log the `Exception` here to match the TGIS behavior
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
inputs = [r.text for r in request.requests]
else:
method_str = "generate_stream"
inputs = [request.request.text]

short_input = [_truncate(input_, 32) for input_ in inputs]
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
# This time contains both request validation and tokenization
tokenization_time = engine_response.metrics.arrival_time - start_time
inference_time = (engine_response.metrics.last_token_time -
engine_response.metrics.first_scheduled_time)
queue_time = engine_response.metrics.time_in_queue
time_per_token = _safe_div(inference_time, response.generated_token_count)
total_time = engine_response.metrics.last_token_time - start_time
if engine_metrics is not None:
tokenization_time = engine_metrics.arrival_time - start_time
inference_time = (engine_metrics.last_token_time -
engine_metrics.first_scheduled_time)
queue_time = engine_metrics.time_in_queue
time_per_token = _safe_div(inference_time,
response.generated_token_count)
total_time = engine_metrics.last_token_time - start_time
else:
logger.warning("No engine metrics for request, cannot log timing info")
tokenization_time = inference_time = queue_time = time_per_token =\
total_time = 0
output_len = len(response.text)
short_output = _truncate(response.text, 32)
short_input = [_truncate(input_, 32) for input_ in inputs]
Expand Down

0 comments on commit bc5ac0b

Please sign in to comment.