diff --git a/src/griffe_fastapi/_extension.py b/src/griffe_fastapi/_extension.py index 64ddeb9..c29358c 100644 --- a/src/griffe_fastapi/_extension.py +++ b/src/griffe_fastapi/_extension.py @@ -4,6 +4,8 @@ from typing import Any from griffe import ( Decorator, + Docstring, + DocstringSectionText, ExprAttribute, ExprCall, ExprDict, @@ -84,14 +86,16 @@ def _process_responses( class FastAPIExtension(Extension): - def __init__(self, *, paths: list[str] | None = None): + def __init__(self, *, paths: list[str] | None = None, generate_table: bool = True): """Initialize the extension. - When paths are set, the extension will only process the modules of the given - path. + Args: + paths: A list of paths to select api functions + generate_table: Generate the table at the end of the function docstring? """ super().__init__() self._paths = paths or [] + self._generate_table = generate_table def on_function_instance( self, @@ -160,3 +164,20 @@ def on_function_instance( func.extra[self_namespace]["responses"] = {} for key, value in resolved_responses.items(): _process_responses(func, key, value) + + if self._generate_table: + table = [ + "| Status | Description |", + "|--------|-------------|", + ] + for http_code, response in func.extra[self_namespace]["responses"].items(): + table.append(f"| {http_code} | {response['description']} |") + if not func.docstring: + func.docstring = Docstring("", parent=func) + sections = func.docstring.parsed + sections.append( + DocstringSectionText( + f"This api can return the following HTTP codes:\n\n{table}", + title="HTTP Responses", + ) + ) diff --git a/tests/test_extension.py b/tests/test_extension.py index 00b11a2..1d26edf 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -12,14 +12,15 @@ def test_extension() -> None: @router.get("/", responses={200:{"description": "Ok"}}) def get_teams() -> list[str]: - '''Get the teams.''' + '''Get the teams. + ''' return [] """ with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_teams"] @@ -45,7 +46,7 @@ def get_teams() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_teams"] @@ -74,7 +75,7 @@ def get_teams() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_teams"] @@ -104,7 +105,7 @@ def get_image() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_image"] @@ -137,7 +138,7 @@ def get_teams() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_teams"] @@ -168,7 +169,7 @@ def get_teams() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension()), + extensions=Extensions(FastAPIExtension(generate_table=False)), ) as package: assert package assert package.functions["get_teams"] @@ -195,8 +196,41 @@ def get_teams() -> list[str]: with temporary_visited_package( "package", modules={"__init__.py": code}, - extensions=Extensions(FastAPIExtension(paths=["package"])), + extensions=Extensions( + FastAPIExtension(paths=["package"], generate_table=False) + ), ) as package: assert package assert package.functions["get_teams"] assert package.functions["get_teams"].extra["griffe_fastapi"] is not None + + +def test_extension_with_table() -> None: + code = """ + from fastapi import ApiRouter + + router = APIRouter() + + @router.get("/", responses={200:{"description": "Ok"}}) + def get_teams() -> list[str]: + '''Get the teams. + ''' + return [] + """ + + with temporary_visited_package( + "package", + modules={"__init__.py": code}, + extensions=Extensions(FastAPIExtension()), + ) as package: + assert package + assert package.functions["get_teams"] + assert package.functions["get_teams"].extra is not None + assert "griffe_fastapi" in package.functions["get_teams"].extra + extra = package.functions["get_teams"].extra["griffe_fastapi"] + assert extra["method"] == "get" + assert extra["responses"]["200"]["description"] == "Ok" + assert ( + package.functions["get_teams"].docstring.parsed[1].value + == "This api can return the following HTTP codes:\n\n['| Status | Description |', '|--------|-------------|', '| 200 | Ok |']" + )