Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Async Support #100

Closed
wants to merge 10 commits into from
6 changes: 6 additions & 0 deletions .tox/py36/.tox-info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ToxEnv": {
"name": "py36",
"type": "VirtualEnvRunner"
}
}
6 changes: 6 additions & 0 deletions .tox/py37/.tox-info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ToxEnv": {
"name": "py37",
"type": "VirtualEnvRunner"
}
}
6 changes: 6 additions & 0 deletions .tox/py38/.tox-info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ToxEnv": {
"name": "py38",
"type": "VirtualEnvRunner"
}
}
6 changes: 6 additions & 0 deletions .tox/py39/.tox-info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ToxEnv": {
"name": "py39",
"type": "VirtualEnvRunner"
}
}
31 changes: 16 additions & 15 deletions sqlalchemy_mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# high-level mixins
from .activerecord import ActiveRecordMixin, ModelNotFoundError
from .activerecordasync import ActiveRecordMixinAsync
from .smartquery import SmartQueryMixin, smart_query
from .eagerload import EagerLoadMixin, JOINED, SUBQUERY
from .repr import ReprMixin
Expand All @@ -16,19 +17,19 @@ class AllFeaturesMixin(ActiveRecordMixin, SmartQueryMixin, ReprMixin, SerializeM
__abstract__ = True
__repr__ = ReprMixin.__repr__


__all__ = [
"ActiveRecordMixin",
"AllFeaturesMixin",
"EagerLoadMixin",
"InspectionMixin",
"JOINED",
"ModelNotFoundError",
"ReprMixin",
"SerializeMixin",
"SessionMixin",
"smart_query",
"SmartQueryMixin",
"SUBQUERY",
"TimestampsMixin",
]
'SessionMixin',
'InspectionMixin',
'ActiveRecordMixin',
'ModelNotFoundError',
'ActiveRecordMixinAsync',
'SmartQueryMixin',
'smart_query',
'EagerLoadMixin',
'JOINED',
'SUBQUERY',
'ReprMixin',
'SerializeMixin',
'TimestampsMixin',
'AllFeaturesMixin',
]
30 changes: 26 additions & 4 deletions sqlalchemy_mixins/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
from .serialize import SerializeMixin
from .session import SessionMixin
from .inspection import InspectionMixin
from .activerecord import ActiveRecordMixin, ModelNotFoundError
from .activerecordasync import ActiveRecordMixinAsync
from .smartquery import SmartQueryMixin, smart_query
from .eagerload import EagerLoadMixin, JOINED, SUBQUERY
from .repr import ReprMixin
from .smartquery import SmartQueryMixin
from .activerecord import ActiveRecordMixin
from .serialize import SerializeMixin
from .timestamp import TimestampsMixin


class AllFeaturesMixin(ActiveRecordMixin, SmartQueryMixin, ReprMixin, SerializeMixin):
__abstract__ = True
__repr__ = ReprMixin.__repr__ # type: ignore
__repr__ = ReprMixin.__repr__

__all__ = [
'SessionMixin',
'InspectionMixin',
'ActiveRecordMixin',
'ModelNotFoundError',
'ActiveRecordMixinAsync',
'SmartQueryMixin',
'smart_query',
'EagerLoadMixin',
'JOINED',
'SUBQUERY',
'ReprMixin',
'SerializeMixin',
'TimestampsMixin',
'AllFeaturesMixin',
]
201 changes: 201 additions & 0 deletions sqlalchemy_mixins/activerecordasync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from sqlalchemy import select
from sqlalchemy.orm import Query
from .utils import classproperty
from .session import SessionMixin
from .inspection import InspectionMixin
from .activerecord import ModelNotFoundError
from . import smartquery as SmaryQuery

get_root_cls = SmaryQuery._get_root_cls
def async_root_cls(query: Query):
"""Monkey patch SmaryQuery to handle async queries."""
try:
return get_root_cls(query)
except ValueError:
# Handle async queries
if query.__dict__["_propagate_attrs"]["plugin_subject"].class_:
return query.__dict__["_propagate_attrs"]["plugin_subject"].class_
raise

SmaryQuery._get_root_cls = lambda query: async_root_cls(query)


class ActiveRecordMixinAsync(InspectionMixin, SessionMixin):
__abstract__ = True

@classproperty
def query(cls):
"""
Override the default query property to handle async session.
"""
if not hasattr(cls.session, "query"):
return select(cls)

return cls.session.query(cls)

async def save_async(self):
"""
Async version of :meth:`save` method.

:see: :meth:`save` method for more information.
"""
async with self.session() as session:
try:
session.add(self)
await session.commit()
return self
except:
await session.rollback()
raise

@classmethod
async def create_async(cls, **kwargs):
"""
Async version of :meth:`create` method.

:see: :meth:`create`
"""
return await cls().fill(**kwargs).save_async()

async def update_async(self, **kwargs):
"""
Async version of :meth:`update` method.

:see: :meth:`update`
"""
return await self.fill(**kwargs).save_async()

async def delete_async(self):
"""
Async version of :meth:`delete` method.

:see: :meth:`delete`
"""
async with self.session() as session:
try:
session.sync_session.delete(self)
await session.commit()
return self
except:
await session.rollback()
raise
finally:
await session.flush()

