diff --git a/src/pynamodb_utils/attributes.py b/src/pynamodb_utils/attributes.py index 465a554..92210a6 100644 --- a/src/pynamodb_utils/attributes.py +++ b/src/pynamodb_utils/attributes.py @@ -1,7 +1,8 @@ from enum import Enum -from typing import Collection, FrozenSet, Union +from typing import Collection, FrozenSet, Union, Optional import six +from pynamodb.constants import NUMBER from pynamodb.attributes import MapAttribute, NumberAttribute, UnicodeAttribute @@ -54,14 +55,16 @@ def __str__(self) -> str: class EnumNumberAttribute(NumberAttribute): + attr_type = NUMBER + def __init__( self, enum, - hash_key=False, - range_key=False, - null=None, - default=None, - attr_name=None, + hash_key: bool = False, + range_key: bool = False, + null: Optional[bool] = None, + default: Optional[Enum] = None, + attr_name: Optional[str] = None, ): if isinstance(enum, Enum): raise ValueError("enum must be Enum class") @@ -69,7 +72,7 @@ def __init__( super().__init__( hash_key=hash_key, range_key=range_key, - default=default, + default=default.value if default else None, null=null, attr_name=attr_name, ) @@ -93,13 +96,13 @@ def deserialize(self, value: str) -> str: class EnumUnicodeAttribute(UnicodeAttribute): def __init__( - self, - hash_key=False, - range_key=False, - null=None, - default=None, - attr_name=None, - enum=None, + self, + hash_key=False, + range_key=False, + null=None, + default=None, + attr_name=None, + enum=None, ): if isinstance(enum, Enum): raise ValueError("enum must be Enum class") diff --git a/src/pynamodb_utils/conditions.py b/src/pynamodb_utils/conditions.py index 5663e3a..4e8f3ea 100644 --- a/src/pynamodb_utils/conditions.py +++ b/src/pynamodb_utils/conditions.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Set, Optional from pynamodb.attributes import Attribute from pynamodb.expressions.condition import Condition @@ -17,8 +17,8 @@ def create_model_condition( args: Dict[str, Any], _operator: Callable = operator.and_, raise_exception: bool = True, - unavailable_attributes: List[str] = [] -) -> Condition: + unavailable_attributes: Optional[List[str]] = None +) -> Optional[Condition]: """ Function creates pynamodb conditions based on input dictionary (args) Parameters: @@ -34,7 +34,7 @@ def create_model_condition( available_attributes: Set[str] = get_available_attributes_list( model=model, - unavaiable_attrs=unavailable_attributes + unavailable_attrs=unavailable_attributes ) key: str diff --git a/src/pynamodb_utils/exceptions.py b/src/pynamodb_utils/exceptions.py index a77836c..f35d32d 100644 --- a/src/pynamodb_utils/exceptions.py +++ b/src/pynamodb_utils/exceptions.py @@ -10,3 +10,9 @@ class FilterError(Error): class SerializerError(Error): pass + + +class IndexNotFoundError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(self.message) diff --git a/src/pynamodb_utils/models.py b/src/pynamodb_utils/models.py index b766fce..a36aaf4 100644 --- a/src/pynamodb_utils/models.py +++ b/src/pynamodb_utils/models.py @@ -1,5 +1,5 @@ from datetime import timezone -from typing import Any, List +from typing import Any, List, Optional from pynamodb.attributes import UTCDateTimeAttribute from pynamodb.expressions.condition import Condition @@ -20,12 +20,14 @@ def get_conditions_from_json(cls, query: dict, raise_exception: bool = True) -> Parameters: query (dict): A decimal integer + raise_exception (bool): Throwing an exception in case of an error Returns: condition (Condition): computed pynamodb condition """ query_unavailable_attributes: List[str] = getattr(cls.Meta, "query_unavailable_attributes", []) - return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, raise_exception=raise_exception) + return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, + raise_exception=raise_exception) @classmethod def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -> ResultIterator[Model]: @@ -34,6 +36,7 @@ def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) - Parameters: query (dict): A decimal integer + raise_exception (bool): Throwing an exception in case of an error Returns: result_iterator (result_iterator): result iterator for optimized query @@ -69,6 +72,9 @@ def _pop_path(obj: dict, path: str) -> Any: obj = obj[key] +TZ_INFO = "TZINFO" + + class TimestampedModel(Model): created_at = UTCDateTimeAttribute(default=get_timestamp) updated_at = UTCDateTimeAttribute(default=get_timestamp) @@ -77,17 +83,17 @@ class TimestampedModel(Model): class Meta: abstract = True - def save(self, condition=None): - tz_info = getattr(self.Meta, "TZINFO", None) + def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True): + tz_info = getattr(self.Meta, TZ_INFO, None) self.created_at = self.created_at.astimezone(tz=tz_info or timezone.utc) - self.updated_at = get_timestamp(tzinfo=tz_info) - super().save(condition=condition) + self.updated_at = get_timestamp(tz=tz_info) + super().save(condition=condition, add_version_condition=add_version_condition) def save_without_timestamp_update(self, condition=None): super().save(condition=condition) def soft_delete(self, condition=None): """ Puts delete_at timestamp """ - tz_info = getattr(self.Meta, "TZINFO", None) + tz_info = getattr(self.Meta, TZ_INFO, None) self.deleted_at = get_timestamp(tz_info) super().save(condition=condition) diff --git a/src/pynamodb_utils/utils.py b/src/pynamodb_utils/utils.py index fc06df7..f7f18ee 100644 --- a/src/pynamodb_utils/utils.py +++ b/src/pynamodb_utils/utils.py @@ -6,6 +6,7 @@ from pynamodb.models import Model from pynamodb_utils.attributes import DynamicMapAttribute +from pynamodb_utils.exceptions import IndexNotFoundError NoneType = type(None) @@ -31,7 +32,7 @@ def create_index_map( ).get("AttributeName") idx_map[(hash_key, range_key)] = getattr(model, k) except StopIteration as e: - raise Exception("Could not find index keys") from e + raise IndexNotFoundError("Could not find index keys") from e return idx_map @@ -53,7 +54,7 @@ def pick_index_keys( return keys -def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime]: +def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime, str]: """ Function parses attribute to corresponding values """ @@ -87,12 +88,12 @@ def get_attributes_list(model: Model, depth: int = 0) -> List[str]: return attrs -def get_available_attributes_list(model: Model, unavaiable_attrs: List[str] = []) -> Set[str]: +def get_available_attributes_list(model: Model, unavailable_attrs: Optional[List[str]] = None) -> List[str]: attrs: List[str] = get_attributes_list(model) - return [attr for attr in attrs if attr not in unavaiable_attrs] + return sorted(set(attr for attr in attrs if attr not in unavailable_attrs)) -def get_attribute(model: Model, attr_string: str) -> Attribute: +def get_attribute(model: Model, attr_string: str) -> Optional[Attribute]: """ Function gets nested attribute based on path (attr_string) """ @@ -108,5 +109,5 @@ def get_attribute(model: Model, attr_string: str) -> Attribute: return result -def get_timestamp(tzinfo: timezone = None) -> datetime: - return datetime.now(tzinfo or timezone.utc) +def get_timestamp(tz: timezone = None) -> datetime: + return datetime.now(tz or timezone.utc) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 38c8edc..a060427 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -34,7 +34,7 @@ class Post(AsDictModel, JSONQueryModel, TimestampedModel): sub_name = UnicodeAttribute(range_key=True) category = EnumAttribute(enum=CategoryEnum, default=CategoryEnum.finance) content = UnicodeAttribute() - tags = DynamicMapAttribute(default={}) + tags = DynamicMapAttribute(default=None) category_created_at_gsi = PostCategoryCreatedAtGSI() secret_parameter = UnicodeAttribute(default="secret") diff --git a/src/tests/test_enum.py b/src/tests/test_enum.py index 8c912be..18867ba 100644 --- a/src/tests/test_enum.py +++ b/src/tests/test_enum.py @@ -8,20 +8,20 @@ @freeze_time("2019-01-01 00:00:00+00:00") def test_enum_query_not_member_of(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() with pytest.raises(SerializerError) as e: - Post.get_conditions_from_json( + post.get_conditions_from_json( query={ "created_at__lte": str(datetime.now()), "sub_name__exists": None, @@ -30,21 +30,21 @@ def test_enum_query_not_member_of(post_table): "tags.topics__contains": ["NYSE"], } ) - assert e.value.message == { - "Query": {"category": ["1 is not member of finance, politics."]} - } + assert e.value.message == { + "Query": {"category": ["1 is not member of finance, politics."]} + } @freeze_time("2019-01-01 00:00:00+00:00") def test_enum_create_not_member_of(post_table): - Post = post_table + post = post_table with pytest.raises(ValueError) as e: - post = Post( + post = post( name="A weekly news.", content="Last week took place...", category=1, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() - assert str(e.value) == "Value Error: 1 must be in finance, politics" + assert str(e.value) == "Value Error: 1 must be in finance, politics" diff --git a/src/tests/test_general.py b/src/tests/test_general.py index 2c9d7f9..06a4316 100644 --- a/src/tests/test_general.py +++ b/src/tests/test_general.py @@ -8,14 +8,14 @@ @freeze_time("2019-01-01 00:00:00+00:00") def test_general(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() @@ -26,7 +26,7 @@ def test_general(post_table): "OR": {"tags.type__equals": "news", "tags.topics__contains": ["NYSE"]}, } - results = Post.make_index_query(query) + results = post.make_index_query(query) expected = { "content": "Last week took place...", @@ -43,26 +43,26 @@ def test_general(post_table): def test_bad_field(post_table): - Post = post_table - CategoryEnum = Post.category.enum + post = post_table + category_enum = post.category.enum - post = Post( + post = post( name="A weekly news.", sub_name="Shocking revelations", content="Last week took place...", - category=CategoryEnum.finance, + category=category_enum.finance, tags={"type": "news", "topics": ["stock exchange", "NYSE"]}, ) post.save() - with pytest.raises(SerializerError) as e: - Post.get_conditions_from_json(query={"tag.type__equals": "news"}) - assert e.message == { - "Query": { - "tag.type": [ - "Parameter tag does not exist. Choose some of " - "available: category, content, created_at, deleted_at, " - "name, sub_name, tags, updated_at" - ] - } + with pytest.raises(SerializerError) as exc_info: + post.get_conditions_from_json(query={"tag.type__equals": "news"}) + assert exc_info.value.message == { + "Query": { + "tag.type": [ + "Parameter tag.type does not exist. Choose some of " + "available: category, content, created_at, deleted_at, " + "name, sub_name, tags, tags.*, updated_at" + ] } + }