Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Annotated types in OpenAPIHandler #472

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions blacksheep/server/openapi/v3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections.abc as collections_abc
import inspect
import sys
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, is_dataclass
Expand All @@ -10,6 +11,9 @@
from typing import get_type_hints
from uuid import UUID

if sys.version_info >= (3, 9): # pragma: no cover
from typing import _AnnotatedAlias as AnnotatedAlias

from openapidocs.common import Format
from openapidocs.v3 import (
APIKeySecurity,
Expand Down Expand Up @@ -458,6 +462,30 @@ def get_paths(self, app: Application, path_prefix: str = "") -> Dict[str, PathIt

return own_paths

def get_type_name_for_annotated(
self,
object_type: "AnnotatedAlias",
context_type_args: Optional[Dict[Any, Type]] = None,
) -> str:
"""
This method returns a type name for an annotated type.
"""
assert isinstance(
object_type, AnnotatedAlias
), "This method requires an annotated type"
# Note: by default returns a string respectful of this requirement:
# $ref values must be RFC3986-compliant percent-encoded URIs
# Therefore, a generic that would be expressed in Python: Example[Foo, Bar]
# and C# or TypeScript Example<Foo, Bar>
# Becomes here represented as: ExampleOfFooAndBar
origin = get_origin(object_type)
annotations = object_type.__metadata__
annotations_repr = "And".join(
self.get_type_name(annotation, context_type_args)
for annotation in annotations
)
return f"{self.get_type_name(origin)}Of{annotations_repr}"

def get_type_name_for_generic(
self,
object_type: GenericAlias,
Expand All @@ -484,6 +512,9 @@ def get_type_name(
) -> str:
if context_type_args and object_type in context_type_args:
object_type = context_type_args.get(object_type)
if sys.version_info >= (3, 9): # pragma: no cover
if isinstance(object_type, AnnotatedAlias):
return self.get_type_name_for_annotated(object_type, context_type_args)
if isinstance(object_type, GenericAlias):
return self.get_type_name_for_generic(object_type, context_type_args)
if hasattr(object_type, "__name__"):
Expand Down Expand Up @@ -661,6 +692,12 @@ def _get_schema_by_type(
if schema:
return schema

if sys.version_info >= (3, 9): # pragma: no cover
if isinstance(object_type, AnnotatedAlias):
schema = self._try_get_schema_for_annotated(object_type, type_args)
if schema:
return schema

if isinstance(object_type, GenericAlias):
schema = self._try_get_schema_for_generic(object_type, type_args)
if schema:
Expand Down Expand Up @@ -732,6 +769,24 @@ def get_fields(self, object_type: Any) -> List[FieldInfo]:

return []

def _try_get_schema_for_annotated(
self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None
) -> Optional[Union[Schema, Reference]]:
annotations = [child for child in getattr(object_type, "__metadata__", [])]
if len(annotations) == 1:
return self.get_schema_by_type(annotations[0], context_type_args)
assert (
None not in annotations
), "None is not a valid type for an annotated type with multiple annotations"
schema = Schema(
ValueType.OBJECT,
any_of=[
self.get_schema_by_type(annotation, context_type_args)
for annotation in annotations
],
)
return self._handle_object_type_schema(object_type, context_type_args, schema)

def _try_get_schema_for_generic(
self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None
) -> Optional[Reference]:
Expand Down Expand Up @@ -1006,6 +1061,13 @@ def get_responses(self, handler: Any) -> Optional[Dict[str, ResponseDoc]]:
if data is None:
data = {}

if sys.version_info >= (3, 9): # pragma: no cover
if (
isinstance(return_type, AnnotatedAlias)
and return_type.__metadata__[0] is None
):
return_type = None

if return_type is None:
# the user explicitly marked the request handler as returning None,
# document therefore HTTP 204 No Content
Expand Down
Loading
Loading