diff --git a/src/blacksmith/service/_async/route_proxy.py b/src/blacksmith/service/_async/route_proxy.py index 1a61f7e0..f41c16d6 100644 --- a/src/blacksmith/service/_async/route_proxy.py +++ b/src/blacksmith/service/_async/route_proxy.py @@ -1,12 +1,12 @@ +from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, Type, Union + try: - from types import NoneType, UnionType + from types import UnionType except ImportError: # coverage: ignore # python 3.7 compat - NoneType = type(None) # coverage: ignore - UnionType = object() # coverage: ignore - -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union + UnionType = Union # type: ignore +from pydantic import ValidationError from result import Err, Ok, Result from typing_extensions import get_origin @@ -54,7 +54,7 @@ def build_timeout(timeout: ClientTimeout) -> HTTPTimeout: def is_union(typ: Type[Any]) -> bool: type_origin = get_origin(typ) if type_origin: - if type_origin is Union: # Optional[T] + if type_origin is Union: # Union[T, U] or even Optional[T] return True if type_origin is UnionType: # T | U @@ -62,6 +62,19 @@ def is_union(typ: Type[Any]) -> bool: return False +def build_request(typ: Type[Any], params: Mapping[str, Any]) -> Request: + if is_union(typ): + err: Optional[Exception] = None + for t in typ.__args__: # type: ignore + try: + return build_request(t, params) # type: ignore + except ValidationError as e: + err = e + if err: + raise err + return typ(**params) + + class AsyncRouteProxy(Generic[TCollectionResponse, TResponse, TError_co]): """Proxy from resource to its associate routes.""" @@ -109,10 +122,11 @@ def _prepare_request( raise NoContractException(method, self.name, self.client_name) param_schema, return_schema = resource.contract[method] + build_params: Request if isinstance(params, dict): - params = param_schema(**params) + build_params = build_request(param_schema, params) elif params is None: - params = param_schema() + build_params = param_schema() elif not isinstance(params, param_schema): raise WrongRequestTypeException( params.__class__, # type: ignore @@ -120,7 +134,9 @@ def _prepare_request( self.name, self.client_name, ) - req = params.to_http_request(method, self.endpoint + resource.path) + else: + build_params = params + req = build_params.to_http_request(method, self.endpoint + resource.path) return (resource.path, req, return_schema) def _prepare_response( diff --git a/src/blacksmith/service/_sync/route_proxy.py b/src/blacksmith/service/_sync/route_proxy.py index 4a6212c9..5c29fe0b 100644 --- a/src/blacksmith/service/_sync/route_proxy.py +++ b/src/blacksmith/service/_sync/route_proxy.py @@ -1,6 +1,14 @@ -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, Type, Union +try: + from types import UnionType +except ImportError: # coverage: ignore + # python 3.7 compat + UnionType = Union # type: ignore + +from pydantic import ValidationError from result import Err, Ok, Result +from typing_extensions import get_origin from blacksmith.domain.error import AbstractErrorParser, TError_co from blacksmith.domain.exceptions import ( @@ -43,6 +51,30 @@ def build_timeout(timeout: ClientTimeout) -> HTTPTimeout: return timeout +def is_union(typ: Type[Any]) -> bool: + type_origin = get_origin(typ) + if type_origin: + if type_origin is Union: # Union[T, U] or even Optional[T] + return True + + if type_origin is UnionType: # T | U + return True + return False + + +def build_request(typ: Type[Any], params: Mapping[str, Any]) -> Request: + if is_union(typ): + err: Optional[Exception] = None + for t in typ.__args__: # type: ignore + try: + return build_request(t, params) # type: ignore + except ValidationError as e: + err = e + if err: + raise err + return typ(**params) + + class SyncRouteProxy(Generic[TCollectionResponse, TResponse, TError_co]): """Proxy from resource to its associate routes.""" @@ -90,10 +122,11 @@ def _prepare_request( raise NoContractException(method, self.name, self.client_name) param_schema, return_schema = resource.contract[method] + build_params: Request if isinstance(params, dict): - params = param_schema(**params) + build_params = build_request(param_schema, params) elif params is None: - params = param_schema() + build_params = param_schema() elif not isinstance(params, param_schema): raise WrongRequestTypeException( params.__class__, # type: ignore @@ -101,7 +134,9 @@ def _prepare_request( self.name, self.client_name, ) - req = params.to_http_request(method, self.endpoint + resource.path) + else: + build_params = params + req = build_params.to_http_request(method, self.endpoint + resource.path) return (resource.path, req, return_schema) def _prepare_response( @@ -127,7 +162,6 @@ def _prepare_collection_response( response_schema: Optional[Type[Response]], collection_parser: Optional[Type[AbstractCollectionParser]], ) -> Result[CollectionIterator[TCollectionResponse], TError_co]: - if result.is_err(): return Err(self.error_parser(result.unwrap_err())) else: diff --git a/tests/functionals/test_api_client.py b/tests/functionals/test_api_client.py index b81316ff..7a3940ed 100644 --- a/tests/functionals/test_api_client.py +++ b/tests/functionals/test_api_client.py @@ -36,6 +36,11 @@ class CreateItem(Request): size: SizeEnum = PostBodyField(SizeEnum.m) +class CreateItemIntSize(Request): + name: str = PostBodyField() + size: int = PostBodyField(2) + + class ListItem(Request): name: Optional[str] = QueryStringField(None) @@ -60,7 +65,7 @@ class UpdateItem(GetItem): collection_path="/items", collection_contract={ "GET": (ListItem, Item), - "POST": (CreateItem, None), + "POST": (CreateItem | CreateItemIntSize, None), }, path="/items/{item_name}", contract={ diff --git a/tests/unittests/_async/test_service_route_proxy.py b/tests/unittests/_async/test_service_route_proxy.py index 5ff71ab4..1f0ef507 100644 --- a/tests/unittests/_async/test_service_route_proxy.py +++ b/tests/unittests/_async/test_service_route_proxy.py @@ -3,6 +3,7 @@ import pytest from pydantic import BaseModel, Field from result import Result +from typing_extensions import Literal from blacksmith import Request from blacksmith.domain.exceptions import ( @@ -24,6 +25,7 @@ from blacksmith.service._async.base import AsyncAbstractTransport from blacksmith.service._async.route_proxy import ( AsyncRouteProxy, + build_request, build_timeout, is_union, ) @@ -79,6 +81,48 @@ def test_is_union(params: Mapping[str, Any]): assert is_union(params["type"]) is params["expected"] +class Foo(BaseModel): + typ: Literal["foo"] + + +class Bar(BaseModel): + typ: Literal["bar"] + + +class Foobar(BaseModel): + obj: Union[Foo, Bar] + + +@pytest.mark.parametrize( + "params", + [ + pytest.param( + {"type": Foo, "params": {"typ": "foo"}, "expected": Foo(typ="foo")}, + id="simple", + ), + pytest.param( + { + "type": Foobar, + "params": {"obj": {"typ": "foo"}}, + "expected": Foobar(obj=Foo(typ="foo")), + }, + id="union", + ), + pytest.param( + { + "type": Foobar, + "params": {"obj": {"typ": "bar"}}, + "expected": Foobar(obj=Bar(typ="bar")), + }, + id="union", + ), + ], +) +def test_build_request(params: Mapping[str, Any]): + req = build_request(params["type"], params["params"]) + assert req == params["expected"] + + async def test_route_proxy_prepare_middleware( dummy_http_request: HTTPRequest, echo_middleware: AsyncAbstractTransport ): diff --git a/tests/unittests/_sync/test_service_route_proxy.py b/tests/unittests/_sync/test_service_route_proxy.py index 261428cc..16392a55 100644 --- a/tests/unittests/_sync/test_service_route_proxy.py +++ b/tests/unittests/_sync/test_service_route_proxy.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, Mapping, Union import pytest from pydantic import BaseModel, Field from result import Result +from typing_extensions import Literal from blacksmith import Request from blacksmith.domain.exceptions import ( @@ -22,7 +23,12 @@ from blacksmith.middleware._sync.auth import SyncHTTPAuthorizationMiddleware from blacksmith.middleware._sync.base import SyncHTTPAddHeadersMiddleware from blacksmith.service._sync.base import SyncAbstractTransport -from blacksmith.service._sync.route_proxy import SyncRouteProxy, build_timeout +from blacksmith.service._sync.route_proxy import ( + SyncRouteProxy, + build_request, + build_timeout, + is_union, +) from blacksmith.typing import ClientName, Path from tests.unittests.dummy_registry import GetParam, GetResponse, PostParam @@ -48,7 +54,6 @@ def __call__( path: Path, timeout: HTTPTimeout, ) -> HTTPResponse: - if self.resp.status_code >= 400: raise HTTPError(f"{self.resp.status_code} blah", req, self.resp) return self.resp @@ -63,6 +68,61 @@ def test_build_timeout() -> None: assert timeout == HTTPTimeout(5.0, 2.0) +@pytest.mark.parametrize( + "params", + [ + pytest.param({"type": int, "expected": False}, id="int"), + pytest.param({"type": str, "expected": False}, id="str"), + pytest.param({"type": int | str, "expected": True}, id="int | str"), + pytest.param({"type": Union[int, str], "expected": True}, id="Union[int, str]"), + ], +) +def test_is_union(params: Mapping[str, Any]): + assert is_union(params["type"]) is params["expected"] + + +class Foo(BaseModel): + typ: Literal["foo"] + + +class Bar(BaseModel): + typ: Literal["bar"] + + +class Foobar(BaseModel): + obj: Union[Foo, Bar] + + +@pytest.mark.parametrize( + "params", + [ + pytest.param( + {"type": Foo, "params": {"typ": "foo"}, "expected": Foo(typ="foo")}, + id="simple", + ), + pytest.param( + { + "type": Foobar, + "params": {"obj": {"typ": "foo"}}, + "expected": Foobar(obj=Foo(typ="foo")), + }, + id="union", + ), + pytest.param( + { + "type": Foobar, + "params": {"obj": {"typ": "bar"}}, + "expected": Foobar(obj=Bar(typ="bar")), + }, + id="union", + ), + ], +) +def test_build_request(params: Mapping[str, Any]): + req = build_request(params["type"], params["params"]) + assert req == params["expected"] + + def test_route_proxy_prepare_middleware( dummy_http_request: HTTPRequest, echo_middleware: SyncAbstractTransport ): diff --git a/tests/unittests/test_registry.py b/tests/unittests/test_registry.py index 59ae8b9d..3be47db1 100644 --- a/tests/unittests/test_registry.py +++ b/tests/unittests/test_registry.py @@ -1,9 +1,12 @@ +from typing import Literal + import pytest import blacksmith from blacksmith.domain import registry from blacksmith.domain.exceptions import ConfigurationError, UnregisteredClientException from blacksmith.domain.model import PathInfoField, PostBodyField, Request, Response +from blacksmith.domain.model.params import QueryStringField from blacksmith.domain.registry import Registry @@ -111,6 +114,41 @@ class DummyRequest(Request): assert api["dummies"].resource.contract["GET"][1] is None +def test_registry_with_union_type() -> None: + class FooRequest(Request): + name: str = PathInfoField() + type: Literal["foo"] = QueryStringField() + + class BarRequest(Request): + name: str = PathInfoField() + type: Literal["bar"] = QueryStringField() + + registry = Registry() + registry.register( + "dummies_api", + "dummies", + "api", + "v5", + path="/dummies/{name}", + contract={ + "GET": (FooRequest | BarRequest, None), + }, + ) + + assert registry.client_service == {"dummies_api": ("api", "v5")} + assert set(registry.clients.keys()) == {"dummies_api"} + + assert set(registry.clients["dummies_api"].keys()) == {"dummies"} + + api = registry.clients["dummies_api"] + assert api["dummies"].resource is not None + assert api["dummies"].resource.contract is not None + assert api["dummies"].resource.path == "/dummies/{name}" + assert set(api["dummies"].resource.contract.keys()) == {"GET"} + assert api["dummies"].resource.contract["GET"][0] == FooRequest | BarRequest + assert api["dummies"].resource.contract["GET"][1] is None + + def test_registry_only_collection() -> None: class DummyRequest(Request): pass