Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AsyncAbstractRepository #134

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion integration_test/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

def extract_python_snippets(content):
# Regular expression pattern for finding Python code blocks
pattern = r'```python(.*?)```'
pattern = r"```python(.*?)```"
snippets = re.findall(pattern, content, re.DOTALL)

return snippets


def evaluate_snippet(snippet):
# Capture the output of the snippet
output_buffer = io.StringIO()
Expand Down
4 changes: 2 additions & 2 deletions phulpyfile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xml.etree.ElementTree as ET
from os import system, unlink
from os import system
from os.path import dirname, join

from phulpy import task
Expand Down Expand Up @@ -46,6 +46,6 @@ def integration_test(phulpy):

@task
def typecheck(phulpy):
result = system('mypy pydantic_mongo test --check-untyped-defs')
result = system("mypy pydantic_mongo test --check-untyped-defs")
if result:
raise Exception("lint test failed")
1 change: 1 addition & 0 deletions pydantic_mongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

__all__ = [
"AbstractRepository",
"AsyncAbstractRepository",
"ObjectIdField",
"ObjectIdAnnotation",
"PydanticObjectId",
Expand Down
145 changes: 18 additions & 127 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,33 @@
from typing import (
Any,
Dict,
Generic,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from typing import Any, Dict, Iterable, Optional, Type, Union, cast

from pydantic import BaseModel
from pymongo import UpdateOne
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.results import InsertOneResult, UpdateResult

from .pagination import (
Edge,
decode_pagination_cursor,
encode_pagination_cursor,
get_pagination_cursor_payload,
from .base_abstract_repository import (
BaseAbstractRepository,
ModelWithId,
OutputT,
Sort,
T,
)

T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)
Sort = Sequence[Tuple[str, int]]


class ModelWithId(BaseModel):
id: Any
from .pagination import Edge, encode_pagination_cursor, get_pagination_cursor_payload


class AbstractRepository(Generic[T]):
class AbstractRepository(BaseAbstractRepository[T]):
class Meta:
collection_name: str

def __init__(self, database: Database):
super().__init__()
self.__database: Database = database
self.__document_class = (
getattr(self.Meta, "document_class")
if hasattr(self.Meta, "document_class")
else self.__orig_bases__[0].__args__[0] # type: ignore
)
self.__collection_name = self.Meta.collection_name
self.__validate()

"""
Get pymongo collection
"""
super().__init__()

def get_collection(self) -> Collection:
return self.__database[self.__collection_name]

def __validate(self):
if "id" not in self.__document_class.model_fields:
raise Exception("Document class should have id field")
if not self.__collection_name:
raise Exception("Meta should contain collection name")

@staticmethod
def to_document(model: T) -> dict:
"""
Convert model to document
:param model:
:return: dict
"""
model_with_id = cast(ModelWithId, model)
data = model_with_id.model_dump()
data.pop("id")
if model_with_id.id:
data["_id"] = model_with_id.id
return data

def __map_id(self, data: dict) -> dict:
query = data.copy()
if "id" in data:
query["_id"] = query.pop("id")
return query

def __map_sort(self, sort: Sort) -> Optional[Sort]:
result = []
for item in sort:
key = item[0]
ordering = item[1]
if key == "id":
key = "_id"
result.append((key, ordering))
return result

def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT:
"""
Convert document to model with custom output type
"""
data_copy = data.copy()
if "_id" in data_copy:
data_copy["id"] = data_copy.pop("_id")
return output_type.model_validate(data_copy)

def to_model(self, data: dict) -> T:
"""
Convert document to model
Get pymongo collection
"""
return self.to_model_custom(self.__document_class, data)
return self.__database[self._collection_name]

def save(self, model: T) -> Union[InsertOneResult, UpdateResult]:
"""
Expand Down Expand Up @@ -174,7 +96,7 @@ def find_one_by(self, query: dict) -> Optional[T]:
"""
Find entity by mongo query
"""
result = self.get_collection().find_one(self.__map_id(query))
result = self.get_collection().find_one(self._map_id(query))
return self.to_model(result) if result else None

def find_by_with_output_type(
Expand All @@ -196,9 +118,9 @@ def find_by_with_output_type(
:param projection:
:return:
"""
mapped_projection = self.__map_id(projection) if projection else None
mapped_sort = self.__map_sort(sort) if sort else None
cursor = self.get_collection().find(self.__map_id(query), mapped_projection)
mapped_projection = self._map_id(projection) if projection else None
mapped_sort = self._map_sort(sort) if sort else None
cursor = self.get_collection().find(self._map_id(query), mapped_projection)
if limit:
cursor.limit(limit)
if skip:
Expand All @@ -219,45 +141,14 @@ def find_by(
Find entities by mongo query
"""
return self.find_by_with_output_type(
output_type=self.__document_class,
output_type=self._document_class,
query=query,
skip=skip,
limit=limit,
sort=sort,
projection=projection,
)

def get_pagination_query(
self,
query: dict,
after: Optional[str] = None,
before: Optional[str] = None,
sort: Optional[Sort] = None,
) -> dict:
"""
Build pagination query based on the cursor and sort
"""
generated_query: dict = {"$and": [query]}
selected_cursor = after or before

if selected_cursor and sort:
cursor_data = decode_pagination_cursor(selected_cursor)
dict_values = []
for i, sort_expression in enumerate(sort):
if after:
compare_operator = "$gt" if sort_expression[1] > 0 else "$lt"
else:
compare_operator = "$lt" if sort_expression[1] > 0 else "$gt"
dict_values.append(
(sort_expression[0], {compare_operator: cursor_data[i]})
)
generated_query["$and"].append(dict(dict_values))

if len(generated_query["$and"]) == 1:
generated_query = query or {}

return generated_query

def paginate_with_output_type(
self,
output_type: Type[OutputT],
Expand Down Expand Up @@ -314,7 +205,7 @@ def paginate(
Return type is an iterable of Edge objects, which contain the model and the cursor
"""
return self.paginate_with_output_type(
self.__document_class,
self._document_class,
query,
limit,
after=after,
Expand Down
Loading
Loading