From 23c8aaaed4a6a4eefc98443387147baa0c63f839 Mon Sep 17 00:00:00 2001 From: ivan <79514623510@yandex.ru> Date: Mon, 30 Sep 2024 12:03:11 +0500 Subject: [PATCH 1/4] add AsyncAbstractRepository(#133) --- integration_test/test_readme.py | 3 +- phulpyfile.py | 4 +- pydantic_mongo/__init__.py | 1 + pydantic_mongo/abstract_repository.py | 145 ++--------- pydantic_mongo/async_abstract_repository.py | 218 ++++++++++++++++ pydantic_mongo/base_abstract_repository.py | 111 ++++++++ requirements_test.txt | 3 +- test/conftest.py | 1 - test/test_async_repository.py | 275 ++++++++++++++++++++ test/test_enhance_meta.py | 5 +- test/test_fields.py | 1 - test/test_repository.py | 4 +- 12 files changed, 633 insertions(+), 138 deletions(-) create mode 100644 pydantic_mongo/async_abstract_repository.py create mode 100644 pydantic_mongo/base_abstract_repository.py create mode 100644 test/test_async_repository.py diff --git a/integration_test/test_readme.py b/integration_test/test_readme.py index 703e0e0..f544c57 100644 --- a/integration_test/test_readme.py +++ b/integration_test/test_readme.py @@ -6,11 +6,12 @@ def extract_python_snippets(content): # Regular expression pattern for finding Python code blocks - pattern = r'```python(.*?)```' + pattern = r"```python(.*?)```" snippets = re.findall(pattern, content, re.DOTALL) return snippets + def evaluate_snippet(snippet): # Capture the output of the snippet output_buffer = io.StringIO() diff --git a/phulpyfile.py b/phulpyfile.py index d0fc7fa..50e9adc 100644 --- a/phulpyfile.py +++ b/phulpyfile.py @@ -1,5 +1,5 @@ import xml.etree.ElementTree as ET -from os import system, unlink +from os import system from os.path import dirname, join from phulpy import task @@ -46,6 +46,6 @@ def integration_test(phulpy): @task def typecheck(phulpy): - result = system('mypy pydantic_mongo test --check-untyped-defs') + result = system("mypy pydantic_mongo test --check-untyped-defs") if result: raise Exception("lint test failed") diff --git a/pydantic_mongo/__init__.py b/pydantic_mongo/__init__.py index 83bfe65..c4e6a07 100644 --- a/pydantic_mongo/__init__.py +++ b/pydantic_mongo/__init__.py @@ -4,6 +4,7 @@ __all__ = [ "AbstractRepository", + "AsyncAbstractRepository", "ObjectIdField", "ObjectIdAnnotation", "PydanticObjectId", diff --git a/pydantic_mongo/abstract_repository.py b/pydantic_mongo/abstract_repository.py index 112f961..d2fd197 100644 --- a/pydantic_mongo/abstract_repository.py +++ b/pydantic_mongo/abstract_repository.py @@ -1,111 +1,33 @@ -from typing import ( - Any, - Dict, - Generic, - Iterable, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) +from typing import Any, Dict, Iterable, Optional, Type, Union, cast -from pydantic import BaseModel from pymongo import UpdateOne from pymongo.collection import Collection from pymongo.database import Database from pymongo.results import InsertOneResult, UpdateResult -from .pagination import ( - Edge, - decode_pagination_cursor, - encode_pagination_cursor, - get_pagination_cursor_payload, +from .base_abstract_repository import ( + BaseAbstractRepository, + ModelWithId, + OutputT, + Sort, + T, ) - -T = TypeVar("T", bound=BaseModel) -OutputT = TypeVar("OutputT", bound=BaseModel) -Sort = Sequence[Tuple[str, int]] - - -class ModelWithId(BaseModel): - id: Any +from .pagination import Edge, encode_pagination_cursor, get_pagination_cursor_payload -class AbstractRepository(Generic[T]): +class AbstractRepository(BaseAbstractRepository[T]): class Meta: collection_name: str def __init__(self, database: Database): - super().__init__() self.__database: Database = database - self.__document_class = ( - getattr(self.Meta, "document_class") - if hasattr(self.Meta, "document_class") - else self.__orig_bases__[0].__args__[0] # type: ignore - ) - self.__collection_name = self.Meta.collection_name - self.__validate() - - """ - Get pymongo collection - """ + super().__init__() def get_collection(self) -> Collection: - return self.__database[self.__collection_name] - - def __validate(self): - if "id" not in self.__document_class.model_fields: - raise Exception("Document class should have id field") - if not self.__collection_name: - raise Exception("Meta should contain collection name") - - @staticmethod - def to_document(model: T) -> dict: - """ - Convert model to document - :param model: - :return: dict - """ - model_with_id = cast(ModelWithId, model) - data = model_with_id.model_dump() - data.pop("id") - if model_with_id.id: - data["_id"] = model_with_id.id - return data - - def __map_id(self, data: dict) -> dict: - query = data.copy() - if "id" in data: - query["_id"] = query.pop("id") - return query - - def __map_sort(self, sort: Sort) -> Optional[Sort]: - result = [] - for item in sort: - key = item[0] - ordering = item[1] - if key == "id": - key = "_id" - result.append((key, ordering)) - return result - - def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT: - """ - Convert document to model with custom output type - """ - data_copy = data.copy() - if "_id" in data_copy: - data_copy["id"] = data_copy.pop("_id") - return output_type.model_validate(data_copy) - - def to_model(self, data: dict) -> T: """ - Convert document to model + Get pymongo collection """ - return self.to_model_custom(self.__document_class, data) + return self.__database[self._collection_name] def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: """ @@ -174,7 +96,7 @@ def find_one_by(self, query: dict) -> Optional[T]: """ Find entity by mongo query """ - result = self.get_collection().find_one(self.__map_id(query)) + result = self.get_collection().find_one(self._map_id(query)) return self.to_model(result) if result else None def find_by_with_output_type( @@ -196,9 +118,9 @@ def find_by_with_output_type( :param projection: :return: """ - mapped_projection = self.__map_id(projection) if projection else None - mapped_sort = self.__map_sort(sort) if sort else None - cursor = self.get_collection().find(self.__map_id(query), mapped_projection) + mapped_projection = self._map_id(projection) if projection else None + mapped_sort = self._map_sort(sort) if sort else None + cursor = self.get_collection().find(self._map_id(query), mapped_projection) if limit: cursor.limit(limit) if skip: @@ -219,7 +141,7 @@ def find_by( Find entities by mongo query """ return self.find_by_with_output_type( - output_type=self.__document_class, + output_type=self._document_class, query=query, skip=skip, limit=limit, @@ -227,37 +149,6 @@ def find_by( projection=projection, ) - def get_pagination_query( - self, - query: dict, - after: Optional[str] = None, - before: Optional[str] = None, - sort: Optional[Sort] = None, - ) -> dict: - """ - Build pagination query based on the cursor and sort - """ - generated_query: dict = {"$and": [query]} - selected_cursor = after or before - - if selected_cursor and sort: - cursor_data = decode_pagination_cursor(selected_cursor) - dict_values = [] - for i, sort_expression in enumerate(sort): - if after: - compare_operator = "$gt" if sort_expression[1] > 0 else "$lt" - else: - compare_operator = "$lt" if sort_expression[1] > 0 else "$gt" - dict_values.append( - (sort_expression[0], {compare_operator: cursor_data[i]}) - ) - generated_query["$and"].append(dict(dict_values)) - - if len(generated_query["$and"]) == 1: - generated_query = query or {} - - return generated_query - def paginate_with_output_type( self, output_type: Type[OutputT], @@ -314,7 +205,7 @@ def paginate( Return type is an iterable of Edge objects, which contain the model and the cursor """ return self.paginate_with_output_type( - self.__document_class, + self._document_class, query, limit, after=after, diff --git a/pydantic_mongo/async_abstract_repository.py b/pydantic_mongo/async_abstract_repository.py new file mode 100644 index 0000000..b29e26e --- /dev/null +++ b/pydantic_mongo/async_abstract_repository.py @@ -0,0 +1,218 @@ +from typing import Any, Dict, Iterable, Optional, Type, Union, cast + +from pymongo import UpdateOne +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.results import InsertOneResult, UpdateResult + +from .base_abstract_repository import ( + BaseAbstractRepository, + ModelWithId, + OutputT, + Sort, + T, +) +from .pagination import Edge, encode_pagination_cursor, get_pagination_cursor_payload + + +class AsyncAbstractRepository(BaseAbstractRepository[T]): + class Meta: + collection_name: str + + def __init__(self, database: AsyncDatabase): + self.__database: AsyncDatabase = database + super().__init__() + + def get_collection(self) -> AsyncCollection: + """ + Get pymongo collection + """ + return self.__database[self._collection_name] + + async def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: + """ + Save entity to database. It will update the entity if it has id, otherwise it will insert it. + """ + document = self.to_document(model) + model_with_id = cast(ModelWithId, model) + + if model_with_id.id: + mongo_id = document.pop("_id") + return await self.get_collection().update_one( + {"_id": mongo_id}, {"$set": document}, upsert=True + ) + + result = await self.get_collection().insert_one(document) + model_with_id.id = result.inserted_id + return result + + async def save_many(self, models: Iterable[T]): + """ + Save multiple entities to database + """ + models_to_insert = [] + models_to_update = [] + + for model in models: + model_with_id = cast(ModelWithId, model) + if model_with_id.id: + models_to_update.append(model) + else: + models_to_insert.append(model) + if len(models_to_insert) > 0: + result = await self.get_collection().insert_many( + (self.to_document(model) for model in models_to_insert) + ) + + for idx, inserted_id in enumerate(result.inserted_ids): + cast(ModelWithId, models_to_insert[idx]).id = inserted_id + + if len(models_to_update) == 0: + return + + documents_to_update = [self.to_document(model) for model in models_to_update] + mongo_ids = [doc.pop("_id") for doc in documents_to_update] + bulk_operations = [ + UpdateOne({"_id": mongo_id}, {"$set": document}, upsert=True) + for mongo_id, document in zip(mongo_ids, documents_to_update) + ] + await self.get_collection().bulk_write(bulk_operations) + + async def delete(self, model: T): + return await self.get_collection().delete_one( + {"_id": cast(ModelWithId, model).id} + ) + + async def delete_by_id(self, _id: Any): + return await self.get_collection().delete_one({"_id": _id}) + + async def find_one_by_id(self, _id: Any) -> Optional[T]: + """ + Find entity by id + + Note: The id should be of the same type as the id field in the document class, ie. ObjectId + """ + return await self.find_one_by({"id": _id}) + + async def find_one_by(self, query: dict) -> Optional[T]: + """ + Find entity by mongo query + """ + result = await self.get_collection().find_one(self._map_id(query)) + return self.to_model(result) if result else None + + async def find_by_with_output_type( + self, + output_type: Type[OutputT], + query: dict, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Optional[Sort] = None, + projection: Optional[Dict[str, int]] = None, + ) -> Iterable[OutputT]: + """ + Find entities by mongo query allowing custom output type + :param output_type: + :param query: + :param skip: + :param limit: + :param sort: + :param projection: + :return: + """ + mapped_projection = self._map_id(projection) if projection else None + mapped_sort = self._map_sort(sort) if sort else None + cursor = self.get_collection().find(self._map_id(query), mapped_projection) + if limit: + cursor.limit(limit) + if skip: + cursor.skip(skip) + if mapped_sort: + cursor.sort(mapped_sort) + + return [self.to_model_custom(output_type, doc) async for doc in cursor] + + async def find_by( + self, + query: dict, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Optional[Sort] = None, + projection: Optional[Dict[str, int]] = None, + ) -> Iterable[T]: + """ " + Find entities by mongo query + """ + return await self.find_by_with_output_type( + output_type=self._document_class, + query=query, + skip=skip, + limit=limit, + sort=sort, + projection=projection, + ) + + async def paginate_with_output_type( + self, + output_type: Type[OutputT], + query: dict, + limit: int, + after: Optional[str] = None, + before: Optional[str] = None, + sort: Optional[Sort] = None, + projection: Optional[Dict[str, int]] = None, + ) -> Iterable[Edge[OutputT]]: + """ + Paginate entities by mongo query allowing custom output type + """ + sort_keys = [] + + if not sort: + sort = [("_id", 1)] + + for sort_expression in sort: + sort_keys.append(sort_expression[0]) + + models = await self.find_by_with_output_type( + output_type, + query=self.get_pagination_query( + query=query, after=after, before=before, sort=sort + ), + limit=limit, + sort=sort, + projection=projection, + ) + + return map( + lambda model: Edge[OutputT]( + node=model, + cursor=encode_pagination_cursor( + get_pagination_cursor_payload(model, sort_keys) + ), + ), + models, + ) + + async def paginate( + self, + query: dict, + limit: int, + after: Optional[str] = None, + before: Optional[str] = None, + sort: Optional[Sort] = None, + projection: Optional[Dict[str, int]] = None, + ) -> Iterable[Edge[T]]: + """ + Paginate entities by mongo query using cursor based pagination + + Return type is an iterable of Edge objects, which contain the model and the cursor + """ + return await self.paginate_with_output_type( + self._document_class, + query, + limit, + after=after, + before=before, + sort=sort, + projection=projection, + ) diff --git a/pydantic_mongo/base_abstract_repository.py b/pydantic_mongo/base_abstract_repository.py new file mode 100644 index 0000000..5d8caa9 --- /dev/null +++ b/pydantic_mongo/base_abstract_repository.py @@ -0,0 +1,111 @@ +from typing import Any, Generic, Optional, Sequence, Tuple, Type, TypeVar, cast + +from pydantic import BaseModel + +from .pagination import decode_pagination_cursor + +T = TypeVar("T", bound=BaseModel) +OutputT = TypeVar("OutputT", bound=BaseModel) +Sort = Sequence[Tuple[str, int]] + + +class ModelWithId(BaseModel): + id: Any + + +class BaseAbstractRepository(Generic[T]): + class Meta: + collection_name: str + + def __init__(self): + super().__init__() + + self._document_class = ( + getattr(self.Meta, "document_class") + if hasattr(self.Meta, "document_class") + else self.__orig_bases__[0].__args__[0] # type: ignore + ) + self._collection_name = self.Meta.collection_name + self.__validate() + + def __validate(self): + if "id" not in self._document_class.model_fields: + raise Exception("Document class should have id field") + if not self._collection_name: + raise Exception("Meta should contain collection name") + + @staticmethod + def to_document(model: T) -> dict: + """ + Convert model to document + :param model: + :return: dict + """ + model_with_id = cast(ModelWithId, model) + data = model_with_id.model_dump() + data.pop("id") + if model_with_id.id: + data["_id"] = model_with_id.id + return data + + def _map_id(self, data: dict) -> dict: + query = data.copy() + if "id" in data: + query["_id"] = query.pop("id") + return query + + def _map_sort(self, sort: Sort) -> Optional[Sort]: + result = [] + for item in sort: + key = item[0] + ordering = item[1] + if key == "id": + key = "_id" + result.append((key, ordering)) + return result + + def to_model_custom(self, output_type: Type[OutputT], data: dict) -> OutputT: + """ + Convert document to model with custom output type + """ + data_copy = data.copy() + if "_id" in data_copy: + data_copy["id"] = data_copy.pop("_id") + return output_type.model_validate(data_copy) + + def to_model(self, data: dict) -> T: + """ + Convert document to model + """ + return self.to_model_custom(self._document_class, data) + + def get_pagination_query( + self, + query: dict, + after: Optional[str] = None, + before: Optional[str] = None, + sort: Optional[Sort] = None, + ) -> dict: + """ + Build pagination query based on the cursor and sort + """ + generated_query: dict = {"$and": [query]} + selected_cursor = after or before + + if selected_cursor and sort: + cursor_data = decode_pagination_cursor(selected_cursor) + dict_values = [] + for i, sort_expression in enumerate(sort): + if after: + compare_operator = "$gt" if sort_expression[1] > 0 else "$lt" + else: + compare_operator = "$lt" if sort_expression[1] > 0 else "$gt" + dict_values.append( + (sort_expression[0], {compare_operator: cursor_data[i]}) + ) + generated_query["$and"].append(dict(dict_values)) + + if len(generated_query["$and"]) == 1: + generated_query = query or {} + + return generated_query diff --git a/requirements_test.txt b/requirements_test.txt index 5a80ee6..1efb09b 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -6,9 +6,10 @@ phulpy==1.0.10 pytest==8.2.0 pytest-cov==4.1.0 pytest-mock==3.14.0 +pytest-asyncio==0.24.0 mongomock==4.1.2 pydantic==2.7.1 -pymongo==4.7.0 +pymongo==4.9.0 mypy==1.10.0 mypy-extensions==1.0.0 black==24.4.2 diff --git a/test/conftest.py b/test/conftest.py index 69bffd4..34c2030 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,5 @@ import mongomock import pytest -from pymongo.database import Database @pytest.fixture(scope="session") diff --git a/test/test_async_repository.py b/test/test_async_repository.py new file mode 100644 index 0000000..e6d2f6d --- /dev/null +++ b/test/test_async_repository.py @@ -0,0 +1,275 @@ +from typing import List, Optional, cast + +import pytest +from bson import ObjectId +from pydantic import BaseModel, Field +from pymongo import AsyncMongoClient + +from pydantic_mongo import AbstractRepository, PydanticObjectId +from pydantic_mongo.async_abstract_repository import AsyncAbstractRepository +from pydantic_mongo.errors import PaginationError + + +class Foo(BaseModel): + count: int + size: Optional[float] = None + + +class Bar(BaseModel): + apple: str = Field(default="x") + banana: str = Field(default="y") + + +class Spam(BaseModel): + id: Optional[PydanticObjectId] = None + foo: Optional[Foo] = None + bars: Optional[List[Bar]] = None + + +class SpamRepository(AsyncAbstractRepository[Spam]): + class Meta: + collection_name = "spams" + + +@pytest.fixture +def database(): + import asyncio + + client = AsyncMongoClient("mongodb://root:example@0.0.0.0:27017") + asyncio.run(client.drop_database("db")) + + return client.db + + +class AsyncTestRepository: + @pytest.mark.asyncio + async def test_save(self, database): + spam_repository = SpamRepository(database=database) + foo = Foo(count=1, size=1.0) + bar = Bar() + spam = Spam(foo=foo, bars=[bar]) + await spam_repository.save(spam) + + assert { + "_id": ObjectId(spam.id), + "foo": {"count": 1, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + } == database["spams"].find()[0] + + cast(Foo, spam.foo).count = 2 + await spam_repository.save(spam) + + assert { + "_id": ObjectId(spam.id), + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + } == database["spams"].find()[0] + + @pytest.mark.asyncio + async def test_save_upsert(self, database): + spam_repository = SpamRepository(database=database) + spam = Spam( + id=ObjectId("65012da68ea5a4798502f710"), foo=Foo(count=1, size=1.0), bars=[] + ) + await spam_repository.save(spam) + + assert { + "_id": ObjectId(spam.id), + "foo": {"count": 1, "size": 1.0}, + "bars": [], + } == database["spams"].find()[0] + + @pytest.mark.asyncio + async def test_save_many(self, database): + spam_repository = SpamRepository(database=database) + spams = [ + Spam(), + Spam(id=ObjectId("65012da68ea5a4798502f710")), + ] + await spam_repository.save_many(spams) + + initial_data = [ + { + "_id": ObjectId(spams[0].id), + "foo": None, + "bars": None, + }, + { + "_id": ObjectId(spams[1].id), + "foo": None, + "bars": None, + }, + ] + + assert initial_data == list(database["spams"].find()) + + # Calling save_many again will only update + await spam_repository.save_many(spams) + assert initial_data == list(database["spams"].find()) + + # Calling save_many with only a new model will only insert + new_span = Spam() + await spam_repository.save_many([new_span]) + assert new_span.id is not None + assert 3 == await database["spams"].count_documents({}) + + @pytest.mark.asyncio + async def test_delete(self, database): + spam_repository = SpamRepository(database=database) + foo = Foo(count=1, size=1.0) + bar = Bar() + spam = Spam(foo=foo, bars=[bar]) + await spam_repository.save(spam) + + result = await spam_repository.find_one_by_id(spam.id) + assert result is not None + + await spam_repository.delete(spam) + result = await spam_repository.find_one_by_id(spam.id) + assert result is None + + @pytest.mark.asyncio + async def test_delete_by_id(self, database): + spam_repository = SpamRepository(database=database) + foo = Foo(count=1, size=1.0) + bar = Bar() + spam = Spam(foo=foo, bars=[bar]) + await spam_repository.save(spam) + + result = await spam_repository.find_one_by_id(spam.id) + assert result is not None + + await spam_repository.delete_by_id(spam.id) + result = await spam_repository.find_one_by_id(spam.id) + assert result is None + + @pytest.mark.asyncio + async def test_find_by_id(self, database): + spam_id = ObjectId("611827f2878b88b49ebb69fc") + await database["spams"].insert_one( + { + "_id": spam_id, + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + } + ) + + spam_repository = SpamRepository(database=database) + result = await spam_repository.find_one_by_id(spam_id) + + assert result is not None + assert result.bars is not None + assert issubclass(Spam, type(result)) + assert spam_id == result.id + assert "x" == result.bars[0].apple + + @pytest.mark.asyncio + async def test_find_by(self, database): + await database["spams"].insert_many( + [ + { + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + { + "foo": {"count": 3, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + ] + ) + + spam_repository = SpamRepository(database=database) + + # Simple Find + result = await spam_repository.find_by({}) + results = [x for x in result] + assert 2 == len(results) + assert results[0].foo is not None + assert results[1].foo is not None + assert 2 == results[0].foo.count + assert 3 == results[1].foo.count + + # Find with optional parameters + result = await spam_repository.find_by( + {}, skip=10, limit=10, sort=[("foo.count", 1), ("id", 1)] + ) + results = [x for x in result] + assert 0 == len(results) + + @pytest.mark.asyncio + async def test_invalid_model_id_field(self, database): + class NoIdModel(BaseModel): + something: str + + class BrokenRepository(AbstractRepository[NoIdModel]): + class Meta: + collection_name = "spams" + + with pytest.raises(Exception): + BrokenRepository(database=database) + + @pytest.mark.asyncio + async def test_invalid_model_collection_name(self, database): + class BrokenRepository(AbstractRepository[Spam]): + class Meta: + collection_name = None + + with pytest.raises(Exception): + BrokenRepository(database=database) + + @pytest.mark.asyncio + async def test_paginate(self, database): + await database["spams"].insert_many( + [ + { + "_id": ObjectId("611b140f4eb6ee47e966860f"), + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + { + "id": ObjectId("611b141cf533ca420b7580d6"), + "foo": {"count": 3, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + { + "_id": ObjectId("611b15241dea2ee3f7cbfe30"), + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + { + "_id": ObjectId("611b157c859bde7de88c98ac"), + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + { + "_id": ObjectId("611b158adec89d18984b7d90"), + "foo": {"count": 2, "size": 1.0}, + "bars": [{"apple": "x", "banana": "y"}], + }, + ] + ) + + spam_repository = SpamRepository(database=database) + + # Simple Find + result = list(await spam_repository.paginate({}, limit=10)) + assert len(result) == 5 + + # Find After + result = list( + await spam_repository.paginate( + {}, limit=10, after="eNqTYWBgYCljEAFS7AYMidKiXfdOzJWY4V07gYEBAD7HBkg=" + ) + ) + assert len(result) == 1 + + # Find Before + result = list( + await spam_repository.paginate( + {}, limit=10, before="eNqTYWBgYCljEAFS7AYMidKiXfdOzJWY4V07gYEBAD7HBkg=" + ) + ) + assert len(result) == 3 + + with pytest.raises(PaginationError): + await spam_repository.paginate({}, limit=10, after="invalid string") diff --git a/test/test_enhance_meta.py b/test/test_enhance_meta.py index 221e2a1..b2e06fe 100644 --- a/test/test_enhance_meta.py +++ b/test/test_enhance_meta.py @@ -1,8 +1,9 @@ import pytest -from pydantic import BaseModel, Field -from pydantic_mongo import AbstractRepository, PydanticObjectId +from pydantic import BaseModel from typing_extensions import Optional +from pydantic_mongo import AbstractRepository, PydanticObjectId + class HamModel(BaseModel): id: Optional[PydanticObjectId] diff --git a/test/test_fields.py b/test/test_fields.py index 6ff156d..073cf62 100644 --- a/test/test_fields.py +++ b/test/test_fields.py @@ -1,5 +1,4 @@ import pytest -from typing import Optional from bson import ObjectId from pydantic import BaseModel, ValidationError diff --git a/test/test_repository.py b/test/test_repository.py index 81d51bc..59f8b6c 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -61,9 +61,7 @@ def test_save(self, database): def test_save_upsert(self, database): spam_repository = SpamRepository(database=database) spam = Spam( - id=ObjectId("65012da68ea5a4798502f710"), - foo=Foo(count=1, size=1.0), - bars=[] + id=ObjectId("65012da68ea5a4798502f710"), foo=Foo(count=1, size=1.0), bars=[] ) spam_repository.save(spam) From 74fdff2d5c09f5da90e0db305cf70990089a775d Mon Sep 17 00:00:00 2001 From: ivan <79514623510@yandex.ru> Date: Fri, 4 Oct 2024 15:22:50 +0500 Subject: [PATCH 2/4] added missing type annotation --- test/test_async_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_async_repository.py b/test/test_async_repository.py index e6d2f6d..e4aacd6 100644 --- a/test/test_async_repository.py +++ b/test/test_async_repository.py @@ -35,7 +35,7 @@ class Meta: def database(): import asyncio - client = AsyncMongoClient("mongodb://root:example@0.0.0.0:27017") + client: AsyncMongoClient = AsyncMongoClient("mongodb://root:example@0.0.0.0:27017") asyncio.run(client.drop_database("db")) return client.db From e34fa6315d295bb9ef171ec361e32d2261ae7e13 Mon Sep 17 00:00:00 2001 From: ivan <79514623510@yandex.ru> Date: Fri, 4 Oct 2024 15:36:01 +0500 Subject: [PATCH 3/4] fixed mongodb address in tests --- test/test_async_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_async_repository.py b/test/test_async_repository.py index e4aacd6..f290b01 100644 --- a/test/test_async_repository.py +++ b/test/test_async_repository.py @@ -35,7 +35,7 @@ class Meta: def database(): import asyncio - client: AsyncMongoClient = AsyncMongoClient("mongodb://root:example@0.0.0.0:27017") + client: AsyncMongoClient = AsyncMongoClient("mongodb://localhost:27017") asyncio.run(client.drop_database("db")) return client.db From c4b62f0683af6986ef9f48999536f5afe4861e0d Mon Sep 17 00:00:00 2001 From: ivan <79514623510@yandex.ru> Date: Fri, 4 Oct 2024 16:04:54 +0500 Subject: [PATCH 4/4] fixed tests for AsyncAbstractRepository --- test/test_async_repository.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/test_async_repository.py b/test/test_async_repository.py index f290b01..9915cfc 100644 --- a/test/test_async_repository.py +++ b/test/test_async_repository.py @@ -41,7 +41,7 @@ def database(): return client.db -class AsyncTestRepository: +class TestAsyncRepository: @pytest.mark.asyncio async def test_save(self, database): spam_repository = SpamRepository(database=database) @@ -50,20 +50,22 @@ async def test_save(self, database): spam = Spam(foo=foo, bars=[bar]) await spam_repository.save(spam) + result = await database["spams"].find().to_list(length=None) assert { "_id": ObjectId(spam.id), "foo": {"count": 1, "size": 1.0}, "bars": [{"apple": "x", "banana": "y"}], - } == database["spams"].find()[0] + } == result[0] cast(Foo, spam.foo).count = 2 await spam_repository.save(spam) + result = await database["spams"].find().to_list(length=None) assert { "_id": ObjectId(spam.id), "foo": {"count": 2, "size": 1.0}, "bars": [{"apple": "x", "banana": "y"}], - } == database["spams"].find()[0] + } == result[0] @pytest.mark.asyncio async def test_save_upsert(self, database): @@ -73,11 +75,12 @@ async def test_save_upsert(self, database): ) await spam_repository.save(spam) + result = await database["spams"].find().to_list(length=None) assert { "_id": ObjectId(spam.id), "foo": {"count": 1, "size": 1.0}, "bars": [], - } == database["spams"].find()[0] + } == result[0] @pytest.mark.asyncio async def test_save_many(self, database): @@ -101,11 +104,13 @@ async def test_save_many(self, database): }, ] - assert initial_data == list(database["spams"].find()) + result = await database["spams"].find().to_list(length=None) + assert initial_data == result # Calling save_many again will only update await spam_repository.save_many(spams) - assert initial_data == list(database["spams"].find()) + result = await database["spams"].find().to_list(length=None) + assert initial_data == result # Calling save_many with only a new model will only insert new_span = Spam()