Skip to content

Commit

Permalink
feat: add generate_table option
Browse files Browse the repository at this point in the history
  • Loading branch information
zumuta committed Nov 19, 2024
1 parent f746230 commit f22e6ec
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 11 deletions.
27 changes: 24 additions & 3 deletions src/griffe_fastapi/_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any
from griffe import (
Decorator,
Docstring,
DocstringSectionText,
ExprAttribute,
ExprCall,
ExprDict,
Expand Down Expand Up @@ -84,14 +86,16 @@ def _process_responses(


class FastAPIExtension(Extension):
def __init__(self, *, paths: list[str] | None = None):
def __init__(self, *, paths: list[str] | None = None, generate_table: bool = True):
"""Initialize the extension.
When paths are set, the extension will only process the modules of the given
path.
Args:
paths: A list of paths to select api functions
generate_table: Generate the table at the end of the function docstring?
"""
super().__init__()
self._paths = paths or []
self._generate_table = generate_table

def on_function_instance(
self,
Expand Down Expand Up @@ -160,3 +164,20 @@ def on_function_instance(
func.extra[self_namespace]["responses"] = {}
for key, value in resolved_responses.items():
_process_responses(func, key, value)

if self._generate_table:
table = [
"| Status | Description |",
"|--------|-------------|",
]
for http_code, response in func.extra[self_namespace]["responses"].items():
table.append(f"| {http_code} | {response['description']} |")
if not func.docstring:
func.docstring = Docstring("", parent=func)
sections = func.docstring.parsed
sections.append(
DocstringSectionText(
f"This api can return the following HTTP codes:\n\n{table}",
title="HTTP Responses",
)
)
50 changes: 42 additions & 8 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ def test_extension() -> None:
@router.get("/", responses={200:{"description": "Ok"}})
def get_teams() -> list[str]:
'''Get the teams.'''
'''Get the teams.
'''
return []
"""

with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_teams"]
Expand All @@ -45,7 +46,7 @@ def get_teams() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_teams"]
Expand Down Expand Up @@ -74,7 +75,7 @@ def get_teams() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_teams"]
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_image() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_image"]
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_teams() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_teams"]
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_teams() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
extensions=Extensions(FastAPIExtension(generate_table=False)),
) as package:
assert package
assert package.functions["get_teams"]
Expand All @@ -195,8 +196,41 @@ def get_teams() -> list[str]:
with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension(paths=["package"])),
extensions=Extensions(
FastAPIExtension(paths=["package"], generate_table=False)
),
) as package:
assert package
assert package.functions["get_teams"]
assert package.functions["get_teams"].extra["griffe_fastapi"] is not None


def test_extension_with_table() -> None:
code = """
from fastapi import ApiRouter
router = APIRouter()
@router.get("/", responses={200:{"description": "Ok"}})
def get_teams() -> list[str]:
'''Get the teams.
'''
return []
"""

with temporary_visited_package(
"package",
modules={"__init__.py": code},
extensions=Extensions(FastAPIExtension()),
) as package:
assert package
assert package.functions["get_teams"]
assert package.functions["get_teams"].extra is not None
assert "griffe_fastapi" in package.functions["get_teams"].extra
extra = package.functions["get_teams"].extra["griffe_fastapi"]
assert extra["method"] == "get"
assert extra["responses"]["200"]["description"] == "Ok"
assert (
package.functions["get_teams"].docstring.parsed[1].value
== "This api can return the following HTTP codes:\n\n['| Status | Description |', '|--------|-------------|', '| 200 | Ok |']"
)

0 comments on commit f22e6ec

Please sign in to comment.