Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Forward route class when it's different from APIRoute #53

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions examples/simple.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# mypy: disable-error-code="no-any-return"
# flake8: noqa: A003

from typing import List, Any, Dict
from fastapi import FastAPI, APIRouter
from typing import Any, Callable, Coroutine, Dict, List
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.routing import APIRoute
from pydantic import BaseModel

from fastapi_versionizer.versionizer import Versionizer, api_version
Expand Down Expand Up @@ -36,6 +37,25 @@ def __init__(self) -> None:
self.items: Dict[int, Any] = {}


class CounterDb:
def __init__(self) -> None:
self.counter = 0

def increment(self) -> None:
self.counter += 1

counter_db = CounterDb()

class CounterRoute(APIRoute):
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
counter_db.increment()
return await original_route_handler(request)

return custom_route_handler

db = DB()
app = FastAPI(
title='test',
Expand All @@ -47,7 +67,8 @@ def __init__(self) -> None:
)
users_router = APIRouter(
prefix='/users',
tags=['Users']
tags=['Users'],
route_class=CounterRoute,
)
items_router = APIRouter(
prefix='/items',
Expand Down
2 changes: 2 additions & 0 deletions fastapi_versionizer/versionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ def _add_route_to_router(
version: Tuple[int, int]
) -> None:
kwargs = dict(route.__dict__)
if route.__class__ != APIRoute:
kwargs['route_class_override'] = route.__class__
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you accomplish the same thing with this:

kwargs['route_class_override'] = type(route)

?


deprecated_in_version = getattr(route.endpoint, '_deprecate_in_version', None)
if deprecated_in_version is not None:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi.testclient import TestClient

from unittest import TestCase
from examples.simple import app, versions
from examples.simple import app, counter_db, versions


class TestSimpleExample(TestCase):
Expand Down Expand Up @@ -30,6 +30,8 @@ def test_simple_example(self) -> None:
self.assertEqual(404, test_client.get('/v2/versions').status_code)
self.assertEqual(404, test_client.get('/latest/versions').status_code)

self.assertEqual(0, counter_db.counter)

# versions route
self.assertDictEqual(
{
Expand Down Expand Up @@ -129,6 +131,8 @@ def test_simple_example(self) -> None:
test_client.get('/latest/users/3').json()
)

self.assertEqual(9, counter_db.counter)

# docs
self.assertEqual(200, test_client.get('/swagger').status_code)
self.assertEqual(200, test_client.get('/v1/swagger').status_code)
Expand Down
Loading