From 5b3e1b745c7412595d4a8360b3120c78dc7e0603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bal=C3=A1zs=20Moln=C3=A1r?= <97176035+bmlnr@users.noreply.github.com> Date: Tue, 30 Jan 2024 18:16:30 +0100 Subject: [PATCH 1/2] Feature/add support for lists during conversion to dictionary (#4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat: add support for lists during conversion to dictionary --------- Co-authored-by: MichaƂ Murawski --- src/pynamodb_utils/attributes.py | 38 +++++++++------- src/pynamodb_utils/conditions.py | 76 +++++++++++++++----------------- src/pynamodb_utils/exceptions.py | 8 ++++ src/pynamodb_utils/models.py | 20 ++++++--- src/pynamodb_utils/parsers.py | 4 +- src/pynamodb_utils/utils.py | 19 ++++---- src/tests/conftest.py | 20 ++++++--- src/tests/requirements.txt | 9 ++-- src/tests/test_enum.py | 22 ++++----- src/tests/test_general.py | 40 ++++++++--------- 10 files changed, 141 insertions(+), 115 deletions(-) diff --git a/src/pynamodb_utils/attributes.py b/src/pynamodb_utils/attributes.py index 465a554..e661b00 100644 --- a/src/pynamodb_utils/attributes.py +++ b/src/pynamodb_utils/attributes.py @@ -1,8 +1,11 @@ from enum import Enum -from typing import Collection, FrozenSet, Union +from typing import Collection, FrozenSet, Optional, Union import six from pynamodb.attributes import MapAttribute, NumberAttribute, UnicodeAttribute +from pynamodb.constants import NUMBER + +from pynamodb_utils.exceptions import EnumSerializationException class DynamicMapAttribute(MapAttribute): @@ -54,14 +57,16 @@ def __str__(self) -> str: class EnumNumberAttribute(NumberAttribute): + attr_type = NUMBER + def __init__( self, enum, - hash_key=False, - range_key=False, - null=None, - default=None, - attr_name=None, + hash_key: bool = False, + range_key: bool = False, + null: Optional[bool] = None, + default: Optional[Enum] = None, + attr_name: Optional[str] = None, ): if isinstance(enum, Enum): raise ValueError("enum must be Enum class") @@ -69,7 +74,7 @@ def __init__( super().__init__( hash_key=hash_key, range_key=range_key, - default=default, + default=default.value if default else None, null=null, attr_name=attr_name, ) @@ -85,7 +90,7 @@ def serialize(self, value: Union[Enum, str]) -> str: f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}' ) except TypeError as e: - raise Exception(value, self.enum) from e + raise EnumSerializationException(f"Error serializing {value} with enum {self.enum}") from e def deserialize(self, value: str) -> str: return self.enum(int(value)).name @@ -94,12 +99,12 @@ def deserialize(self, value: str) -> str: class EnumUnicodeAttribute(UnicodeAttribute): def __init__( self, - hash_key=False, - range_key=False, - null=None, - default=None, - attr_name=None, - enum=None, + enum, + hash_key: bool = False, + range_key: bool = False, + null: Optional[bool] = None, + default: Optional[Enum] = None, + attr_name: Optional[str] = None, ): if isinstance(enum, Enum): raise ValueError("enum must be Enum class") @@ -115,9 +120,8 @@ def __init__( def serialize(self, value: Union[Enum, str]) -> str: if isinstance(value, self.enum): return str(value.value) - elif isinstance(value, str): - if value in self.enum.__members__.keys(): - return getattr(self.enum, value).value + elif isinstance(value, str) and value in self.enum.__members__.keys(): + return getattr(self.enum, value).value raise ValueError( f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}' ) diff --git a/src/pynamodb_utils/conditions.py b/src/pynamodb_utils/conditions.py index 5663e3a..3c814df 100644 --- a/src/pynamodb_utils/conditions.py +++ b/src/pynamodb_utils/conditions.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Optional from pynamodb.attributes import Attribute from pynamodb.expressions.condition import Condition @@ -12,13 +12,30 @@ from pynamodb_utils.utils import get_attribute, get_available_attributes_list +def _is_available(field_path: str, available_attributes: List, raise_exception: bool): + if "." in field_path: + _field_path = field_path.split(".", 1)[0] + ".*" + is_available = _field_path in available_attributes + else: + is_available = field_path in available_attributes + if not is_available and raise_exception: + raise FilterError( + message={ + field_path: [ + f"Parameter {field_path} does not exist." + f" Choose some of available: {', '.join(available_attributes)}" + ] + } + ) + + def create_model_condition( - model: Model, - args: Dict[str, Any], - _operator: Callable = operator.and_, - raise_exception: bool = True, - unavailable_attributes: List[str] = [] -) -> Condition: + model: Model, + args: Dict[str, Any], + _operator: Callable = operator.and_, + raise_exception: bool = True, + unavailable_attributes: Optional[List[str]] = None +) -> Optional[Condition]: """ Function creates pynamodb conditions based on input dictionary (args) Parameters: @@ -31,52 +48,29 @@ def create_model_condition( condition (Condition): computed pynamodb condition """ conditions_list: List[Condition] = [] - - available_attributes: Set[str] = get_available_attributes_list( + available_attributes: List[str] = get_available_attributes_list( model=model, - unavaiable_attrs=unavailable_attributes + unavailable_attrs=unavailable_attributes ) - - key: str - value: Any for key, value in args.items(): - array: List[str] = key.rsplit('__', 1) + array: List[str] = key.rsplit("__", 1) field_path: str = array[0] - operator_name: str = array[1] if len(array) > 1 and array[1] != 'not' else '' - - if "." in field_path: - _field_path = field_path.split(".", 1)[0] + ".*" - is_available = _field_path in available_attributes - else: - - is_available = field_path in available_attributes - - if operator_name.replace('not_', '') not in OPERATORS_MAPPING: - raise FilterError( - message={key: [f'Operator {operator_name} does not exist.' - f' Choose some of available: {", ".join(OPERATORS_MAPPING.keys())}']} - ) - if not is_available and raise_exception: + operator_name: str = array[1] if len(array) > 1 and array[1] != "not" else "" + if operator_name.replace("not_", "") not in OPERATORS_MAPPING: raise FilterError( - message={ - field_path: [ - f"Parameter {field_path} does not exist." - f' Choose some of available: {", ".join(available_attributes)}' - ] - } + message={key: [f"Operator {operator_name} does not exist." + f" Choose some of available: {', '.join(OPERATORS_MAPPING.keys())}"]} ) - + _is_available(field_path, available_attributes, raise_exception) attr: Attribute = get_attribute(model, field_path) - if isinstance(attr, (Attribute, Path)): if 'not_' in operator_name: - operator_name = operator_name.replace('not_', '') + operator_name = operator_name.replace("not_", "") operator_handler = OPERATORS_MAPPING[operator_name] conditions_list.append(~operator_handler(model, field_path, attr, value)) else: operator_handler = OPERATORS_MAPPING[operator_name] conditions_list.append(operator_handler(model, field_path, attr, value)) - if not conditions_list: - return None - else: + if conditions_list: return reduce(_operator, conditions_list) + return None diff --git a/src/pynamodb_utils/exceptions.py b/src/pynamodb_utils/exceptions.py index a77836c..6546a54 100644 --- a/src/pynamodb_utils/exceptions.py +++ b/src/pynamodb_utils/exceptions.py @@ -10,3 +10,11 @@ class FilterError(Error): class SerializerError(Error): pass + + +class IndexNotFoundError(Exception): + pass + + +class EnumSerializationException(Exception): + pass diff --git a/src/pynamodb_utils/models.py b/src/pynamodb_utils/models.py index b766fce..a36aaf4 100644 --- a/src/pynamodb_utils/models.py +++ b/src/pynamodb_utils/models.py @@ -1,5 +1,5 @@ from datetime import timezone -from typing import Any, List +from typing import Any, List, Optional from pynamodb.attributes import UTCDateTimeAttribute from pynamodb.expressions.condition import Condition @@ -20,12 +20,14 @@ def get_conditions_from_json(cls, query: dict, raise_exception: bool = True) -> Parameters: query (dict): A decimal integer + raise_exception (bool): Throwing an exception in case of an error Returns: condition (Condition): computed pynamodb condition """ query_unavailable_attributes: List[str] = getattr(cls.Meta, "query_unavailable_attributes", []) - return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, raise_exception=raise_exception) + return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, + raise_exception=raise_exception) @classmethod def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -> ResultIterator[Model]: @@ -34,6 +36,7 @@ def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) - Parameters: query (dict): A decimal integer + raise_exception (bool): Throwing an exception in case of an error Returns: result_iterator (result_iterator): result iterator for optimized query @@ -69,6 +72,9 @@ def _pop_path(obj: dict, path: str) -> Any: obj = obj[key] +TZ_INFO = "TZINFO" + + class TimestampedModel(Model): created_at = UTCDateTimeAttribute(default=get_timestamp) updated_at = UTCDateTimeAttribute(default=get_timestamp) @@ -77,17 +83,17 @@ class TimestampedModel(Model): class Meta: abstract = True - def save(self, condition=None): - tz_info = getattr(self.Meta, "TZINFO", None) + def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True): + tz_info = getattr(self.Meta, TZ_INFO, None) self.created_at = self.created_at.astimezone(tz=tz_info or timezone.utc) - self.updated_at = get_timestamp(tzinfo=tz_info) - super().save(condition=condition) + self.updated_at = get_timestamp(tz=tz_info) + super().save(condition=condition, add_version_condition=add_version_condition) def save_without_timestamp_update(self, condition=None): super().save(condition=condition) def soft_delete(self, condition=None): """ Puts delete_at timestamp """ - tz_info = getattr(self.Meta, "TZINFO", None) + tz_info = getattr(self.Meta, TZ_INFO, None) self.deleted_at = get_timestamp(tz_info) super().save(condition=condition) diff --git a/src/pynamodb_utils/parsers.py b/src/pynamodb_utils/parsers.py index 422a6a7..40e8067 100644 --- a/src/pynamodb_utils/parsers.py +++ b/src/pynamodb_utils/parsers.py @@ -58,13 +58,13 @@ def default_list_parser(value: List[Any], field_name: str, model: Model) -> List raise FilterError(message={field_name: [f"{value} is not valid type of {field_name}."]}) -def default_dict_parser(value: Dict, field_name: str, *args) -> Dict[Any, Any]: +def default_dict_parser(value: Dict, field_name: str, *args) -> Union[Dict[Any, Any], str]: if isinstance(value, (dict, NoneType)): return value elif isinstance(value, str): try: return json.dumps(value, default=str) - except (ValueError, json.JSONDecodeError): + except ValueError: pass raise FilterError( message={field_name: [f"{value} is not valid type of {field_name}."]} diff --git a/src/pynamodb_utils/utils.py b/src/pynamodb_utils/utils.py index 6f75b99..8d06df4 100644 --- a/src/pynamodb_utils/utils.py +++ b/src/pynamodb_utils/utils.py @@ -1,11 +1,12 @@ from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pynamodb.attributes import Attribute, MapAttribute from pynamodb.indexes import GlobalSecondaryIndex, LocalSecondaryIndex from pynamodb.models import Model from pynamodb_utils.attributes import DynamicMapAttribute +from pynamodb_utils.exceptions import IndexNotFoundError NoneType = type(None) @@ -31,7 +32,7 @@ def create_index_map( ).get("AttributeName") idx_map[(hash_key, range_key)] = getattr(model, k) except StopIteration as e: - raise Exception("Could not find index keys") from e + raise IndexNotFoundError("Could not find index keys") from e return idx_map @@ -53,12 +54,14 @@ def pick_index_keys( return keys -def parse_attr(attr: Attribute) -> Union[Dict, datetime]: +def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime, str]: """ Function parses attribute to corresponding values """ if isinstance(attr, DynamicMapAttribute): return attr.as_dict() + elif isinstance(attr, List): + return [parse_attr(el) for el in attr] elif isinstance(attr, MapAttribute): return parse_attrs_to_dict(attr) elif isinstance(attr, datetime): @@ -85,12 +88,12 @@ def get_attributes_list(model: Model, depth: int = 0) -> List[str]: return attrs -def get_available_attributes_list(model: Model, unavaiable_attrs: List[str] = []) -> Set[str]: +def get_available_attributes_list(model: Model, unavailable_attrs: Optional[List[str]] = None) -> List[str]: attrs: List[str] = get_attributes_list(model) - return [attr for attr in attrs if attr not in unavaiable_attrs] + return sorted(set(attr for attr in attrs if attr not in unavailable_attrs)) -def get_attribute(model: Model, attr_string: str) -> Attribute: +def get_attribute(model: Model, attr_string: str) -> Optional[Attribute]: """ Function gets nested attribute based on path (attr_string) """ @@ -106,5 +109,5 @@ def get_attribute(model: Model, attr_string: str) -> Attribute: return result -def get_timestamp(tzinfo: timezone = None) -> datetime: - return datetime.now(tzinfo or timezone.utc) +def get_timestamp(tz: timezone = None) -> datetime: + return datetime.now(tz or timezone.utc) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 38c8edc..3922c81 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1,8 +1,9 @@ import enum +import os from datetime import timezone import pytest -from moto import mock_dynamodb +from moto import mock_aws from pynamodb.attributes import UnicodeAttribute, UTCDateTimeAttribute from pynamodb.indexes import AllProjection, GlobalSecondaryIndex @@ -10,13 +11,22 @@ @pytest.fixture -def dynamodb(): - with mock_dynamodb(): +def aws_environ(): + env_vars = { + "AWS_DEFAULT_REGION": "us-east-1" + } + with mock_aws(): + for k, v in env_vars.items(): + os.environ[k] = v + yield + for k in env_vars: + del os.environ[k] + @pytest.fixture -def post_table(dynamodb): +def post_table(aws_environ): class CategoryEnum(enum.Enum): finance = enum.auto() politics = enum.auto() @@ -34,7 +44,7 @@ class Post(AsDictModel, JSONQueryModel, TimestampedModel): sub_name = UnicodeAttribute(range_key=True) category = EnumAttribute(enum=CategoryEnum, default=CategoryEnum.finance) content = UnicodeAttribute() - tags = DynamicMapAttribute(default={}) + tags = DynamicMapAttribute(default=None) category_created_at_gsi = PostCategoryCreatedAtGSI() secret_parameter = UnicodeAttribute(default="secret") diff --git a/src/tests/requirements.txt b/src/tests/requirements.txt index b672d5c..8b31d74 100644 --- a/src/tests/requirements.txt +++ b/src/tests/requirements.txt @@ -1,4 +1,5 @@ -pytest>=7.0.1 -pytest-cov>=3.0.0 -moto==4.1.7 -freezegun>=1.1.0 +pytest>=8.0.0 +pytest-cov>=4.1.0 +moto[dynamodb]>=5.0.0 +freezegun>=1.4.0 +pynamodb>=6.0.0 \ No newline at end of file diff --git a/src/tests/test_enum.py b/src/tests/test_enum.py index 8c912be..18867ba 100644 --- a/src/tests/test_enum.py +++ b/src/tests/test_enum.py @@ -8,20 +8,20 @@ @freeze_time("2019-01-01 00:00:00+00:00") def test_enum_query_not_member_of(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() with pytest.raises(SerializerError) as e: - Post.get_conditions_from_json( + post.get_conditions_from_json( query={ "created_at__lte": str(datetime.now()), "sub_name__exists": None, @@ -30,21 +30,21 @@ def test_enum_query_not_member_of(post_table): "tags.topics__contains": ["NYSE"], } ) - assert e.value.message == { - "Query": {"category": ["1 is not member of finance, politics."]} - } + assert e.value.message == { + "Query": {"category": ["1 is not member of finance, politics."]} + } @freeze_time("2019-01-01 00:00:00+00:00") def test_enum_create_not_member_of(post_table): - Post = post_table + post = post_table with pytest.raises(ValueError) as e: - post = Post( + post = post( name="A weekly news.", content="Last week took place...", category=1, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() - assert str(e.value) == "Value Error: 1 must be in finance, politics" + assert str(e.value) == "Value Error: 1 must be in finance, politics" diff --git a/src/tests/test_general.py b/src/tests/test_general.py index 2c9d7f9..cd9df75 100644 --- a/src/tests/test_general.py +++ b/src/tests/test_general.py @@ -8,14 +8,14 @@ @freeze_time("2019-01-01 00:00:00+00:00") def test_general(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() @@ -26,7 +26,7 @@ def test_general(post_table): "OR": {"tags.type__equals": "news", "tags.topics__contains": ["NYSE"]}, } - results = Post.make_index_query(query) + results = post.make_index_query(query) expected = { "content": "Last week took place...", @@ -43,26 +43,26 @@ def test_general(post_table): def test_bad_field(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() - with pytest.raises(SerializerError) as e: - Post.get_conditions_from_json(query={"tag.type__equals": "news"}) - assert e.message == { - "Query": { - "tag.type": [ - "Parameter tag does not exist. Choose some of " - "available: category, content, created_at, deleted_at, " - "name, sub_name, tags, updated_at" - ] - } - } + with pytest.raises(SerializerError) as exc_info: + post.get_conditions_from_json(query={"tag.type__equals": "news"}) + assert exc_info.value.message == { + "Query": { + "tag.type": [ + "Parameter tag.type does not exist. Choose some of " + "available: category, content, created_at, deleted_at, " + "name, sub_name, tags, tags.*, updated_at" + ] + } + } From b7692cf13d7d13652823dc251666a69d3199ef4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Murawski?= Date: Tue, 30 Jan 2024 18:18:20 +0100 Subject: [PATCH 2/2] feat: add support for lists during conversion to dictionary --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3a3772d..c751517 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ def read(*parts): setup( name="pynamodb_utils", - version="1.3.6", + version="1.3.7", author="Michal Murawski", author_email="mmurawski777@gmail.com", description="Utilities package for pynamodb.",