Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
[MYPY]: all warnings addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
amadolid committed Mar 7, 2024
1 parent ca61212 commit 9f181ac
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 241 deletions.
113 changes: 73 additions & 40 deletions jaclang_jaseci/collections/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
"""BaseCollection Interface."""

from os import getenv
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional, Union, cast

from bson import ObjectId

from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCursor

from pymongo import DeleteMany, DeleteOne, IndexModel, InsertOne, UpdateMany, UpdateOne
from pymongo.client_session import ClientSession
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorClientSession,
AsyncIOMotorCollection,
AsyncIOMotorCursor,
AsyncIOMotorDatabase,
)

from pymongo import (
DeleteMany,
DeleteOne,
IndexModel,
InsertOne,
MongoClient,
UpdateMany,
UpdateOne,
)
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.server_api import ServerApi
Expand All @@ -24,13 +37,13 @@ class BaseCollection:
"""

__collection__: Optional[str] = None
__collection_obj__: Optional[Collection] = None
__collection_obj__: Optional[AsyncIOMotorCollection] = None # type: ignore[valid-type]
__indexes__: list = []
__excluded__: list = []
__excluded_obj__ = None

__client__: Optional[AsyncIOMotorClient] = None # type: ignore
__database__: Optional[Database] = None
__client__: Optional[AsyncIOMotorClient] = None # type: ignore[valid-type]
__database__: Optional[AsyncIOMotorDatabase] = None # type: ignore[valid-type]

@classmethod
def __document__(cls, doc: dict) -> Union[dict, object]:
Expand All @@ -43,7 +56,9 @@ def __document__(cls, doc: dict) -> Union[dict, object]:
return doc

@classmethod
async def __documents__(cls, docs: AsyncIOMotorCursor) -> AsyncGenerator[Union[dict, object], None]: # type: ignore
async def __documents__(
cls, docs: AsyncIOMotorCursor # type: ignore[valid-type]
) -> AsyncGenerator[Union[dict, object], None]:
"""
Return parsed version of multiple documents.
Expand All @@ -53,7 +68,7 @@ async def __documents__(cls, docs: AsyncIOMotorCursor) -> AsyncGenerator[Union[d
return (cls.__document__(doc) async for doc in docs) # type: ignore[attr-defined]

@staticmethod
def get_client() -> AsyncIOMotorClient: # type: ignore
def get_client() -> AsyncIOMotorClient: # type: ignore[valid-type]
"""Return pymongo.database.Database for mongodb connection."""
if not isinstance(BaseCollection.__client__, AsyncIOMotorClient):
BaseCollection.__client__ = AsyncIOMotorClient(
Expand All @@ -67,29 +82,31 @@ def get_client() -> AsyncIOMotorClient: # type: ignore
return BaseCollection.__client__

@staticmethod
async def get_session() -> ClientSession:
async def get_session() -> AsyncIOMotorClientSession: # type: ignore[valid-type]
"""Return pymongo.client_session.ClientSession used for mongodb transactional operations."""
return await BaseCollection.get_client().start_session() # type: ignore[attr-defined]
return await cast(MongoClient, BaseCollection.get_client()).start_session() # type: ignore[misc]

@staticmethod
def get_database() -> Database:
def get_database() -> AsyncIOMotorDatabase: # type: ignore[valid-type]
"""Return pymongo.database.Database for database connection based from current client connection."""
if not isinstance(BaseCollection.__database__, Database):
BaseCollection.__database__ = BaseCollection.get_client().get_database( # type: ignore[attr-defined]
getenv("DATABASE_NAME", "jaclang")
)
if not isinstance(BaseCollection.__database__, AsyncIOMotorDatabase):
BaseCollection.__database__ = cast(
MongoClient, BaseCollection.get_client()
).get_database(getenv("DATABASE_NAME", "jaclang"))

return BaseCollection.__database__

@staticmethod
def get_collection(collection: str) -> Collection:
def get_collection(collection: str) -> AsyncIOMotorCollection: # type: ignore[valid-type]
"""Return pymongo.collection.Collection for collection connection based from current database connection."""
return BaseCollection.get_database().get_collection(collection)
return cast(Database, BaseCollection.get_database()).get_collection(collection)

@classmethod
async def collection(cls, session: Optional[ClientSession] = None) -> Collection:
async def collection(
cls, session: Optional[AsyncIOMotorClientSession] = None # type: ignore[valid-type]
) -> AsyncIOMotorCollection: # type: ignore[valid-type]
"""Return pymongo.collection.Collection for collection connection based from attribute of it's child class."""
if not isinstance(cls.__collection_obj__, Collection):
if not isinstance(cls.__collection_obj__, AsyncIOMotorCollection):
cls.__collection_obj__ = cls.get_collection(
getattr(cls, "__collection__", None) or cls.__name__.lower()
)
Expand All @@ -106,21 +123,23 @@ async def collection(cls, session: Optional[ClientSession] = None) -> Collection
idx = cls.__indexes__.pop()
idxs.append(IndexModel(idx.pop("fields"), **idx))

ops = cls.__collection_obj__.create_indexes(idxs, session=session)
ops = cast(Collection, cls.__collection_obj__).create_indexes(
idxs, session=session
)
if isinstance(ops, Awaitable):
await ops

return cls.__collection_obj__

@classmethod
async def insert_one(
cls, doc: dict, session: Optional[ClientSession] = None
cls, doc: dict, session: Optional[AsyncIOMotorClientSession] = None # type: ignore[valid-type]
) -> Optional[ObjectId]:
"""Insert single document and return the inserted id."""
try:
collection = await cls.collection(session=session)

ops = collection.insert_one(doc, session=session)
ops = cast(Collection, collection).insert_one(doc, session=session)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -133,13 +152,13 @@ async def insert_one(

@classmethod
async def insert_many(
cls, docs: list[dict], session: Optional[ClientSession] = None
cls, docs: list[dict], session: Optional[AsyncIOMotorClientSession] = None # type: ignore[valid-type]
) -> list[ObjectId]:
"""Insert multiple documents and return the inserted ids."""
try:
collection = await cls.collection(session=session)

ops = collection.insert_many(docs, session=session)
ops = cast(Collection, collection).insert_many(docs, session=session)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -152,13 +171,18 @@ async def insert_many(

@classmethod
async def update_one(
cls, filter: dict, update: dict, session: Optional[ClientSession] = None
cls,
filter: dict,
update: dict,
session: Optional[AsyncIOMotorClientSession] = None, # type: ignore[valid-type]
) -> int:
"""Update single document and return if it's modified or not."""
try:
collection = await cls.collection(session=session)

ops = collection.update_one(filter, update, session=session)
ops = cast(Collection, collection).update_one(
filter, update, session=session
)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -171,13 +195,18 @@ async def update_one(

@classmethod
async def update_many(
cls, filter: dict, update: dict, session: Optional[ClientSession] = None
cls,
filter: dict,
update: dict,
session: Optional[AsyncIOMotorClientSession] = None, # type: ignore[valid-type]
) -> int:
"""Update multiple documents and return how many docs are modified."""
try:
collection = await cls.collection(session=session)

ops = collection.update_many(filter, update, session=session)
ops = cast(Collection, collection).update_many(
filter, update, session=session
)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -193,7 +222,7 @@ async def update_by_id(
cls,
id: Union[str, ObjectId],
update: dict,
session: Optional[ClientSession] = None,
session: Optional[AsyncIOMotorClientSession] = None, # type: ignore[valid-type]
) -> int:
"""Update single document via ID and return if it's modified or not."""
return await cls.update_one({"_id": ObjectId(id)}, update, session)
Expand All @@ -211,7 +240,7 @@ async def find(
if "projection" not in kwargs:
kwargs["projection"] = cls.__excluded_obj__

docs = cursor(collection.find(*args, **kwargs))
docs = cursor(cast(Collection, collection).find(*args, **kwargs))
return await cls.__documents__(docs)

@classmethod
Expand All @@ -222,7 +251,7 @@ async def find_one(cls, *args: Any, **kwargs: Any) -> object: # noqa ANN401
if "projection" not in kwargs:
kwargs["projection"] = cls.__excluded_obj__

ops = collection.find_one(*args, **kwargs)
ops = cast(Collection, collection).find_one(*args, **kwargs)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -238,12 +267,14 @@ async def find_by_id(
return await cls.find_one({"_id": ObjectId(id)}, *args, **kwargs)

@classmethod
async def delete(cls, filter: dict, session: Optional[ClientSession] = None) -> int:
async def delete(
cls, filter: dict, session: Optional[AsyncIOMotorClientSession] = None # type: ignore[valid-type]
) -> int:
"""Delete document/s via filter and return how many documents are deleted."""
try:
collection = await cls.collection(session=session)

ops = collection.delete_many(filter, session=session)
ops = cast(Collection, collection).delete_many(filter, session=session)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -256,13 +287,13 @@ async def delete(cls, filter: dict, session: Optional[ClientSession] = None) ->

@classmethod
async def delete_one(
cls, filter: dict, session: Optional[ClientSession] = None
cls, filter: dict, session: Optional[AsyncIOMotorClientSession] = None # type: ignore[valid-type]
) -> int:
"""Delete single document via filter and return if it's deleted or not."""
try:
collection = await cls.collection(session=session)

ops = collection.delete_one(filter, session=session)
ops = cast(Collection, collection).delete_one(filter, session=session)
if isinstance(ops, Awaitable):
ops = await ops

Expand All @@ -275,7 +306,9 @@ async def delete_one(

@classmethod
async def delete_by_id(
cls, id: Union[str, ObjectId], session: Optional[ClientSession] = None
cls,
id: Union[str, ObjectId],
session: Optional[AsyncIOMotorClientSession] = None, # type: ignore[valid-type]
) -> int:
"""Delete single document via ID and return if it's deleted or not."""
return await cls.delete_one({"_id": ObjectId(id)}, session)
Expand All @@ -284,13 +317,13 @@ async def delete_by_id(
async def bulk_write(
cls,
ops: list[Union[InsertOne, DeleteMany, DeleteOne, UpdateMany, UpdateOne]],
session: Optional[ClientSession] = None,
session: Optional[AsyncIOMotorClientSession] = None, # type: ignore[valid-type]
) -> dict: # noqa ANN401
"""Bulk write operations."""
try:
collection = await cls.collection(session=session)

_ops = collection.bulk_write(ops, session=session)
_ops = cast(Collection, collection).bulk_write(ops, session=session)
if isinstance(_ops, Awaitable):
_ops = await _ops

Expand Down
Loading

0 comments on commit 9f181ac

Please sign in to comment.