@classmethod
async def destroy_async(cls, *ids):
"""
Async version of :meth:`destroy` method.

:see: :meth:`destroy`
"""
primary_key = cls._get_primary_key_name()
if primary_key:
async with cls.session() as session:
try:
for row in await cls.where_async(**{f"{primary_key}__in": ids}):
session.sync_session.delete(row)
await session.commit()
except:
await session.rollback()
raise
await session.flush()

@classmethod
async def select_async(cls, stmt=None, filters=None, sort_attrs=None, schema=None):
async with cls.session() as session:
if stmt is None:
stmt = cls.smart_query(
filters=filters, sort_attrs=sort_attrs, schema=schema)
return (await session.execute(stmt)).scalars()

@classmethod
async def where_async(cls, **filters):
"""
Aync version of where method.

:see: :meth:`where` method for more details.
"""
return await cls.select_async(filters=filters)

@classmethod
async def sort_async(cls, *columns):
"""
Async version of sort method.

:see: :meth:`sort` method for more details.
"""
return await cls.select_async(sort_attrs=columns)

@classmethod
async def all_async(cls):
"""
Async version of all method.
This is same as calling ``(await select_async()).all()``.

:see: :meth:`all` method for more details.
"""
return (await cls.select_async()).all()

@classmethod
async def first_async(cls):
"""
Async version of first method.
This is same as calling ``(await select_async()).first()``.

:see: :meth:`first` method for more details.
"""
return (await cls.select_async()).first()

@classmethod
async def find_async(cls, id_):
"""
Async version of find method.

:see: :meth:`find` method for more details.
"""
primary_key = cls._get_primary_key_name()
if primary_key:
return (await cls.where_async(**{primary_key: id_})).first()
return None

@classmethod
async def find_or_fail_async(cls, id_):
"""
Async version of find_or_fail method.

:see: :meth:`find_or_fail` method for more details.
"""
cursor = await cls.find_async(id_)
if cursor:
return cursor
else:
raise ModelNotFoundError("{} with id '{}' was not found"
.format(cls.__name__, id_))

@classmethod
async def with_async(cls, schema):
"""
Async version of with method.

:see: :meth:`with` method for more details.
"""
return await cls.select_async(cls.with_(schema))

@classmethod
async def with_joined_async(cls, *paths):
"""
Async version of with_joined method.

:see: :meth:`with_joined` method for more details.
"""
return await cls.select_async(cls.with_joined(*paths))

@classmethod
async def with_subquery_async(cls, *paths):
"""
Async version of with_subquery method.

:see: :meth:`with_subquery` method for more details.
"""
return await cls.select_async(cls.with_subquery(*paths))
59 changes: 59 additions & 0 deletions sqlalchemy_mixins/activerecordasync.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Dict, Iterable, List, Any, Optional

from sqlalchemy_mixins.inspection import InspectionMixin
from sqlalchemy_mixins.session import SessionMixin
from sqlalchemy_mixins.utils import classproperty
from sqlalchemy.orm import Query, QueryableAttribute


class ActiveRecordMixinAsync(InspectionMixin, SessionMixin):

@classproperty
def settable_attributes(cls) -> List[str]: ...

async def save_async(self) -> "ActiveRecordMixinAsync": ...

@classmethod
async def create_async(cls, **kwargs: Any) -> "ActiveRecordMixinAsync": ...

async def update_async(self, **kwargs: dict) -> "ActiveRecordMixinAsync": ...

async def delete_async(self) -> None: ...

@classmethod
async def destroy_async(cls, *ids: list) -> None: ...

@classmethod
async def all_async(cls) -> List["ActiveRecordMixinAsync"]: ...

@classmethod
async def first_async(cls) -> Optional["ActiveRecordMixinAsync"]: ...

@classmethod
async def find_async(cls, id_: Any) -> Optional["ActiveRecordMixinAsync"]: ...

@classmethod
async def find_or_fail_async(cls, id_: Any) -> "ActiveRecordMixinAsync": ...

@classmethod
async def select_async(cls,
stmt:Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
sort_attrs: Optional[Iterable[str]] = None,
schema: Optional[dict] = None
) -> "ActiveRecordMixinAsync": ...

@classmethod
async def where_async(cls, **filters: Any) -> Query: ...

@classmethod
async def sort_async(cls, *columns: str) -> Query: ...

@classmethod
async def with_async(cls, schema: dict) -> Query: ...

@classmethod
async def with_joined_async(cls, *paths: List[QueryableAttribute]) -> Query: ...

@classmethod
async def with_subquery_async(cls, *paths: List[QueryableAttribute]) -> Query: ...
2 changes: 1 addition & 1 deletion sqlalchemy_mixins/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def hybrid_methods_full(cls):

@classproperty
def hybrid_methods(cls):
return list(cls.hybrid_methods_full.keys())
return list(cls.hybrid_methods_full.keys())
2 changes: 1 addition & 1 deletion sqlalchemy_mixins/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def query(cls):
"""
:rtype: Query
"""
return cls.session.query(cls)
return cls.session.query(cls)
2 changes: 1 addition & 1 deletion sqlalchemy_mixins/smartquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,4 +429,4 @@ def sort(cls, *columns):
Exanple 3 (with joins):
Post.sort('comments___rating', 'user___name').all()
"""
return cls.smart_query({}, columns)
return cls.smart_query({}, columns)
Loading