Skip to content

Commit

Permalink
Serialize dict to proper union type in case it has to
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Gauvrit committed Jan 9, 2024
1 parent c56c3b7 commit 7d8c8ef
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 18 deletions.
34 changes: 25 additions & 9 deletions src/blacksmith/service/_async/route_proxy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, Type, Union

try:
from types import NoneType, UnionType
from types import UnionType
except ImportError: # coverage: ignore
# python 3.7 compat
NoneType = type(None) # coverage: ignore
UnionType = object() # coverage: ignore

from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union
UnionType = Union # type: ignore

from pydantic import ValidationError
from result import Err, Ok, Result
from typing_extensions import get_origin

Expand Down Expand Up @@ -54,14 +54,27 @@ def build_timeout(timeout: ClientTimeout) -> HTTPTimeout:
def is_union(typ: Type[Any]) -> bool:
type_origin = get_origin(typ)
if type_origin:
if type_origin is Union: # Optional[T]
if type_origin is Union: # Union[T, U] or even Optional[T]
return True

if type_origin is UnionType: # T | U
return True
return False


def build_request(typ: Type[Any], params: Mapping[str, Any]) -> Request:
if is_union(typ):
err: Optional[Exception] = None
for t in typ.__args__: # type: ignore
try:
return build_request(t, params) # type: ignore
except ValidationError as e:
err = e
if err:
raise err
return typ(**params)


class AsyncRouteProxy(Generic[TCollectionResponse, TResponse, TError_co]):
"""Proxy from resource to its associate routes."""

