Skip to content

Commit

Permalink
feat: Added pynamodb 6 support
Browse files Browse the repository at this point in the history
  • Loading branch information
micmurawski committed Jan 30, 2024
2 parents 376792d + b7692cf commit dadad06
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 127 deletions.
38 changes: 21 additions & 17 deletions src/pynamodb_utils/attributes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
from enum import Enum
from typing import Collection, FrozenSet, Union
from typing import Collection, FrozenSet, Optional, Union

import six
from pynamodb.attributes import MapAttribute, NumberAttribute, UnicodeAttribute
from pynamodb.constants import NUMBER

from pynamodb_utils.exceptions import EnumSerializationException


class DynamicMapAttribute(MapAttribute):
Expand Down Expand Up @@ -58,15 +61,17 @@ def __str__(self) -> str:


class EnumNumberAttribute(NumberAttribute):
attr_type = NUMBER

def __init__(
self,
enum: Enum,
hash_key=False,
range_key=False,
null=None,
default: Enum = None,
default_for_new: Enum = None,
attr_name=None,
hash_key: bool = False,
range_key: bool = False,
null: Optional[bool] = None,
default: Optional[Enum] = None,
default_for_new: Optional[Enum] = None,
attr_name: Optional[str] = None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
Expand Down Expand Up @@ -97,7 +102,7 @@ def serialize(self, value: Union[Enum, str]) -> str:
f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}'
)
except TypeError as e:
raise Exception(value, self.enum) from e
raise EnumSerializationException(f"Error serializing {value} with enum {self.enum}") from e

def deserialize(self, value: str) -> str:
return self.enum(int(value)).name
Expand All @@ -107,12 +112,12 @@ class EnumUnicodeAttribute(UnicodeAttribute):
def __init__(
self,
enum: Enum,
hash_key=False,
range_key=False,
null=None,
default: Enum = None,
default_for_new: Enum = None,
attr_name=None,
hash_key: bool = False,
range_key: bool = False,
null: Optional[bool] = None,
default: Optional[Enum] = None,
default_for_new: Optional[Enum] = None,
attr_name: Optional[str] = None,
):
if isinstance(enum, Enum):
raise ValueError("enum must be Enum class")
Expand All @@ -135,9 +140,8 @@ def __init__(
def serialize(self, value: Union[Enum, str]) -> str:
if isinstance(value, self.enum):
return str(value.value)
elif isinstance(value, str):
if value in self.enum.__members__.keys():
return getattr(self.enum, value).value
elif isinstance(value, str) and value in self.enum.__members__.keys():
return getattr(self.enum, value).value
raise ValueError(
f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}'
)
Expand Down
76 changes: 35 additions & 41 deletions src/pynamodb_utils/conditions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import operator
from functools import reduce
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Optional

from pynamodb.attributes import Attribute
from pynamodb.expressions.condition import Condition
Expand All @@ -12,13 +12,30 @@
from pynamodb_utils.utils import get_attribute, get_available_attributes_list


def _is_available(field_path: str, available_attributes: List, raise_exception: bool):
if "." in field_path:
_field_path = field_path.split(".", 1)[0] + ".*"
is_available = _field_path in available_attributes
else:
is_available = field_path in available_attributes
if not is_available and raise_exception:
raise FilterError(
message={
field_path: [
f"Parameter {field_path} does not exist."
f" Choose some of available: {', '.join(available_attributes)}"
]
}
)


def create_model_condition(
model: Model,
args: Dict[str, Any],
_operator: Callable = operator.and_,
raise_exception: bool = True,
unavailable_attributes: List[str] = []
) -> Condition:
model: Model,
args: Dict[str, Any],
_operator: Callable = operator.and_,
raise_exception: bool = True,
unavailable_attributes: Optional[List[str]] = None
) -> Optional[Condition]:
"""
Function creates pynamodb conditions based on input dictionary (args)
Parameters:
Expand All @@ -31,52 +48,29 @@ def create_model_condition(
condition (Condition): computed pynamodb condition
"""
conditions_list: List[Condition] = []

