Skip to content

Commit

Permalink
Refactoring and fixing minor problems
Browse files Browse the repository at this point in the history
  • Loading branch information
bmlnr committed Jan 29, 2024
1 parent b77ab9d commit cdd6dd2
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 63 deletions.
31 changes: 17 additions & 14 deletions src/pynamodb_utils/attributes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
from typing import Collection, FrozenSet, Union
from typing import Collection, FrozenSet, Union, Optional

import six
from pynamodb.constants import NUMBER
from pynamodb.attributes import MapAttribute, NumberAttribute, UnicodeAttribute


Expand Down Expand Up @@ -54,22 +55,24 @@ def __str__(self) -> str:


class EnumNumberAttribute(NumberAttribute):
attr_type = NUMBER

def __init__(
self,
enum,
hash_key=False,
range_key=False,
null=None,
default=None,
attr_name=None,
hash_key: bool = False,
range_key: bool = False,
null: Optional[bool] = None,
default: Optional[Enum] = None,
attr_name: Optional[str] = None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
self.enum = enum
super().__init__(
hash_key=hash_key,
range_key=range_key,
default=default,
default=default.value if default else None,
null=null,
attr_name=attr_name,
)
Expand All @@ -93,13 +96,13 @@ def deserialize(self, value: str) -> str:

class EnumUnicodeAttribute(UnicodeAttribute):
def __init__(
self,
hash_key=False,
range_key=False,
null=None,
default=None,
attr_name=None,
enum=None,
self,
hash_key=False,
range_key=False,
null=None,
default=None,
attr_name=None,
enum=None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
Expand Down
8 changes: 4 additions & 4 deletions src/pynamodb_utils/conditions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import operator
from functools import reduce
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Set, Optional

from pynamodb.attributes import Attribute
from pynamodb.expressions.condition import Condition
Expand All @@ -17,8 +17,8 @@ def create_model_condition(
args: Dict[str, Any],
_operator: Callable = operator.and_,
raise_exception: bool = True,
unavailable_attributes: List[str] = []
) -> Condition:
unavailable_attributes: Optional[List[str]] = None
) -> Optional[Condition]:
"""
Function creates pynamodb conditions based on input dictionary (args)
Parameters:
Expand All @@ -34,7 +34,7 @@ def create_model_condition(

available_attributes: Set[str] = get_available_attributes_list(
model=model,
unavaiable_attrs=unavailable_attributes
unavailable_attrs=unavailable_attributes
)

key: str
Expand Down
6 changes: 6 additions & 0 deletions src/pynamodb_utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ class FilterError(Error):

class SerializerError(Error):
pass


class IndexNotFoundError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
20 changes: 13 additions & 7 deletions src/pynamodb_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import timezone
from typing import Any, List
from typing import Any, List, Optional

from pynamodb.attributes import UTCDateTimeAttribute
from pynamodb.expressions.condition import Condition
Expand All @@ -20,12 +20,14 @@ def get_conditions_from_json(cls, query: dict, raise_exception: bool = True) ->
Parameters:
query (dict): A decimal integer
raise_exception (bool): Throwing an exception in case of an error
Returns:
condition (Condition): computed pynamodb condition
"""
query_unavailable_attributes: List[str] = getattr(cls.Meta, "query_unavailable_attributes", [])
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, raise_exception=raise_exception)
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query,
raise_exception=raise_exception)

@classmethod
def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -> ResultIterator[Model]:
Expand All @@ -34,6 +36,7 @@ def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -
Parameters:
query (dict): A decimal integer
raise_exception (bool): Throwing an exception in case of an error
Returns:
result_iterator (result_iterator): result iterator for optimized query
Expand Down Expand Up @@ -69,6 +72,9 @@ def _pop_path(obj: dict, path: str) -> Any:
obj = obj[key]


TZ_INFO = "TZINFO"


class TimestampedModel(Model):
created_at = UTCDateTimeAttribute(default=get_timestamp)
updated_at = UTCDateTimeAttribute(default=get_timestamp)
Expand All @@ -77,17 +83,17 @@ class TimestampedModel(Model):
class Meta:
abstract = True

def save(self, condition=None):
tz_info = getattr(self.Meta, "TZINFO", None)
def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True):
tz_info = getattr(self.Meta, TZ_INFO, None)
self.created_at = self.created_at.astimezone(tz=tz_info or timezone.utc)
self.updated_at = get_timestamp(tzinfo=tz_info)
super().save(condition=condition)
self.updated_at = get_timestamp(tz=tz_info)
super().save(condition=condition, add_version_condition=add_version_condition)

def save_without_timestamp_update(self, condition=None):
super().save(condition=condition)

def soft_delete(self, condition=None):
""" Puts delete_at timestamp """
tz_info = getattr(self.Meta, "TZINFO", None)
tz_info = getattr(self.Meta, TZ_INFO, None)
self.deleted_at = get_timestamp(tz_info)
super().save(condition=condition)
15 changes: 8 additions & 7 deletions src/pynamodb_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pynamodb.models import Model

from pynamodb_utils.attributes import DynamicMapAttribute
from pynamodb_utils.exceptions import IndexNotFoundError

NoneType = type(None)

Expand All @@ -31,7 +32,7 @@ def create_index_map(
).get("AttributeName")
idx_map[(hash_key, range_key)] = getattr(model, k)
except StopIteration as e:
raise Exception("Could not find index keys") from e
raise IndexNotFoundError("Could not find index keys") from e