Expand Down Expand Up @@ -109,18 +122,21 @@ def _prepare_request(
raise NoContractException(method, self.name, self.client_name)

param_schema, return_schema = resource.contract[method]
build_params: Request
if isinstance(params, dict):
params = param_schema(**params)
build_params = build_request(param_schema, params)
elif params is None:
params = param_schema()
build_params = param_schema()
elif not isinstance(params, param_schema):
raise WrongRequestTypeException(
params.__class__, # type: ignore
method,
self.name,
self.client_name,
)
req = params.to_http_request(method, self.endpoint + resource.path)
else:
build_params = params
req = build_params.to_http_request(method, self.endpoint + resource.path)
return (resource.path, req, return_schema)

def _prepare_response(
Expand Down
44 changes: 39 additions & 5 deletions src/blacksmith/service/_sync/route_proxy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, Type, Union

try:
from types import UnionType
except ImportError: # coverage: ignore
# python 3.7 compat
UnionType = Union # type: ignore

from pydantic import ValidationError
from result import Err, Ok, Result
from typing_extensions import get_origin

from blacksmith.domain.error import AbstractErrorParser, TError_co
from blacksmith.domain.exceptions import (
Expand Down Expand Up @@ -43,6 +51,30 @@ def build_timeout(timeout: ClientTimeout) -> HTTPTimeout:
return timeout


def is_union(typ: Type[Any]) -> bool:
type_origin = get_origin(typ)
if type_origin:
if type_origin is Union: # Union[T, U] or even Optional[T]
return True

if type_origin is UnionType: # T | U
return True
return False


def build_request(typ: Type[Any], params: Mapping[str, Any]) -> Request:
if is_union(typ):
err: Optional[Exception] = None
for t in typ.__args__: # type: ignore
try:
return build_request(t, params) # type: ignore
except ValidationError as e:
err = e
if err:
raise err
return typ(**params)


class SyncRouteProxy(Generic[TCollectionResponse, TResponse, TError_co]):
"""Proxy from resource to its associate routes."""

Expand Down Expand Up @@ -90,18 +122,21 @@ def _prepare_request(
raise NoContractException(method, self.name, self.client_name)

param_schema, return_schema = resource.contract[method]
build_params: Request
if isinstance(params, dict):
params = param_schema(**params)
build_params = build_request(param_schema, params)
elif params is None:
params = param_schema()
build_params = param_schema()
elif not isinstance(params, param_schema):
raise WrongRequestTypeException(
params.__class__, # type: ignore
method,
self.name,
self.client_name,
)
req = params.to_http_request(method, self.endpoint + resource.path)
else:
build_params = params
req = build_params.to_http_request(method, self.endpoint + resource.path)
return (resource.path, req, return_schema)

def _prepare_response(
Expand All @@ -127,7 +162,6 @@ def _prepare_collection_response(
response_schema: Optional[Type[Response]],
collection_parser: Optional[Type[AbstractCollectionParser]],
) -> Result[CollectionIterator[TCollectionResponse], TError_co]:

if result.is_err():
return Err(self.error_parser(result.unwrap_err()))
else:
Expand Down
7 changes: 6 additions & 1 deletion tests/functionals/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class CreateItem(Request):
size: SizeEnum = PostBodyField(SizeEnum.m)


class CreateItemIntSize(Request):
name: str = PostBodyField()
size: int = PostBodyField(2)


class ListItem(Request):
name: Optional[str] = QueryStringField(None)

Expand All @@ -60,7 +65,7 @@ class UpdateItem(GetItem):
collection_path="/items",
collection_contract={
"GET": (ListItem, Item),
"POST": (CreateItem, None),
"POST": (CreateItem | CreateItemIntSize, None),
},
path="/items/{item_name}",
contract={
Expand Down
44 changes: 44 additions & 0 deletions tests/unittests/_async/test_service_route_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pydantic import BaseModel, Field
from result import Result
from typing_extensions import Literal

from blacksmith import Request
from blacksmith.domain.exceptions import (
Expand All @@ -24,6 +25,7 @@
from blacksmith.service._async.base import AsyncAbstractTransport
from blacksmith.service._async.route_proxy import (
AsyncRouteProxy,
build_request,
build_timeout,
is_union,
)
Expand Down Expand Up @@ -79,6 +81,48 @@ def test_is_union(params: Mapping[str, Any]):
assert is_union(params["type"]) is params["expected"]


class Foo(BaseModel):
typ: Literal["foo"]


class Bar(BaseModel):
typ: Literal["bar"]


class Foobar(BaseModel):
obj: Union[Foo, Bar]


@pytest.mark.parametrize(
"params",
[
pytest.param(
{"type": Foo, "params": {"typ": "foo"}, "expected": Foo(typ="foo")},
id="simple",
),
pytest.param(
{
"type": Foobar,
"params": {"obj": {"typ": "foo"}},
"expected": Foobar(obj=Foo(typ="foo")),
},
id="union",
),
pytest.param(
{
"type": Foobar,
"params": {"obj": {"typ": "bar"}},
"expected": Foobar(obj=Bar(typ="bar")),
},
id="union",
),
],
)
def test_build_request(params: Mapping[str, Any]):
req = build_request(params["type"], params["params"])
assert req == params["expected"]


async def test_route_proxy_prepare_middleware(
dummy_http_request: HTTPRequest, echo_middleware: AsyncAbstractTransport
):
Expand Down
66 changes: 63 additions & 3 deletions tests/unittests/_sync/test_service_route_proxy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any
from typing import Any, Mapping, Union

import pytest
from pydantic import BaseModel, Field
from result import Result
from typing_extensions import Literal

from blacksmith import Request
from blacksmith.domain.exceptions import (
Expand All @@ -22,7 +23,12 @@
from blacksmith.middleware._sync.auth import SyncHTTPAuthorizationMiddleware
from blacksmith.middleware._sync.base import SyncHTTPAddHeadersMiddleware
from blacksmith.service._sync.base import SyncAbstractTransport
from blacksmith.service._sync.route_proxy import SyncRouteProxy, build_timeout
from blacksmith.service._sync.route_proxy import (
SyncRouteProxy,
build_request,
build_timeout,
is_union,
)
from blacksmith.typing import ClientName, Path
from tests.unittests.dummy_registry import GetParam, GetResponse, PostParam

Expand All @@ -48,7 +54,6 @@ def __call__(
path: Path,
timeout: HTTPTimeout,
) -> HTTPResponse:

if self.resp.status_code >= 400:
raise HTTPError(f"{self.resp.status_code} blah", req, self.resp)
return self.resp
Expand All @@ -63,6 +68,61 @@ def test_build_timeout() -> None:
assert timeout == HTTPTimeout(5.0, 2.0)


@pytest.mark.parametrize(
"params",
[
pytest.param({"type": int, "expected": False}, id="int"),
pytest.param({"type": str, "expected": False}, id="str"),
pytest.param({"type": int | str, "expected": True}, id="int | str"),
pytest.param({"type": Union[int, str], "expected": True}, id="Union[int, str]"),
],
)
def test_is_union(params: Mapping[str, Any]):
assert is_union(params["type"]) is params["expected"]


class Foo(BaseModel):
typ: Literal["foo"]


class Bar(BaseModel):
typ: Literal["bar"]


class Foobar(BaseModel):
obj: Union[Foo, Bar]


@pytest.mark.parametrize(
"params",
[
pytest.param(
{"type": Foo, "params": {"typ": "foo"}, "expected": Foo(typ="foo")},
id="simple",
),
pytest.param(
{
"type": Foobar,
"params": {"obj": {"typ": "foo"}},
"expected": Foobar(obj=Foo(typ="foo")),
},
id="union",
),
pytest.param(
{
"type": Foobar,
"params": {"obj": {"typ": "bar"}},
"expected": Foobar(obj=Bar(typ="bar")),
},
id="union",
),
],
)
def test_build_request(params: Mapping[str, Any]):
req = build_request(params["type"], params["params"])
assert req == params["expected"]


def test_route_proxy_prepare_middleware(
dummy_http_request: HTTPRequest, echo_middleware: SyncAbstractTransport
):
Expand Down
38 changes: 38 additions & 0 deletions tests/unittests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Literal

import pytest

import blacksmith
from blacksmith.domain import registry
from blacksmith.domain.exceptions import ConfigurationError, UnregisteredClientException
from blacksmith.domain.model import PathInfoField, PostBodyField, Request, Response
from blacksmith.domain.model.params import QueryStringField
from blacksmith.domain.registry import Registry


Expand Down Expand Up @@ -111,6 +114,41 @@ class DummyRequest(Request):
assert api["dummies"].resource.contract["GET"][1] is None


def test_registry_with_union_type() -> None:
class FooRequest(Request):
name: str = PathInfoField()
type: Literal["foo"] = QueryStringField()

class BarRequest(Request):
name: str = PathInfoField()
type: Literal["bar"] = QueryStringField()

registry = Registry()
registry.register(
"dummies_api",
"dummies",
"api",
"v5",
path="/dummies/{name}",
contract={
"GET": (FooRequest | BarRequest, None),
},
)

assert registry.client_service == {"dummies_api": ("api", "v5")}
assert set(registry.clients.keys()) == {"dummies_api"}

assert set(registry.clients["dummies_api"].keys()) == {"dummies"}

api = registry.clients["dummies_api"]
assert api["dummies"].resource is not None
assert api["dummies"].resource.contract is not None
assert api["dummies"].resource.path == "/dummies/{name}"
assert set(api["dummies"].resource.contract.keys()) == {"GET"}
assert api["dummies"].resource.contract["GET"][0] == FooRequest | BarRequest
assert api["dummies"].resource.contract["GET"][1] is None


def test_registry_only_collection() -> None:
class DummyRequest(Request):
pass
Expand Down

0 comments on commit 7d8c8ef

Please sign in to comment.