available_attributes: Set[str] = get_available_attributes_list(
available_attributes: List[str] = get_available_attributes_list(
model=model,
unavaiable_attrs=unavailable_attributes
unavailable_attrs=unavailable_attributes
)

key: str
value: Any
for key, value in args.items():
array: List[str] = key.rsplit('__', 1)
array: List[str] = key.rsplit("__", 1)
field_path: str = array[0]
operator_name: str = array[1] if len(array) > 1 and array[1] != 'not' else ''

if "." in field_path:
_field_path = field_path.split(".", 1)[0] + ".*"
is_available = _field_path in available_attributes
else:

is_available = field_path in available_attributes

if operator_name.replace('not_', '') not in OPERATORS_MAPPING:
raise FilterError(
message={key: [f'Operator {operator_name} does not exist.'
f' Choose some of available: {", ".join(OPERATORS_MAPPING.keys())}']}
)
if not is_available and raise_exception:
operator_name: str = array[1] if len(array) > 1 and array[1] != "not" else ""
if operator_name.replace("not_", "") not in OPERATORS_MAPPING:
raise FilterError(
message={
field_path: [
f"Parameter {field_path} does not exist."
f' Choose some of available: {", ".join(available_attributes)}'
]
}
message={key: [f"Operator {operator_name} does not exist."
f" Choose some of available: {', '.join(OPERATORS_MAPPING.keys())}"]}
)

_is_available(field_path, available_attributes, raise_exception)
attr: Attribute = get_attribute(model, field_path)

if isinstance(attr, (Attribute, Path)):
if 'not_' in operator_name:
operator_name = operator_name.replace('not_', '')
operator_name = operator_name.replace("not_", "")
operator_handler = OPERATORS_MAPPING[operator_name]
conditions_list.append(~operator_handler(model, field_path, attr, value))
else:
operator_handler = OPERATORS_MAPPING[operator_name]
conditions_list.append(operator_handler(model, field_path, attr, value))
if not conditions_list:
return None
else:
if conditions_list:
return reduce(_operator, conditions_list)
return None
8 changes: 8 additions & 0 deletions src/pynamodb_utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ class FilterError(Error):

class SerializerError(Error):
pass


class IndexNotFoundError(Exception):
pass


class EnumSerializationException(Exception):
pass
20 changes: 13 additions & 7 deletions src/pynamodb_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import timezone
from typing import Any, List
from typing import Any, List, Optional

from pynamodb.attributes import UTCDateTimeAttribute
from pynamodb.expressions.condition import Condition
Expand All @@ -20,12 +20,14 @@ def get_conditions_from_json(cls, query: dict, raise_exception: bool = True) ->
Parameters:
query (dict): A decimal integer
raise_exception (bool): Throwing an exception in case of an error
Returns:
condition (Condition): computed pynamodb condition
"""
query_unavailable_attributes: List[str] = getattr(cls.Meta, "query_unavailable_attributes", [])
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, raise_exception=raise_exception)
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query,
raise_exception=raise_exception)

@classmethod
def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -> ResultIterator[Model]:
Expand All @@ -34,6 +36,7 @@ def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -
Parameters:
query (dict): A decimal integer
raise_exception (bool): Throwing an exception in case of an error
Returns:
result_iterator (result_iterator): result iterator for optimized query
Expand Down Expand Up @@ -69,6 +72,9 @@ def _pop_path(obj: dict, path: str) -> Any:
obj = obj[key]


TZ_INFO = "TZINFO"


class TimestampedModel(Model):
created_at = UTCDateTimeAttribute(default=get_timestamp)
updated_at = UTCDateTimeAttribute(default=get_timestamp)
Expand All @@ -77,17 +83,17 @@ class TimestampedModel(Model):
class Meta:
abstract = True

def save(self, condition=None):
tz_info = getattr(self.Meta, "TZINFO", None)
def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True):
tz_info = getattr(self.Meta, TZ_INFO, None)
self.created_at = self.created_at.astimezone(tz=tz_info or timezone.utc)
self.updated_at = get_timestamp(tzinfo=tz_info)
super().save(condition=condition)
self.updated_at = get_timestamp(tz=tz_info)
super().save(condition=condition, add_version_condition=add_version_condition)

def save_without_timestamp_update(self, condition=None):
super().save(condition=condition)

def soft_delete(self, condition=None):
""" Puts delete_at timestamp """
tz_info = getattr(self.Meta, "TZINFO", None)
tz_info = getattr(self.Meta, TZ_INFO, None)
self.deleted_at = get_timestamp(tz_info)
super().save(condition=condition)
4 changes: 2 additions & 2 deletions src/pynamodb_utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def default_list_parser(value: List[Any], field_name: str, model: Model) -> List
raise FilterError(message={field_name: [f"{value} is not valid type of {field_name}."]})


def default_dict_parser(value: Dict, field_name: str, *args) -> Dict[Any, Any]:
def default_dict_parser(value: Dict, field_name: str, *args) -> Union[Dict[Any, Any], str]:
if isinstance(value, (dict, NoneType)):
return value
elif isinstance(value, str):
try:
return json.dumps(value, default=str)
except (ValueError, json.JSONDecodeError):
except ValueError:
pass
raise FilterError(
message={field_name: [f"{value} is not valid type of {field_name}."]}
Expand Down
19 changes: 11 additions & 8 deletions src/pynamodb_utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from pynamodb.attributes import Attribute, MapAttribute
from pynamodb.indexes import GlobalSecondaryIndex, LocalSecondaryIndex
from pynamodb.models import Model

from pynamodb_utils.attributes import DynamicMapAttribute
from pynamodb_utils.exceptions import IndexNotFoundError

NoneType = type(None)

Expand All @@ -31,7 +32,7 @@ def create_index_map(
).get("AttributeName")
idx_map[(hash_key, range_key)] = getattr(model, k)
except StopIteration as e:
raise Exception("Could not find index keys") from e
raise IndexNotFoundError("Could not find index keys") from e

return idx_map

Expand All @@ -53,12 +54,14 @@ def pick_index_keys(
return keys


def parse_attr(attr: Attribute) -> Union[Dict, datetime]:
def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime, str]:
"""
Function parses attribute to corresponding values
"""
if isinstance(attr, DynamicMapAttribute):
return attr.as_dict()
elif isinstance(attr, List):
return [parse_attr(el) for el in attr]
elif isinstance(attr, MapAttribute):
return parse_attrs_to_dict(attr)
elif isinstance(attr, datetime):
Expand All @@ -85,12 +88,12 @@ def get_attributes_list(model: Model, depth: int = 0) -> List[str]:
return attrs


def get_available_attributes_list(model: Model, unavaiable_attrs: List[str] = []) -> Set[str]:
def get_available_attributes_list(model: Model, unavailable_attrs: Optional[List[str]] = None) -> List[str]:
attrs: List[str] = get_attributes_list(model)
return [attr for attr in attrs if attr not in unavaiable_attrs]
return sorted(set(attr for attr in attrs if attr not in unavailable_attrs))


def get_attribute(model: Model, attr_string: str) -> Attribute:
def get_attribute(model: Model, attr_string: str) -> Optional[Attribute]:
"""
Function gets nested attribute based on path (attr_string)
"""
Expand All @@ -106,5 +109,5 @@ def get_attribute(model: Model, attr_string: str) -> Attribute:
return result


def get_timestamp(tzinfo: timezone = None) -> datetime:
return datetime.now(tzinfo or timezone.utc)
def get_timestamp(tz: timezone = None) -> datetime:
return datetime.now(tz or timezone.utc)
27 changes: 11 additions & 16 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,30 @@
from datetime import timezone

import pytest
from moto import mock_dynamodb
from moto import mock_aws
from pynamodb.attributes import UnicodeAttribute, UTCDateTimeAttribute
from pynamodb.indexes import AllProjection, GlobalSecondaryIndex

from pynamodb_utils import AsDictModel, DynamicMapAttribute, EnumAttribute, JSONQueryModel, TimestampedModel


@pytest.fixture(scope="session")
@pytest.fixture
def aws_environ():
vars = {
env_vars = {
"AWS_DEFAULT_REGION": "us-east-1"
}
for k, v in vars.items():
os.environ[k] = v

yield

for k in vars:
del os.environ[k]
with mock_aws():
for k, v in env_vars.items():
os.environ[k] = v


@pytest.fixture
def dynamodb(aws_environ):
with mock_dynamodb():
yield

for k in env_vars:
del os.environ[k]


@pytest.fixture
def post_table(dynamodb):
def post_table(aws_environ):
class CategoryEnum(enum.Enum):
finance = enum.auto()
politics = enum.auto()
Expand All @@ -49,7 +44,7 @@ class Post(AsDictModel, JSONQueryModel, TimestampedModel):
sub_name = UnicodeAttribute(range_key=True)
category = EnumAttribute(enum=CategoryEnum, default=CategoryEnum.finance)
content = UnicodeAttribute()
tags = DynamicMapAttribute(default={})
tags = DynamicMapAttribute(default=None)
category_created_at_gsi = PostCategoryCreatedAtGSI()
secret_parameter = UnicodeAttribute(default="secret")

Expand Down
Loading

0 comments on commit dadad06

Please sign in to comment.