Skip to content

Commit

Permalink
fix: ExprAttribute is not hashable, so resolve http code before using…
Browse files Browse the repository at this point in the history
… it as key
  • Loading branch information
zumuta committed Nov 16, 2024
1 parent 94b2f04 commit b6b461a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "griffe-fastapi"
version = "0.1.0"
version = "0.1.1"
description = "Griffe extension for FastAPI."
authors = ["fbraem <[email protected]>"]
readme = "README.md"
Expand Down
33 changes: 16 additions & 17 deletions src/griffe_fastapi/_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,24 @@ def _search_decorator(decorators: list[Decorator]) -> Decorator | None:
return None


def _resolve_http_code(func, http_code: str | ExprAttribute):
if isinstance(http_code, ExprAttribute):
if http_code.canonical_path.startswith("fastapi.status."):
return http_code.last.name.split("_")[1]
logger.warning(
f"Could not resolve http code {http_code.canonical_path} "
f"for function {func.canonical_path}"
)
return
return http_code


def _process_responses(
func: Function,
http_code_attribute: str | ExprAttribute,
http_code: str,
open_api_response: ExprName | ExprDict,
):
"""Process the response code and the response object."""
http_code = None
# When a constant is used, resolve the value
if isinstance(http_code_attribute, ExprAttribute):
if http_code_attribute.canonical_path.startswith("fastapi.status."):
http_code = http_code_attribute.last.name.split("_")[1]
if http_code is None:
logger.warning(
f"Could not resolve http code {http_code_attribute.canonical_path} "
f"for function {func.canonical_path}"
)
return
else:
http_code = http_code_attribute

func.extra[self_namespace]["responses"][http_code] = {
ast.literal_eval(str(key)): ast.literal_eval(str(value))
for key, value in zip(
Expand Down Expand Up @@ -140,15 +138,16 @@ def on_function_instance(
resolved_responses = {
**resolved_responses,
**{
k: v
k: _resolve_http_code(func, v)
for k, v in zip(
module_attribute.value.keys,
module_attribute.value.values,
)
},
}
else:
resolved_responses[http_code_variable] = open_api_response_obj
http_code = _resolve_http_code(func, http_code_variable)
resolved_responses[http_code] = open_api_response_obj

func.extra[self_namespace]["responses"] = {}
for key, value in resolved_responses.items():
Expand Down
38 changes: 32 additions & 6 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def test_extension() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand All @@ -30,9 +30,35 @@ def get_teams() -> list[str]:
assert extra["responses"]["200"]["description"] == "Ok"


def test_extension_with_constant() -> None:
code = """
from fastapi import ApiRouter, status
router = APIRouter()
@router.get("/", responses={status.HTTP_200_OK:{"description": "Ok"}})
def get_teams() -> list[str]:
'''Get the teams.'''
return []
"""

with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
) as package:
assert package
assert package.functions["get_teams"]
assert package.functions["get_teams"].extra is not None
assert "griffe_fastapi" in package.functions["get_teams"].extra
extra = package.functions["get_teams"].extra["griffe_fastapi"]
assert extra["method"] == "get"
assert extra["responses"]["200"]["description"] == "Ok"


def test_extension_with_multiple_responses() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand Down Expand Up @@ -62,7 +88,7 @@ def get_teams() -> list[str]:

def test_extension_with_a_response_with_headers() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand Down Expand Up @@ -93,7 +119,7 @@ def get_image() -> list[str]:

def test_extension_with_a_dict() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand Down Expand Up @@ -125,7 +151,7 @@ def get_teams() -> list[str]:

def test_extension_mixed() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand Down Expand Up @@ -156,7 +182,7 @@ def get_teams() -> list[str]:

def test_with_paths() -> None:
code = """
from fastapi import FastAPI
from fastapi import ApiRouter
router = APIRouter()
Expand Down

0 comments on commit b6b461a

Please sign in to comment.