Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move exception handling logic to Route #2026

Merged
merged 31 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6470c9f
Move exception handling logic to endpoints
adriangb Jan 27, 2023
23c918e
lint
adriangb Jan 27, 2023
0737cc9
lint
adriangb Jan 27, 2023
bdb9d68
more checks
adriangb Jan 27, 2023
ce4a3d5
fix tests
adriangb Jan 27, 2023
62807bd
Update exceptions.py
adriangb Jan 28, 2023
5ca404d
reformat
adriangb Jan 28, 2023
9468718
Move wrapper into it's own file so it can be re-used
adriangb Jan 28, 2023
c63c8ff
Merge branch 'master' into cleanup-move-exc
adriangb Feb 6, 2023
7920133
Merge branch 'master' into cleanup-move-exc
adriangb Feb 8, 2023
aa10f20
Merge branch 'master' into cleanup-move-exc
adriangb Feb 13, 2023
2747c9b
Merge branch 'master' into cleanup-move-exc
adriangb Feb 21, 2023
b9f34f8
Merge branch 'master' into cleanup-move-exc
adriangb Mar 2, 2023
d6d8622
Merge branch 'master' into cleanup-move-exc
adriangb Mar 5, 2023
58138d1
Merge branch 'master' into cleanup-move-exc
adriangb Mar 10, 2023
6d1e6fd
Merge branch 'master' into cleanup-move-exc
adriangb Mar 11, 2023
25b6256
Merge branch 'master' into cleanup-move-exc
Kludex Mar 17, 2023
890dcc3
Merge branch 'master' into cleanup-move-exc
adriangb Mar 20, 2023
2d144c3
Merge branch 'master' into cleanup-move-exc
adriangb Mar 31, 2023
4e20aeb
Merge branch 'master' into cleanup-move-exc
Kludex Apr 14, 2023
6ddf952
Merge branch 'master' into cleanup-move-exc
adriangb May 4, 2023
fb379dd
Merge branch 'master' into cleanup-move-exc
adriangb May 5, 2023
cfbb38b
Merge branch 'master' into cleanup-move-exc
Kludex May 27, 2023
1ef9837
Merge branch 'master' into cleanup-move-exc
Kludex Jun 1, 2023
27d8bc8
Merge branch 'master' into cleanup-move-exc
adriangb Jun 1, 2023
3cf3881
rename type alias from ExcHandlers to ExceptionHandlers
adriangb Jun 1, 2023
c437933
Merge branch 'master' into cleanup-move-exc
Kludex Jun 6, 2023
6cc501e
Refactor exception handling structure
Kludex Jun 7, 2023
f354ae6
Remove implementation details from documentation
Kludex Jun 7, 2023
6801432
Add documentation back
Kludex Jun 7, 2023
3fb1118
Move retrieval of handlers to `wrap_app_handling_exceptions`
Kludex Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions starlette/_exception_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

Handler = typing.Callable[..., typing.Any]
ExceptionHandlers = typing.Dict[typing.Any, Handler]
StatusHandlers = typing.Dict[int, Handler]


def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> typing.Optional[Handler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None


def wrap_app_handling_exceptions(
app: ASGIApp, conn: typing.Union[Request, WebSocket]
) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if scope["type"] == "http":
response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)

return wrapped_app
83 changes: 24 additions & 59 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand All @@ -20,12 +23,10 @@ def __init__(
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
WebSocketException: self.websocket_exception, # type: ignore[dict-item]
}
if handlers is not None:
for key, value in handlers.items():
Expand All @@ -42,68 +43,32 @@ def add_exception_handler(
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)

if handler is None:
handler = self._lookup_exception_handler(exc)

if handler is None:
raise exc
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
conn: typing.Union[Request, WebSocket]
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)

if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
23 changes: 16 additions & 7 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import asynccontextmanager
from enum import Enum

from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
Expand Down Expand Up @@ -61,12 +62,16 @@ def request_response(func: typing.Callable) -> ASGIApp:
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
request = Request(scope, receive, send)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)

await wrap_app_handling_exceptions(app, request)(scope, receive, send)

return app

Expand All @@ -79,7 +84,11 @@ def websocket_session(func: typing.Callable) -> ASGIApp:

async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)

await wrap_app_handling_exceptions(app, session)(scope, receive, send)

return app

Expand Down
35 changes: 33 additions & 2 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute


Expand All @@ -28,6 +29,22 @@ def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
await request.body()
raise BadBodyException(422)


async def handler_that_reads_body(
request: Request, exc: BadBodyException
) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send):
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
]
)


app = ExceptionMiddleware(router)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)


@pytest.fixture
Expand Down Expand Up @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None:

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
6 changes: 1 addition & 5 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,13 +1033,9 @@ async def modified_send(msg: Message) -> None:
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers
assert "X-Mounted" in resp.headers


def test_route_repr() -> None:
Expand Down