Skip to content

Commit

Permalink
chore: updated type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed May 13, 2024
1 parent f88365e commit b0c0bb3
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 236 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ install: ## Install the project and
@if [ "$(VENV_EXISTS)" ]; then $(MAKE) destroy; fi
@if [ "$(VENV_EXISTS)" ]; then $(MAKE) clean; fi
@if [ "$(USING_PDM)" ]; then $(PDM) config venv.in_project true && python3 -m venv --copies .venv && . $(ENV_PREFIX)/activate && $(ENV_PREFIX)/pip install --quiet -U wheel setuptools cython mypy build pip; fi
@if [ "$(USING_PDM)" ]; then $(PDM) install -G:all; fi
@if [ "$(USING_PDM)" ]; then $(PDM) use -f .venv && $(PDM) install -d -G:all; fi
@echo "=> Install complete! Note: If you want to re-install re-run 'make install'"


Expand Down
38 changes: 20 additions & 18 deletions advanced_alchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import contextlib
import re
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
from uuid import UUID

from sqlalchemy import Date, Index, MetaData, Sequence, String, UniqueConstraint
Expand Down Expand Up @@ -33,7 +33,9 @@

if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
from sqlalchemy.sql.schema import _NamingSchemaParameter as NamingSchemaParameter
from sqlalchemy.sql.schema import (
_NamingSchemaParameter as NamingSchemaParameter, # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.types import TypeEngine


Expand Down Expand Up @@ -78,11 +80,11 @@
"""Regular expression for table name"""


def merge_table_arguments(
def merge_table_arguments( # pyright: ignore[reportUnknownParameterType]
cls: DeclarativeBase,
*mixins: Any,
table_args: dict | tuple | None = None,
) -> tuple | dict:
table_args: dict | tuple | None = None, # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
) -> tuple | dict: # pyright: ignore[reportMissingTypeArgument]
"""Merge Table Arguments.
When using mixins that include their own table args, it is difficult to append info into the model such as a comment.
Expand All @@ -92,27 +94,27 @@ def merge_table_arguments(
Args:
cls (DeclarativeBase): This is the model that will get the table args
*mixins (Any): The mixins to add into the model
table_args: additional information to add to tableargs
table_args: additional information to add to table_args
Returns:
tuple | dict: The merged __table_args__ property
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}

mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in (cls, *mixins))
mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in (cls, *mixins)) # pyright: ignore[reportArgumentType,reportUnknownArgumentType]

for arg_to_merge in (*mixin_table_args, table_args):
for arg_to_merge in (*mixin_table_args, table_args): # pyright: ignore[reportUnknownVariableType]
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1]
args.extend(arg_to_merge[:-1])
last_positional_arg = arg_to_merge[-1] # pyright: ignore[reportUnknownVariableType]
args.extend(arg_to_merge[:-1]) # pyright: ignore[reportUnknownArgumentType]
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg)
kwargs.update(last_positional_arg) # pyright: ignore[reportUnknownArgumentType]
else:
args.append(last_positional_arg)
elif isinstance(arg_to_merge, dict):
kwargs.update(arg_to_merge)
kwargs.update(arg_to_merge) # pyright: ignore[reportUnknownArgumentType]

if args:
if kwargs:
Expand All @@ -126,8 +128,8 @@ class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol."""

__table__: FromClause
__mapper__: Mapper
__name__: ClassVar[str]
__mapper__: Mapper # pyright: ignore[reportMissingTypeArgument]
__name__: str

def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert model to dictionary.
Expand Down Expand Up @@ -204,9 +206,9 @@ class AuditColumns:
class BasicAttributes:
"""Basic attributes for SQLALchemy tables and queries."""

__name__: ClassVar[str]
__name__: str
__table__: FromClause
__mapper__: Mapper
__mapper__: Mapper # pyright: ignore[reportMissingTypeArgument]

def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert model to dictionary.
Expand All @@ -217,7 +219,7 @@ def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
exclude = {"sa_orm_sentinel", "_sentinel"}.union(self._sa_instance_state.unloaded).union(exclude or []) # type: ignore[attr-defined]
return {
field: getattr(self, field)
for field in self.__mapper__.columns.keys() # noqa: SIM118
for field in self.__mapper__.columns.keys() # pyright: ignore[reportUnknownMemberType] # noqa: SIM118
if field not in exclude
}

Expand Down Expand Up @@ -257,7 +259,7 @@ def _create_unique_slug_constraint(*_args: Any, **kwargs: Any) -> bool:
return not kwargs["dialect"].name.startswith("spanner")

@declared_attr.directive
def __table_args__(cls) -> tuple | dict:
def __table_args__(cls) -> tuple | dict: # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
return (
UniqueConstraint(
cls.slug,
Expand Down
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import random
import string
from typing import TYPE_CHECKING, Any, Final, Iterable, Literal, cast
from typing import TYPE_CHECKING, Any, Final, Iterable, List, Literal, cast

from sqlalchemy import (
Result,
Expand Down Expand Up @@ -1281,7 +1281,7 @@ async def list(
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
return cast("List[ModelT]", instances)

def filter_collection_by_kwargs(
self,
Expand Down
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import random
import string
from typing import TYPE_CHECKING, Any, Final, Iterable, Literal, cast
from typing import TYPE_CHECKING, Any, Final, Iterable, List, Literal, cast

from sqlalchemy import (
Result,
Expand Down Expand Up @@ -1282,7 +1282,7 @@ def list(
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
return cast("List[ModelT]", instances)

def filter_collection_by_kwargs(
self,
Expand Down
16 changes: 4 additions & 12 deletions advanced_alchemy/repository/typing.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
from typing import TYPE_CHECKING, Any, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Tuple, TypeVar, Union

if TYPE_CHECKING:
from sqlalchemy import Select
from sqlalchemy import RowMapping, Select

from advanced_alchemy import base
from advanced_alchemy.repository._async import SQLAlchemyAsyncRepository
from advanced_alchemy.repository._sync import SQLAlchemySyncRepository

__all__ = (
"ModelT",
"SelectT",
"RowT",
"SQLAlchemySyncRepositoryT",
"SQLAlchemyAsyncRepositoryT",
"MISSING",
)

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")


SelectT = TypeVar("SelectT", bound="Select[Any]")
RowT = TypeVar("RowT", bound=Tuple[Any, ...])


SQLAlchemySyncRepositoryT = TypeVar("SQLAlchemySyncRepositoryT", bound="SQLAlchemySyncRepository")
SQLAlchemyAsyncRepositoryT = TypeVar("SQLAlchemyAsyncRepositoryT", bound="SQLAlchemyAsyncRepository")
RowMappingT = TypeVar("RowMappingT", bound="RowMapping")
ModelOrRowMappingT = TypeVar("ModelOrRowMappingT", bound="Union[base.ModelProtocol, RowMapping]")


class _MISSING:
Expand Down
45 changes: 19 additions & 26 deletions advanced_alchemy/service/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from uuid import UUID

from advanced_alchemy.filters import FilterTypes, LimitOffset
from advanced_alchemy.repository.typing import ModelT
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination

if TYPE_CHECKING:
from sqlalchemy import ColumnElement

from advanced_alchemy.service.typing import FilterTypeT, ModelDTOT, RowMappingT
from advanced_alchemy.service.typing import FilterTypeT, ModelDTOT

try:
from msgspec import Struct, convert
from msgspec import Struct, convert # pyright: ignore[reportAssignmentType]
except ImportError: # pragma: nocover

class Struct: # type: ignore[no-redef]
Expand All @@ -34,11 +34,11 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa:


try:
from pydantic import BaseModel
from pydantic.type_adapter import TypeAdapter
from pydantic.main import ModelMetaclass # pyright: ignore[reportAssignmentType]
from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportAssignmentType]
except ImportError: # pragma: nocover

class BaseModel: # type: ignore[no-redef]
class ModelMetaclass: # type: ignore[no-redef]
"""Placeholder Implementation"""

class TypeAdapter: # type: ignore[no-redef]
Expand Down Expand Up @@ -81,13 +81,13 @@ def _default_deserializer(

def _find_filter(
filter_type: type[FilterTypeT],
*filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes],
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes],
) -> FilterTypeT | None:
"""Get the filter specified by filter type from the filters.
Args:
filter_type: The type of filter to find.
*filters: filter types to apply to the query
filters: filter types to apply to the query
Returns:
The match filter instance or None
Expand All @@ -99,18 +99,11 @@ def _find_filter(


def to_schema(
data: ModelT | Sequence[ModelT] | Sequence[RowMappingT] | RowMappingT,
data: ModelOrRowMappingT | Sequence[ModelOrRowMappingT],
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
schema_type: type[ModelT | ModelDTOT | RowMappingT] | None = None,
) -> (
ModelT
| OffsetPagination[ModelT]
| ModelDTOT
| OffsetPagination[ModelDTOT]
| RowMappingT
| OffsetPagination[RowMappingT]
):
schema_type: type[ModelDTOT] | None = None,
) -> ModelOrRowMappingT | OffsetPagination[ModelOrRowMappingT] | ModelDTOT | OffsetPagination[ModelDTOT]:
if schema_type is not None and issubclass(schema_type, Struct):
if not isinstance(data, Sequence):
return convert( # type: ignore # noqa: PGH003
Expand All @@ -124,7 +117,7 @@ def to_schema(
],
),
)
limit_offset = _find_filter(LimitOffset, *filters)
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total or len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[schema_type]( # type: ignore[valid-type]
Expand All @@ -144,10 +137,10 @@ def to_schema(
total=total,
)

if schema_type is not None and issubclass(schema_type, BaseModel):
if schema_type is not None and issubclass(schema_type, ModelMetaclass):
if not isinstance(data, Sequence):
return TypeAdapter(schema_type).validate_python(data, from_attributes=True) # type: ignore # noqa: PGH003
limit_offset = _find_filter(LimitOffset, *filters)
return TypeAdapter(schema_type).validate_python(data, from_attributes=True) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total if total else len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
return OffsetPagination[schema_type]( # type: ignore[valid-type]
Expand All @@ -157,12 +150,12 @@ def to_schema(
total=total,
)
if not issubclass(type(data), Sequence):
return data # type: ignore[return-value]
limit_offset = _find_filter(LimitOffset, *filters)
return cast("ModelOrRowMappingT", data)
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total or len(data) # type: ignore[arg-type]
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) # type: ignore[arg-type]
return OffsetPagination[ModelT](
items=data, # type: ignore[arg-type]
return OffsetPagination[ModelOrRowMappingT](
items=cast("List[ModelOrRowMappingT]", data),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
Expand Down
Loading

0 comments on commit b0c0bb3

Please sign in to comment.