diff --git a/advanced_alchemy/filters.py b/advanced_alchemy/filters.py index ab6b6685..2b6676b4 100644 --- a/advanced_alchemy/filters.py +++ b/advanced_alchemy/filters.py @@ -41,7 +41,7 @@ from operator import attrgetter from typing import TYPE_CHECKING, Any, Generic, Literal, cast -from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, or_, text +from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, exists, or_, select, text from typing_extensions import TypeVar if TYPE_CHECKING: @@ -58,9 +58,11 @@ __all__ = ( "BeforeAfter", "CollectionFilter", + "ExistsFilter", "FilterTypes", "InAnyFilter", "LimitOffset", + "NotExistsFilter", "NotInCollectionFilter", "NotInSearchFilter", "OnBeforeAfter", @@ -79,7 +81,7 @@ "StatementTypeT", bound="ReturningDelete[tuple[Any]] | ReturningUpdate[tuple[Any]] | Select[tuple[Any]] | Select[Any] | Update | Delete", ) -FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter" +FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter | ExistsFilter | NotExistsFilter" """Aggregate type alias of the types supported for collection filtering.""" @@ -576,3 +578,173 @@ def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: - :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE """ return attrgetter("not_ilike" if self.ignore_case else "not_like") + + +@dataclass +class ExistsFilter(StatementFilter): + """Filter for EXISTS subqueries. + + This filter creates an EXISTS condition using a list of column expressions. + The expressions can be combined using either AND or OR logic. + + Args: + field_name: The field to narrow against in the EXISTS clause + values: List of SQLAlchemy column expressions to use in the EXISTS clause + operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and". + + Example: + Basic usage with AND conditions:: + + from sqlalchemy import select + from advanced_alchemy.filters import ExistsFilter + + filter = ExistsFilter( + field_name="User.is_active", + values=[User.email.like("%@example.com%")], + ) + statement = filter.append_to_statement( + select(Organization), Organization + ) + + Using OR conditions:: + + filter = ExistsFilter( + field_name="User.role", + values=[User.role == "admin", User.role == "owner"], + operator="or", + ) + """ + + values: list[ColumnElement[bool]] + """List of SQLAlchemy column expressions to use in the EXISTS clause.""" + operator: Literal["and", "or"] = "and" + """If "and", combines conditions with AND, otherwise uses OR.""" + + @property + def _and(self) -> Callable[..., ColumnElement[bool]]: + """Return the SQL operator for combining multiple search clauses. + + Returns: + Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions + + See Also: + :func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator + """ + return and_ + + @property + def _or(self) -> Callable[..., ColumnElement[bool]]: + """Return the SQL operator for combining multiple search clauses. + + Returns: + Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions + + See Also: + :func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator + """ + return or_ + + def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: + """Apply EXISTS condition to the statement. + + Args: + statement: The SQLAlchemy statement to modify + model: The SQLAlchemy model class + + Returns: + StatementTypeT: Modified statement with EXISTS condition + + Note: + The conditions are combined using AND or OR based on the operator parameter. + """ + if not self.values: + return statement + + if self.operator == "and": + exists_clause = select(model).where(self._and(*self.values)).exists() + exists_clause = select(model).where(self._or(*self.values)).exists() + return cast("StatementTypeT", statement.where(exists_clause)) + + +@dataclass +class NotExistsFilter(StatementFilter): + """Filter for NOT EXISTS subqueries. + + This filter creates a NOT EXISTS condition using a list of column expressions. + The expressions can be combined using either AND or OR logic. + + Args: + field_name: The field to narrow against in the NOT EXISTS clause + values: List of SQLAlchemy column expressions to use in the NOT EXISTS clause + operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and". + + Example: + Basic usage with AND conditions:: + + from sqlalchemy import select + from advanced_alchemy.filters import NotExistsFilter + + filter = NotExistsFilter( + values=[User.email.like("%@example.com%")], + ) + statement = filter.append_to_statement( + select(Organization), Organization + ) + + Using OR conditions:: + + filter = NotExistsFilter( + values=[User.role == "admin", User.role == "owner"], + operator="or", + ) + """ + + values: list[ColumnElement[bool]] + """List of SQLAlchemy column expressions to use in the EXISTS clause.""" + operator: Literal["and", "or"] = "and" + """If "and", combines conditions with AND, otherwise uses OR.""" + + @property + def _and(self) -> Callable[..., ColumnElement[bool]]: + """Return the SQL operator for combining multiple search clauses. + + Returns: + Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions + + See Also: + :func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator + """ + return and_ + + @property + def _or(self) -> Callable[..., ColumnElement[bool]]: + """Return the SQL operator for combining multiple search clauses. + + Returns: + Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions + + See Also: + :func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator + """ + return or_ + + def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT: + """Apply NOT EXISTS condition to the statement. + + Args: + statement: The SQLAlchemy statement to modify + model: The SQLAlchemy model class + + Returns: + StatementTypeT: Modified statement with NOT EXISTS condition + + Note: + The conditions are combined using AND or OR based on the operator parameter. + """ + if not self.values: + return statement + + if self.operator == "and": + exists_clause = select(model).where(self._and(*self.values)).exists() + exists_clause = select(model).where(self._or(*self.values)).exists() + return cast("StatementTypeT", statement.where(~exists_clause)) diff --git a/tests/integration/test_filters.py b/tests/integration/test_filters.py new file mode 100644 index 00000000..83de3640 --- /dev/null +++ b/tests/integration/test_filters.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Generator + +import pytest +from sqlalchemy import String, create_engine, select +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker + +from advanced_alchemy.base import BigIntBase +from advanced_alchemy.filters import ( + BeforeAfter, + CollectionFilter, + ExistsFilter, + LimitOffset, + NotExistsFilter, + NotInCollectionFilter, + OnBeforeAfter, + OrderBy, + SearchFilter, +) + + +class Movie(BigIntBase): + __tablename__ = "movies" + + title: Mapped[str] = mapped_column(String(length=100)) + release_date: Mapped[datetime] = mapped_column() + genre: Mapped[str] = mapped_column(String(length=50)) + + +@pytest.fixture() +def db_session(tmp_path: Path) -> Generator[Session, None, None]: + engine = create_engine(f"sqlite:///{tmp_path}/test_filters.sqlite", echo=True) + Movie.metadata.create_all(engine) + session_factory = sessionmaker(engine, expire_on_commit=False) + session = session_factory() + # Add test data + movie1 = Movie(title="The Matrix", release_date=datetime(1999, 3, 31, tzinfo=timezone.utc), genre="Action") + movie2 = Movie(title="The Hangover", release_date=datetime(2009, 6, 1, tzinfo=timezone.utc), genre="Comedy") + movie3 = Movie( + title="Shawshank Redemption", release_date=datetime(1994, 10, 14, tzinfo=timezone.utc), genre="Drama" + ) + session.add_all([movie1, movie2, movie3]) + session.commit() + yield session + session.close() + engine.dispose() + Path(tmp_path / "test_filters.sqlite").unlink(missing_ok=True) + + +def test_before_after_filter(db_session: Session) -> None: + before_after_filter = BeforeAfter( + field_name="release_date", before=datetime(1999, 3, 31, tzinfo=timezone.utc), after=None + ) + statement = before_after_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 1 + + +def test_on_before_after_filter(db_session: Session) -> None: + on_before_after_filter = OnBeforeAfter( + field_name="release_date", on_or_before=None, on_or_after=datetime(1999, 3, 31, tzinfo=timezone.utc) + ) + statement = on_before_after_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_collection_filter(db_session: Session) -> None: + collection_filter = CollectionFilter(field_name="title", values=["The Matrix", "Shawshank Redemption"]) + statement = collection_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_not_in_collection_filter(db_session: Session) -> None: + not_in_collection_filter = NotInCollectionFilter(field_name="title", values=["The Hangover"]) + statement = not_in_collection_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_exists_filter_basic(db_session: Session) -> None: + exists_filter_1 = ExistsFilter(values=[Movie.genre == "Action"]) + statement = exists_filter_1.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 1 + + exists_filter_2 = ExistsFilter(values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")]) + statement = exists_filter_2.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_exists_filter(db_session: Session) -> None: + exists_filter_1 = ExistsFilter(values=[Movie.title.startswith("The")]) + statement = exists_filter_1.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 3 + + exists_filter_2 = ExistsFilter( + values=[Movie.title.startswith("Shawshank Redemption"), Movie.title.startswith("The")], + operator="and", + ) + statement = exists_filter_2.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 0 + + exists_filter_3 = ExistsFilter( + values=[Movie.title.startswith("The"), Movie.title.startswith("Shawshank")], + operator="or", + ) + statement = exists_filter_3.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 3 + + +def test_not_exists_filter(db_session: Session) -> None: + not_exists_filter = NotExistsFilter(values=[Movie.title.like("%Hangover%")]) + statement = not_exists_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_limit_offset_filter(db_session: Session) -> None: + limit_offset_filter = LimitOffset(limit=2, offset=1) + statement = limit_offset_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 2 + + +def test_order_by_filter(db_session: Session) -> None: + order_by_filter = OrderBy(field_name="release_date", sort_order="asc") + statement = order_by_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert results[0].title == "Shawshank Redemption" + order_by_filter = OrderBy(field_name="release_date", sort_order="desc") + statement = order_by_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert results[0].title == "The Hangover" + + +def test_search_filter(db_session: Session) -> None: + search_filter = SearchFilter(field_name="title", value="Hangover") + statement = search_filter.append_to_statement(select(Movie), Movie) + results = db_session.execute(statement).scalars().all() + assert len(results) == 1