From 2f711193591843002e29b8a417e931f5e4e50993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 20 Jul 2022 09:42:59 +0200 Subject: [PATCH] Rewrite without BaseHTTPMiddleware --- starlette_prometheus/middleware.py | 54 +++++++++++++++++++----------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/starlette_prometheus/middleware.py b/starlette_prometheus/middleware.py index 02aaaf0..76c98a9 100644 --- a/starlette_prometheus/middleware.py +++ b/starlette_prometheus/middleware.py @@ -1,13 +1,12 @@ +import functools from typing import Tuple import time from prometheus_client import Counter, Gauge, Histogram -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Match from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Message, Receive, Scope, Send REQUESTS = Counter( "starlette_requests_total", "Total count of requests by method and path.", ["method", "path_template"] @@ -34,38 +33,53 @@ ) -class PrometheusMiddleware(BaseHTTPMiddleware): +class PrometheusMiddleware: def __init__(self, app: ASGIApp, filter_unhandled_paths: bool = False) -> None: - super().__init__(app) + self.app = app self.filter_unhandled_paths = filter_unhandled_paths - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + request = Request(scope) method = request.method path_template, is_handled_path = self.get_path_template(request) if self._is_path_filtered(is_handled_path): - return await call_next(request) + await self.app(scope, receive, send) + return REQUESTS_IN_PROGRESS.labels(method=method, path_template=path_template).inc() REQUESTS.labels(method=method, path_template=path_template).inc() before_time = time.perf_counter() + + send = functools.partial( + self.send, send=send, scope=scope, path_template=path_template, before_time=before_time + ) try: - response = await call_next(request) + await self.app(scope, receive, send) except BaseException as e: - status_code = HTTP_500_INTERNAL_SERVER_ERROR EXCEPTIONS.labels(method=method, path_template=path_template, exception_type=type(e).__name__).inc() - raise e from None - else: - status_code = response.status_code - after_time = time.perf_counter() - REQUESTS_PROCESSING_TIME.labels(method=method, path_template=path_template).observe( - after_time - before_time - ) - finally: - RESPONSES.labels(method=method, path_template=path_template, status_code=status_code).inc() - REQUESTS_IN_PROGRESS.labels(method=method, path_template=path_template).dec() + self.write_response_metrics(method, path_template, HTTP_500_INTERNAL_SERVER_ERROR, before_time) + raise + + async def send(self, message: Message, send: Send, scope: Scope, *, path_template: str, before_time: float) -> None: + message_type = message["type"] + if message_type == "http.response.start": + request = Request(scope) + method = request.method + status_code = message["status"] + self.write_response_metrics(method, path_template, status_code, before_time) + + await send(message) - return response + def write_response_metrics(self, method: str, path_template: str, status_code: int, before_time: float) -> None: + after_time = time.perf_counter() + REQUESTS_PROCESSING_TIME.labels(method=method, path_template=path_template).observe(after_time - before_time) + RESPONSES.labels(method=method, path_template=path_template, status_code=status_code).inc() + REQUESTS_IN_PROGRESS.labels(method=method, path_template=path_template).dec() @staticmethod def get_path_template(request: Request) -> Tuple[str, bool]: