From bc5ac0b24382a93bc4fad7a9ca15db21af0e7b03 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 20 May 2024 10:58:13 -0700 Subject: [PATCH] :sparkles: log all errored requests (#30) This PR logs all errors during validation or generation for a request like TGIS does. Signed-off-by: Joe Runde --- vllm/entrypoints/grpc/grpc_server.py | 72 +++++++++----------- vllm/tgis_utils/logs.py | 99 ++++++++++++++++++++++++---- 2 files changed, 115 insertions(+), 56 deletions(-) diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index bf85db877..9aead23aa 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -49,19 +49,34 @@ def with_default(value: Any, default: Any) -> Any: async def _handle_exception(e: Exception, func, *args, **kwargs): - # We don't log AbortErrors since these correspond to gRPC errors - # intentionally raised during handling of requests. + context = kwargs.get("context", None) or args[-1] + is_generate_fn = "generate" in func.__name__.lower() + + # First just try to replicate the TGIS-style log messages + # for generate_* rpcs + if is_generate_fn: + if isinstance(e, AbortError): + # For things that we've already aborted, the relevant error + # string is already in the grpc context. + error_message = context.details() + else: + error_message = str(e) + request = kwargs.get("request", None) or args[-2] + logs.log_error(request=request, + exception_str=error_message, + logger=logger) + + # AbortErrors likely correspond to things we've already explicitly handled, + # So we only add special handling for other types of errors if not isinstance(e, AbortError): if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check - context = kwargs.get("context", None) or args[-1] logger.exception("%s caused GPU OOM error", func.__name__) service_metrics.count_request_failure(FailureReasonLabel.OOM) await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e)) + elif is_generate_fn: + service_metrics.count_request_failure(FailureReasonLabel.GENERATE) else: - if "generate" in func.__name__.lower(): - service_metrics.count_request_failure(FailureReasonLabel.GENERATE) - else: - service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN) + service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN) logger.exception("%s failed", func.__name__) raise e @@ -173,14 +188,9 @@ async def Generate(self, request: BatchedGenerationRequest, response = self._convert_input_details(res, resp_options, sampling_params, response) - if request_count == 1: - kind_log = "Request" - else: - kind_log = f"Sub-request {i} from batch of {request_count}" - - self._log_unary_response(request=request, response=response, - start_time=start_time, engine_response=res, - kind_log=kind_log) + logs.log_response(request=request, response=response, + start_time=start_time, engine_metrics=res.metrics, + sub_request_num=i, logger=logger) responses[i] = response return BatchedGenerationResponse(responses=responses) @@ -254,9 +264,11 @@ async def GenerateStream( return first_response.text = full_output first_response.generated_token_count = last_token_count - self._log_streaming_response(request=request, response=first_response, - start_time=start_time, - engine_response=last_engine_response) + logs.log_response(request=request, response=first_response, + start_time=start_time, + engine_metrics=last_engine_response.metrics + if last_engine_response else None, + logger=logger) def _convert_input_details( self, result: RequestOutput, resp_options: ResponseOptions, @@ -538,30 +550,6 @@ async def _validate_prompt_and_tokenize( return input_ids, max_is_token_limit - @staticmethod - def _log_unary_response(request: BatchedGenerationRequest, - response: GenerationResponse, - engine_response: RequestOutput, - start_time: float, kind_log: str): - logs.log_response(inputs=[r.text for r in request.requests], - response=response, params=request.params, - prefix_id=request.prefix_id, - engine_response=engine_response, - start_time=start_time, kind_log=kind_log, - method_str="generate", logger=logger) - - @staticmethod - def _log_streaming_response(request: SingleGenerationRequest, - response: GenerationResponse, - engine_response: RequestOutput, - start_time: float): - logs.log_response(inputs=[request.request.text], response=response, - params=request.params, prefix_id=request.prefix_id, - engine_response=engine_response, - start_time=start_time, kind_log="Streaming response", - method_str="generate_stream", logger=logger) - - @log_rpc_handler_errors async def Tokenize(self, request: BatchedTokenizeRequest, context: ServicerContext) -> BatchedTokenizeResponse: diff --git a/vllm/tgis_utils/logs.py b/vllm/tgis_utils/logs.py index 6cf81acca..7da32ad00 100644 --- a/vllm/tgis_utils/logs.py +++ b/vllm/tgis_utils/logs.py @@ -1,26 +1,97 @@ """Some methods for producing logs similar to TGIS""" import logging -from typing import List +from typing import List, Optional, Union from google.protobuf import text_format -from vllm import RequestOutput -from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse, - Parameters, StopReason) +from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, + GenerationResponse, + Parameters, + SingleGenerationRequest, + StopReason) +from vllm.sequence import RequestMetrics -def log_response(inputs: List[str], params: Parameters, prefix_id: str, - response: GenerationResponse, engine_response: RequestOutput, - start_time: float, kind_log: str, method_str: str, - logger: logging.Logger): +def log_response( + request: Union[BatchedGenerationRequest, SingleGenerationRequest], + response: GenerationResponse, + engine_metrics: Optional[RequestMetrics], + start_time: float, + logger: logging.Logger, + sub_request_num: int = 0, +): + if isinstance(request, BatchedGenerationRequest): + # unary case + request_count = len(request.requests) + if request_count == 1: + kind_log = "Request" + else: + kind_log = (f"Sub-request {sub_request_num} from batch of " + f"{request_count}") + inputs = [r.text for r in request.requests] + method_str = "generate" + else: + # streaming case + inputs = [request.request.text] + kind_log = "Streaming response" + method_str = "generate_stream" + + _log_response( + inputs=inputs, + response=response, + params=request.params, + prefix_id=request.prefix_id, + engine_metrics=engine_metrics, + start_time=start_time, + kind_log=kind_log, + method_str=method_str, + logger=logger, + ) + + +def log_error(request: Union[BatchedGenerationRequest, + SingleGenerationRequest], exception_str: str, + logger: logging.Logger): + """Logs errors similar to how the TGIS server does""" + # NB: We don't actually log the `Exception` here to match the TGIS behavior + # of just logging the simple string representation of the error + param_str = text_format.MessageToString(request.params, as_one_line=True) + prefix_id = request.prefix_id + + if isinstance(request, BatchedGenerationRequest): + method_str = "generate" + inputs = [r.text for r in request.requests] + else: + method_str = "generate_stream" + inputs = [request.request.text] + + short_input = [_truncate(input_, 32) for input_ in inputs] + input_chars = sum(len(input_) for input_ in inputs) + + span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} " + f"input_chars=[{input_chars}] params={param_str}") + + logger.error("%s: %s", span_str, exception_str) + + +def _log_response(inputs: List[str], params: Parameters, prefix_id: str, + response: GenerationResponse, + engine_metrics: Optional[RequestMetrics], start_time: float, + kind_log: str, method_str: str, logger: logging.Logger): """Logs responses similar to how the TGIS server does""" # This time contains both request validation and tokenization - tokenization_time = engine_response.metrics.arrival_time - start_time - inference_time = (engine_response.metrics.last_token_time - - engine_response.metrics.first_scheduled_time) - queue_time = engine_response.metrics.time_in_queue - time_per_token = _safe_div(inference_time, response.generated_token_count) - total_time = engine_response.metrics.last_token_time - start_time + if engine_metrics is not None: + tokenization_time = engine_metrics.arrival_time - start_time + inference_time = (engine_metrics.last_token_time - + engine_metrics.first_scheduled_time) + queue_time = engine_metrics.time_in_queue + time_per_token = _safe_div(inference_time, + response.generated_token_count) + total_time = engine_metrics.last_token_time - start_time + else: + logger.warning("No engine metrics for request, cannot log timing info") + tokenization_time = inference_time = queue_time = time_per_token =\ + total_time = 0 output_len = len(response.text) short_output = _truncate(response.text, 32) short_input = [_truncate(input_, 32) for input_ in inputs]