return idx_map

Expand All @@ -53,7 +54,7 @@ def pick_index_keys(
return keys


def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime]:
def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime, str]:
"""
Function parses attribute to corresponding values
"""
Expand Down Expand Up @@ -87,12 +88,12 @@ def get_attributes_list(model: Model, depth: int = 0) -> List[str]:
return attrs


def get_available_attributes_list(model: Model, unavaiable_attrs: List[str] = []) -> Set[str]:
def get_available_attributes_list(model: Model, unavailable_attrs: Optional[List[str]] = None) -> List[str]:
attrs: List[str] = get_attributes_list(model)
return [attr for attr in attrs if attr not in unavaiable_attrs]
return sorted(set(attr for attr in attrs if attr not in unavailable_attrs))


def get_attribute(model: Model, attr_string: str) -> Attribute:
def get_attribute(model: Model, attr_string: str) -> Optional[Attribute]:
"""
Function gets nested attribute based on path (attr_string)
"""
Expand All @@ -108,5 +109,5 @@ def get_attribute(model: Model, attr_string: str) -> Attribute:
return result


def get_timestamp(tzinfo: timezone = None) -> datetime:
return datetime.now(tzinfo or timezone.utc)
def get_timestamp(tz: timezone = None) -> datetime:
return datetime.now(tz or timezone.utc)
2 changes: 1 addition & 1 deletion src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Post(AsDictModel, JSONQueryModel, TimestampedModel):
sub_name = UnicodeAttribute(range_key=True)
category = EnumAttribute(enum=CategoryEnum, default=CategoryEnum.finance)
content = UnicodeAttribute()
tags = DynamicMapAttribute(default={})
tags = DynamicMapAttribute(default=None)
category_created_at_gsi = PostCategoryCreatedAtGSI()
secret_parameter = UnicodeAttribute(default="secret")

Expand Down
22 changes: 11 additions & 11 deletions src/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@

@freeze_time("2019-01-01 00:00:00+00:00")
def test_enum_query_not_member_of(post_table):
Post = post_table
CategoryEnum = Post.category.enum
post = post_table
category_enum = post.category.enum

post = Post(
post = post(
name="A weekly news.",
sub_name="Shocking revelations",
content="Last week took place...",
category=CategoryEnum.finance,
category=category_enum.finance,
tags={"type": "news", "topics": ["stock exchange", "NYSE"]},
)
post.save()

with pytest.raises(SerializerError) as e:
Post.get_conditions_from_json(
post.get_conditions_from_json(
query={
"created_at__lte": str(datetime.now()),
"sub_name__exists": None,
Expand All @@ -30,21 +30,21 @@ def test_enum_query_not_member_of(post_table):
"tags.topics__contains": ["NYSE"],
}
)
assert e.value.message == {
"Query": {"category": ["1 is not member of finance, politics."]}
}
assert e.value.message == {
"Query": {"category": ["1 is not member of finance, politics."]}
}


@freeze_time("2019-01-01 00:00:00+00:00")
def test_enum_create_not_member_of(post_table):
Post = post_table
post = post_table

with pytest.raises(ValueError) as e:
post = Post(
post = post(
name="A weekly news.",
content="Last week took place...",
category=1,
tags={"type": "news", "topics": ["stock exchange", "NYSE"]},
)
post.save()
assert str(e.value) == "Value Error: 1 must be in finance, politics"
assert str(e.value) == "Value Error: 1 must be in finance, politics"
38 changes: 19 additions & 19 deletions src/tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

@freeze_time("2019-01-01 00:00:00+00:00")
def test_general(post_table):
Post = post_table
CategoryEnum = Post.category.enum
post = post_table
category_enum = post.category.enum

post = Post(
post = post(
name="A weekly news.",
sub_name="Shocking revelations",
content="Last week took place...",
category=CategoryEnum.finance,
category=category_enum.finance,
tags={"type": "news", "topics": ["stock exchange", "NYSE"]},
)
post.save()
Expand All @@ -26,7 +26,7 @@ def test_general(post_table):
"OR": {"tags.type__equals": "news", "tags.topics__contains": ["NYSE"]},
}

results = Post.make_index_query(query)
results = post.make_index_query(query)

expected = {
"content": "Last week took place...",
Expand All @@ -43,26 +43,26 @@ def test_general(post_table):


def test_bad_field(post_table):
Post = post_table
CategoryEnum = Post.category.enum
post = post_table
category_enum = post.category.enum

post = Post(
post = post(
name="A weekly news.",
sub_name="Shocking revelations",
content="Last week took place...",
category=CategoryEnum.finance,
category=category_enum.finance,
tags={"type": "news", "topics": ["stock exchange", "NYSE"]},
)
post.save()

with pytest.raises(SerializerError) as e:
Post.get_conditions_from_json(query={"tag.type__equals": "news"})
assert e.message == {
"Query": {
"tag.type": [
"Parameter tag does not exist. Choose some of "
"available: category, content, created_at, deleted_at, "
"name, sub_name, tags, updated_at"
]
}
with pytest.raises(SerializerError) as exc_info:
post.get_conditions_from_json(query={"tag.type__equals": "news"})
assert exc_info.value.message == {
"Query": {
"tag.type": [
"Parameter tag.type does not exist. Choose some of "
"available: category, content, created_at, deleted_at, "
"name, sub_name, tags, tags.*, updated_at"
]
}
}

0 comments on commit cdd6dd2

Please sign in to comment.