Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
[MYPY]: 54/220 warnings addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
amadolid committed Mar 6, 2024
1 parent 5912c20 commit 86eb8e3
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 86 deletions.
140 changes: 91 additions & 49 deletions jaclang_jaseci/collections/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BaseCollection Interface."""

from os import getenv
from typing import Any, AsyncGenerator, Callable, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional, Union

from bson import ObjectId

Expand All @@ -23,17 +23,17 @@ class BaseCollection:
This interface use for connecting to mongodb.
"""

__collection__ = None
__collection_obj__: Collection = None
__indexes__ = []
__excluded__ = []
__collection__: Optional[str] = None
__collection_obj__: Optional[Collection] = None
__indexes__: list = []
__excluded__: list = []
__excluded_obj__ = None

__client__: AsyncIOMotorClient = None # type: ignore
__database__: Database = None
__client__: Optional[AsyncIOMotorClient] = None # type: ignore
__database__: Optional[Database] = None

@classmethod
def __document__(cls, doc: dict) -> dict:
def __document__(cls, doc: dict) -> Union[dict, object]:
"""
Return parsed version of document.
Expand All @@ -43,51 +43,51 @@ def __document__(cls, doc: dict) -> dict:
return doc

@classmethod
async def __documents__(cls, docs: AsyncIOMotorCursor) -> AsyncGenerator[Any, None]: # type: ignore
async def __documents__(cls, docs: AsyncIOMotorCursor) -> AsyncGenerator[Union[dict, object], None]: # type: ignore
"""
Return parsed version of multiple documents.
This the default parser after getting a list of documents.
You may override this to specify how/which class it will be casted/based.
"""
return (cls.__document__(doc) async for doc in docs)
return (cls.__document__(doc) async for doc in docs) # type: ignore[attr-defined]

@staticmethod
def get_client() -> AsyncIOMotorClient: # type: ignore
"""Return pymongo.database.Database for mongodb connection."""
if not isinstance(__class__.__client__, AsyncIOMotorClient):
__class__.__client__ = AsyncIOMotorClient(
if not isinstance(BaseCollection.__client__, AsyncIOMotorClient):
BaseCollection.__client__ = AsyncIOMotorClient(
getenv(
"DATABASE_HOST",
"mongodb://localhost/?retryWrites=true&w=majority",
),
server_api=ServerApi("1"),
)

return __class__.__client__
return BaseCollection.__client__

@staticmethod
async def get_session() -> ClientSession:
"""Return pymongo.client_session.ClientSession used for mongodb transactional operations."""
return await __class__.get_client().start_session()
return await BaseCollection.get_client().start_session() # type: ignore[attr-defined]

@staticmethod
def get_database() -> Database:
"""Return pymongo.database.Database for database connection based from current client connection."""
if not isinstance(__class__.__database__, Database):
__class__.__database__ = __class__.get_client().get_database(
if not isinstance(BaseCollection.__database__, Database):
BaseCollection.__database__ = BaseCollection.get_client().get_database( # type: ignore[attr-defined]
getenv("DATABASE_NAME", "jaclang")
)

return __class__.__database__
return BaseCollection.__database__

@staticmethod
def get_collection(collection: str) -> Collection:
"""Return pymongo.collection.Collection for collection connection based from current database connection."""
return __class__.get_database().get_collection(collection)
return BaseCollection.get_database().get_collection(collection)

@classmethod
async def collection(cls, session: ClientSession = None) -> Collection:
async def collection(cls, session: Optional[ClientSession] = None) -> Collection:
"""Return pymongo.collection.Collection for collection connection based from attribute of it's child class."""
if not isinstance(cls.__collection_obj__, Collection):
cls.__collection_obj__ = cls.get_collection(
Expand All @@ -105,17 +105,26 @@ async def collection(cls, session: ClientSession = None) -> Collection:
while cls.__indexes__:
idx = cls.__indexes__.pop()
idxs.append(IndexModel(idx.pop("fields"), **idx))
await cls.__collection_obj__.create_indexes(idxs, session=session)

ops = cls.__collection_obj__.create_indexes(idxs, session=session)
if isinstance(ops, Awaitable):
await ops

return cls.__collection_obj__

@classmethod
async def insert_one(cls, doc: dict, session: ClientSession = None) -> ObjectId:
async def insert_one(
cls, doc: dict, session: Optional[ClientSession] = None
) -> Optional[ObjectId]:
"""Insert single document and return the inserted id."""
try:
collection = await cls.collection(session=session)
result = await collection.insert_one(doc, session=session)
return result.inserted_id

ops = collection.insert_one(doc, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.inserted_id
except Exception:
if session:
raise
Expand All @@ -124,13 +133,17 @@ async def insert_one(cls, doc: dict, session: ClientSession = None) -> ObjectId:

@classmethod
async def insert_many(
cls, docs: list[dict], session: ClientSession = None
cls, docs: list[dict], session: Optional[ClientSession] = None
) -> list[ObjectId]:
"""Insert multiple documents and return the inserted ids."""
try:
collection = await cls.collection(session=session)
result = await collection.insert_many(docs, session=session)
return result.inserted_ids

ops = collection.insert_many(docs, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.inserted_ids
except Exception:
if session:
raise
Expand All @@ -139,13 +152,17 @@ async def insert_many(

@classmethod
async def update_one(
cls, filter: dict, update: dict, session: ClientSession = None
cls, filter: dict, update: dict, session: Optional[ClientSession] = None
) -> int:
"""Update single document and return if it's modified or not."""
try:
collection = await cls.collection(session=session)
result = await collection.update_one(filter, update, session=session)
return result.modified_count

ops = collection.update_one(filter, update, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.modified_count
except Exception:
if session:
raise
Expand All @@ -154,13 +171,17 @@ async def update_one(

@classmethod
async def update_many(
cls, filter: dict, update: dict, session: ClientSession = None
cls, filter: dict, update: dict, session: Optional[ClientSession] = None
) -> int:
"""Update multiple documents and return how many docs are modified."""
try:
collection = await cls.collection(session=session)
result = await collection.update_many(filter, update, session=session)
return result.modified_count

ops = collection.update_many(filter, update, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.modified_count
except Exception:
if session:
raise
Expand All @@ -169,7 +190,10 @@ async def update_many(

@classmethod
async def update_by_id(
cls, id: Union[str, ObjectId], update: dict, session: ClientSession = None
cls,
id: Union[str, ObjectId],
update: dict,
session: Optional[ClientSession] = None,
) -> int:
"""Update single document via ID and return if it's modified or not."""
return await cls.update_one({"_id": ObjectId(id)}, update, session)
Expand All @@ -179,8 +203,8 @@ async def find(
cls,
*args: list[Any],
cursor: Callable = lambda x: x,
**kwargs: dict[str, Any],
) -> list[object]:
**kwargs: Any, # noqa ANN401
) -> AsyncGenerator[Union[dict, object], None]:
"""Retrieve multiple documents."""
collection = await cls.collection()

Expand All @@ -191,16 +215,20 @@ async def find(
return await cls.__documents__(docs)

@classmethod
async def find_one(cls, *args: list[Any], **kwargs: dict[str, Any]) -> object:
async def find_one(cls, *args: Any, **kwargs: Any) -> object: # noqa ANN401
"""Retrieve single document from db."""
collection = await cls.collection()

if "projection" not in kwargs:
kwargs["projection"] = cls.__excluded_obj__

if result := await collection.find_one(*args, **kwargs):
return cls.__document__(result)
return result
ops = collection.find_one(*args, **kwargs)
if isinstance(ops, Awaitable):
ops = await ops

if ops:
return cls.__document__(ops)
return ops

@classmethod
async def find_by_id(
Expand All @@ -210,25 +238,35 @@ async def find_by_id(
return await cls.find_one({"_id": ObjectId(id)}, *args, **kwargs)

@classmethod
async def delete(cls, filter: dict, session: ClientSession = None) -> int:
async def delete(cls, filter: dict, session: Optional[ClientSession] = None) -> int:
"""Delete document/s via filter and return how many documents are deleted."""
try:
collection = await cls.collection(session=session)
result = await collection.delete_many(filter, session=session)
return result.deleted_count

ops = collection.delete_many(filter, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.deleted_count
except Exception:
if session:
raise
logger.exception(f"Error delete with filter:\n{filter}")
return 0

@classmethod
async def delete_one(cls, filter: dict, session: ClientSession = None) -> int:
async def delete_one(
cls, filter: dict, session: Optional[ClientSession] = None
) -> int:
"""Delete single document via filter and return if it's deleted or not."""
try:
collection = await cls.collection(session=session)
result = await collection.delete_one(filter, session=session)
return result.deleted_count

ops = collection.delete_one(filter, session=session)
if isinstance(ops, Awaitable):
ops = await ops

return ops.deleted_count
except Exception:
if session:
raise
Expand All @@ -237,7 +275,7 @@ async def delete_one(cls, filter: dict, session: ClientSession = None) -> int:

@classmethod
async def delete_by_id(
cls, id: Union[str, ObjectId], session: ClientSession = None
cls, id: Union[str, ObjectId], session: Optional[ClientSession] = None
) -> int:
"""Delete single document via ID and return if it's deleted or not."""
return await cls.delete_one({"_id": ObjectId(id)}, session)
Expand All @@ -246,13 +284,17 @@ async def delete_by_id(
async def bulk_write(
cls,
ops: list[Union[InsertOne, DeleteMany, DeleteOne, UpdateMany, UpdateOne]],
session: ClientSession = None,
session: Optional[ClientSession] = None,
) -> dict: # noqa ANN401
"""Bulk write operations."""
try:
collection = await cls.collection(session=session)
result = await collection.bulk_write(ops, session=session)
return result.bulk_api_result

_ops = collection.bulk_write(ops, session=session)
if isinstance(_ops, Awaitable):
_ops = await _ops

return _ops.bulk_api_result
except Exception:
if session:
raise
Expand Down
8 changes: 5 additions & 3 deletions jaclang_jaseci/collections/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""UserCollection Interface."""

from typing import Optional

from .base import BaseCollection


Expand All @@ -11,9 +13,9 @@ class UserCollection(BaseCollection):
You may override this if you wish to implement different structure
"""

__collection__ = "user"
__excluded__ = ["password"]
__indexes__ = [{"fields": ["email"], "unique": True}]
__collection__: Optional[str] = "user"
__excluded__: list[str] = ["password"]
__indexes__: list[dict] = [{"fields": ["email"], "unique": True}]

@classmethod
async def find_by_email(cls, email: str) -> object:
Expand Down
Loading

0 comments on commit 86eb8e3

Please sign in to comment.