diff --git a/blacksheep/server/openapi/v3.py b/blacksheep/server/openapi/v3.py index 920d041..b607c70 100644 --- a/blacksheep/server/openapi/v3.py +++ b/blacksheep/server/openapi/v3.py @@ -3,8 +3,10 @@ import sys import warnings from abc import ABC, abstractmethod +from collections import OrderedDict, defaultdict from dataclasses import dataclass, fields, is_dataclass from datetime import date, datetime +from decimal import Decimal from enum import Enum, IntEnum from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union @@ -666,6 +668,11 @@ def _get_schema_by_type( if schema: return schema + # Dict, OrderedDict, defaultdict are handled first than GenericAlias + schema = self._try_get_schema_for_mapping(object_type, type_args) + if schema: + return schema + # List, Set, Tuple are handled first than GenericAlias schema = self._try_get_schema_for_iterable(object_type, type_args) if schema: @@ -733,6 +740,44 @@ def _try_get_schema_for_iterable( items=self.get_schema_by_type(item_type, context_type_args), ) + def _try_get_schema_for_mapping( + self, object_type: Type, context_type_args: Optional[Dict[Any, Type]] = None + ) -> Optional[Schema]: + if object_type in {dict, defaultdict, OrderedDict}: + # the user didn't specify the key and value types + return Schema( + type=ValueType.OBJECT, + additional_properties=Schema( + type=ValueType.STRING, + ), + ) + + origin = get_origin(object_type) + + if not origin or origin not in { + dict, + Dict, + collections_abc.Mapping, + }: + return None + + # can be Dict, Dict[str, str] or dict[str, str] (Python 3.9), + # note: it could also be union if it wasn't handled above for dataclasses + try: + _, value_type = object_type.__args__ # type: ignore + except AttributeError: # pragma: no cover + value_type = str + + if context_type_args and value_type in context_type_args: + value_type = context_type_args.get(value_type, value_type) + + return Schema( + type=ValueType.OBJECT, + additional_properties=self.get_schema_by_type( + value_type, context_type_args + ), + ) + def get_fields(self, object_type: Any) -> List[FieldInfo]: for handler in self._object_types_handlers: if handler.handles_type(object_type): diff --git a/tests/test_openapi_v3.py b/tests/test_openapi_v3.py index cbddde4..62094db 100644 --- a/tests/test_openapi_v3.py +++ b/tests/test_openapi_v3.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from datetime import date, datetime from enum import IntEnum -from typing import Generic, List, Optional, Sequence, TypeVar, Union +from typing import Generic, List, Mapping, Optional, Sequence, TypeVar, Union from uuid import UUID import pytest @@ -1513,6 +1513,66 @@ def home() -> Sequence[Cat]: ... ) +@pytest.mark.asyncio +async def test_handling_of_mapping(docs: OpenAPIHandler, serializer: Serializer): + app = get_app() + + @app.router.route("/") + def home() -> Mapping[str, Mapping[int, List[Cat]]]: + ... + + docs.bind_app(app) + await app.start() + + yaml = serializer.to_yaml(docs.generate_documentation(app)) + + assert ( + yaml.strip() + == r""" +openapi: 3.0.3 +info: + title: Example + version: 0.0.1 +paths: + /: + get: + responses: + '200': + description: Success response + content: + application/json: + schema: + type: object + additionalProperties: + type: object + additionalProperties: + type: array + nullable: false + items: + $ref: '#/components/schemas/Cat' + nullable: false + nullable: false + operationId: home +components: + schemas: + Cat: + type: object + required: + - id + - name + properties: + id: + type: integer + format: int64 + nullable: false + name: + type: string + nullable: false +tags: [] +""".strip() + ) + + def test_handling_of_generic_with_forward_references(docs: OpenAPIHandler): with pytest.warns(UserWarning): docs.register_schema_for_type(GenericWithForwardRefExample[Cat])