Skip to content

Commit

Permalink
fix(idempotency): add support for Optional type when serializing outp…
Browse files Browse the repository at this point in the history
…ut (#5590)

* Accepting None when working with output serialization

* Fix Python3.8/3.9

* Make mypy happy

* Making it work in python 3.8 and 3.9

* Making it work in python 3.8 and 3.9
  • Loading branch information
leandrodamascena authored Dec 23, 2024
1 parent 1261c07 commit 7a7f10c
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BaseIdempotencyModelSerializer,
BaseIdempotencySerializer,
)
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type

DataClass = Any

Expand All @@ -37,6 +38,9 @@ def from_dict(self, data: dict) -> DataClass:

@classmethod
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:

model_type = get_actual_type(model_type=model_type)

if model_type is None:
raise IdempotencyNoSerializationModelError("No serialization model was supplied")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
from typing import Any, Optional, Union, get_args, get_origin

# Conditionally import or define UnionType based on Python version
if sys.version_info >= (3, 10):
from types import UnionType # Available in Python 3.10+
else:
UnionType = Union # Fallback for Python 3.8 and 3.9

from aws_lambda_powertools.utilities.idempotency.exceptions import (
IdempotencyModelTypeError,
)


def get_actual_type(model_type: Any) -> Any:
"""
Extract the actual type from a potentially Optional or Union type.
This function handles types that may be wrapped in Optional or Union,
including the Python 3.10+ Union syntax (Type | None).
Parameters
----------
model_type: Any
The type to analyze. Can be a simple type, Optional[Type], BaseModel, dataclass
Returns
-------
The actual type without Optional or Union wrappers.
Raises:
IdempotencyModelTypeError: If the type specification is invalid
(e.g., Union with multiple non-None types).
"""

# Get the origin of the type (e.g., Union, Optional)
origin = get_origin(model_type)

# Check if type is Union, Optional, or UnionType (Python 3.10+)
if origin in (Union, Optional) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)):
# Get type arguments
args = get_args(model_type)

# Filter out NoneType
actual_type = _extract_non_none_types(args)

# Ensure only one non-None type exists
if len(actual_type) != 1:
raise IdempotencyModelTypeError(
"Invalid type: expected a single type, optionally wrapped in Optional or Union with None.",
)

return actual_type[0]

# If not a Union/Optional type, return original type
return model_type


def _extract_non_none_types(args: tuple) -> list:
"""Extract non-None types from type arguments."""
return [arg for arg in args if arg is not type(None)]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BaseIdempotencyModelSerializer,
BaseIdempotencySerializer,
)
from aws_lambda_powertools.utilities.idempotency.serialization.functions import get_actual_type


class PydanticSerializer(BaseIdempotencyModelSerializer):
Expand All @@ -34,6 +35,9 @@ def from_dict(self, data: dict) -> BaseModel:

@classmethod
def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer:

model_type = get_actual_type(model_type=model_type)

if model_type is None:
raise IdempotencyNoSerializationModelError("No serialization model was supplied")

Expand Down
5 changes: 4 additions & 1 deletion docs/utilities/idempotency.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ By default, `idempotent_function` serializes, stores, and returns your annotated

The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**.

!!! info "When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string."
!!! info
When using the `output_serializer` parameter, the data will continue to be stored in your persistent storage as a JSON string.

Function returns must be annotated with a single type, optionally wrapped in `Optional` or `Union` with `None`.

=== "Pydantic"

Expand Down
49 changes: 48 additions & 1 deletion tests/functional/idempotency/_boto3/test_idempotency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import datetime
import warnings
from typing import Any
from typing import Any, Optional
from unittest.mock import MagicMock, Mock

import jmespath
Expand Down Expand Up @@ -2014,3 +2014,50 @@ def lambda_handler(event, context):

stubber.assert_no_pending_responses()
stubber.deactivate()


@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
def test_idempotent_function_serialization_dataclass_with_optional_return(output_serializer_type: str):
# GIVEN
dataclasses = get_dataclasses_lib()
config = IdempotencyConfig(use_local_cache=True)
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)

@dataclasses.dataclass
class PaymentInput:
customer_id: str
transaction_id: str

@dataclasses.dataclass
class PaymentOutput:
customer_id: str
transaction_id: str

if output_serializer_type == "explicit":
output_serializer = DataclassSerializer(
model=PaymentOutput,
)
else:
output_serializer = DataclassSerializer

@idempotent_function(
data_keyword_argument="payment",
persistence_store=persistence_layer,
config=config,
output_serializer=output_serializer,
)
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
return PaymentOutput(**dataclasses.asdict(payment))

# WHEN
payment = PaymentInput(**mock_event)
first_call: PaymentOutput = collect_payment(payment=payment)
assert first_call.customer_id == payment.customer_id
assert first_call.transaction_id == payment.transaction_id
assert isinstance(first_call, PaymentOutput)
second_call: PaymentOutput = collect_payment(payment=payment)
assert isinstance(second_call, PaymentOutput)
assert second_call.customer_id == payment.customer_id
assert second_call.transaction_id == payment.transaction_id
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import pytest
from pydantic import BaseModel

Expand Down Expand Up @@ -219,3 +221,47 @@ def collect_payment(payment: Payment):

# THEN idempotency key assertion happens at MockPersistenceLayer
assert result == payment.transaction_id


@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"])
def test_idempotent_function_serialization_pydantic_with_optional_return(output_serializer_type: str):
# GIVEN
config = IdempotencyConfig(use_local_cache=True)
mock_event = {"customer_id": "fake", "transaction_id": "fake-id"}
idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_with_optional_return.<locals>.collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501
persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key)

class PaymentInput(BaseModel):
customer_id: str
transaction_id: str

class PaymentOutput(BaseModel):
customer_id: str
transaction_id: str

if output_serializer_type == "explicit":
output_serializer = PydanticSerializer(
model=PaymentOutput,
)
else:
output_serializer = PydanticSerializer

@idempotent_function(
data_keyword_argument="payment",
persistence_store=persistence_layer,
config=config,
output_serializer=output_serializer,
)
def collect_payment(payment: PaymentInput) -> Optional[PaymentOutput]:
return PaymentOutput(**payment.dict())

# WHEN
payment = PaymentInput(**mock_event)
first_call: PaymentOutput = collect_payment(payment=payment)
assert first_call.customer_id == payment.customer_id
assert first_call.transaction_id == payment.transaction_id
assert isinstance(first_call, PaymentOutput)
second_call: PaymentOutput = collect_payment(payment=payment)
assert isinstance(second_call, PaymentOutput)
assert second_call.customer_id == payment.customer_id
assert second_call.transaction_id == payment.transaction_id

0 comments on commit 7a7f10c

Please sign in to comment.