Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move exception handling logic to Route #2026

Merged
merged 31 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6470c9f
Move exception handling logic to endpoints
adriangb Jan 27, 2023
23c918e
lint
adriangb Jan 27, 2023
0737cc9
lint
adriangb Jan 27, 2023
bdb9d68
more checks
adriangb Jan 27, 2023
ce4a3d5
fix tests
adriangb Jan 27, 2023
62807bd
Update exceptions.py
adriangb Jan 28, 2023
5ca404d
reformat
adriangb Jan 28, 2023
9468718
Move wrapper into it's own file so it can be re-used
adriangb Jan 28, 2023
c63c8ff
Merge branch 'master' into cleanup-move-exc
adriangb Feb 6, 2023
7920133
Merge branch 'master' into cleanup-move-exc
adriangb Feb 8, 2023
aa10f20
Merge branch 'master' into cleanup-move-exc
adriangb Feb 13, 2023
2747c9b
Merge branch 'master' into cleanup-move-exc
adriangb Feb 21, 2023
b9f34f8
Merge branch 'master' into cleanup-move-exc
adriangb Mar 2, 2023
d6d8622
Merge branch 'master' into cleanup-move-exc
adriangb Mar 5, 2023
58138d1
Merge branch 'master' into cleanup-move-exc
adriangb Mar 10, 2023
6d1e6fd
Merge branch 'master' into cleanup-move-exc
adriangb Mar 11, 2023
25b6256
Merge branch 'master' into cleanup-move-exc
Kludex Mar 17, 2023
890dcc3
Merge branch 'master' into cleanup-move-exc
adriangb Mar 20, 2023
2d144c3
Merge branch 'master' into cleanup-move-exc
adriangb Mar 31, 2023
4e20aeb
Merge branch 'master' into cleanup-move-exc
Kludex Apr 14, 2023
6ddf952
Merge branch 'master' into cleanup-move-exc
adriangb May 4, 2023
fb379dd
Merge branch 'master' into cleanup-move-exc
adriangb May 5, 2023
cfbb38b
Merge branch 'master' into cleanup-move-exc
Kludex May 27, 2023
1ef9837
Merge branch 'master' into cleanup-move-exc
Kludex Jun 1, 2023
27d8bc8
Merge branch 'master' into cleanup-move-exc
adriangb Jun 1, 2023
3cf3881
rename type alias from ExcHandlers to ExceptionHandlers
adriangb Jun 1, 2023
c437933
Merge branch 'master' into cleanup-move-exc
Kludex Jun 6, 2023
6cc501e
Refactor exception handling structure
Kludex Jun 7, 2023
f354ae6
Remove implementation details from documentation
Kludex Jun 7, 2023
6801432
Add documentation back
Kludex Jun 7, 2023
3fb1118
Move retrieval of handlers to `wrap_app_handling_exceptions`
Kludex Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 0 additions & 109 deletions starlette/middleware/exceptions.py

This file was deleted.

76 changes: 76 additions & 0 deletions starlette/middleware/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import typing

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions._wrapper import (
ExcHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExcHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception, # type: ignore[dict-item]
}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)

conn: typing.Union[Request, WebSocket]
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)

await wrap_app_handling_exceptions(
self.app, self._exception_handlers, self._status_handlers, conn
)(scope, receive, send)

def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
72 changes: 72 additions & 0 deletions starlette/middleware/exceptions/_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

Handler = typing.Callable[..., typing.Any]
ExcHandlers = typing.Dict[typing.Any, Handler]
adriangb marked this conversation as resolved.
Show resolved Hide resolved
StatusHandlers = typing.Dict[int, Handler]


def _lookup_exception_handler(
exc_handlers: ExcHandlers, exc: Exception
) -> typing.Optional[Handler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None


def wrap_app_handling_exceptions(
app: ASGIApp,
exc_handlers: ExcHandlers,
status_handlers: StatusHandlers,
conn: typing.Union[Request, WebSocket],
) -> ASGIApp:
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exc_handlers, exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if scope["type"] == "http":
response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)

return wrapped_app
40 changes: 33 additions & 7 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.exceptions._wrapper import wrap_app_handling_exceptions
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
Expand Down Expand Up @@ -61,12 +62,23 @@ def request_response(func: typing.Callable) -> ASGIApp:
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
request = Request(scope, receive, send)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)

try:
exception_handlers, status_handlers = scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

await wrap_app_handling_exceptions(
app, exception_handlers, status_handlers, request
)(scope, receive, send)

return app

Expand All @@ -79,7 +91,21 @@ def websocket_session(func: typing.Callable) -> ASGIApp:

async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)

try:
exception_handlers, status_handlers = scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

await wrap_app_handling_exceptions(
app,
exception_handlers,
status_handlers,
session,
)(scope, receive, send)

return app

Expand Down
35 changes: 33 additions & 2 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute


Expand All @@ -28,6 +29,22 @@ def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
await request.body()
raise BadBodyException(422)


async def handler_that_reads_body(
request: Request, exc: BadBodyException
) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send):
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
]
)


app = ExceptionMiddleware(router)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)


@pytest.fixture
Expand Down Expand Up @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None:

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
6 changes: 1 addition & 5 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,13 +1033,9 @@ async def modified_send(msg: Message) -> None:
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers
assert "X-Mounted" in resp.headers


def test_route_repr() -> None:
Expand Down