Skip to content

Commit

Permalink
feat(filters): adds an ExistsFilter and NotExists filter
Browse files Browse the repository at this point in the history
Fixes #331
  • Loading branch information
cofin committed Jan 20, 2025
1 parent 114b0cd commit 21826ca
Show file tree
Hide file tree
Showing 2 changed files with 323 additions and 2 deletions.
176 changes: 174 additions & 2 deletions advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, cast

from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, or_, text
from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, exists, or_, select, text
from typing_extensions import TypeVar

if TYPE_CHECKING:
Expand All @@ -58,9 +58,11 @@
__all__ = (
"BeforeAfter",
"CollectionFilter",
"ExistsFilter",
"FilterTypes",
"InAnyFilter",
"LimitOffset",
"NotExistsFilter",
"NotInCollectionFilter",
"NotInSearchFilter",
"OnBeforeAfter",
Expand All @@ -79,7 +81,7 @@
"StatementTypeT",
bound="ReturningDelete[tuple[Any]] | ReturningUpdate[tuple[Any]] | Select[tuple[Any]] | Select[Any] | Update | Delete",
)
FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter"
FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter | ExistsFilter | NotExistsFilter"
"""Aggregate type alias of the types supported for collection filtering."""


Expand Down Expand Up @@ -576,3 +578,173 @@ def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]:
- :meth:`sqlalchemy.sql.expression.ColumnOperators.notilike`: NOT ILIKE
"""
return attrgetter("not_ilike" if self.ignore_case else "not_like")


@dataclass
class ExistsFilter(StatementFilter):
"""Filter for EXISTS subqueries.
This filter creates an EXISTS condition using a list of column expressions.
The expressions can be combined using either AND or OR logic.
Args:
field_name: The field to narrow against in the EXISTS clause
values: List of SQLAlchemy column expressions to use in the EXISTS clause
operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and".
Example:
Basic usage with AND conditions::
from sqlalchemy import select
from advanced_alchemy.filters import ExistsFilter
filter = ExistsFilter(
field_name="User.is_active",
values=[User.email.like("%@example.com%")],
)
statement = filter.append_to_statement(
select(Organization), Organization
)
Using OR conditions::
filter = ExistsFilter(
field_name="User.role",
values=[User.role == "admin", User.role == "owner"],
operator="or",
)
"""

values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
operator: Literal["and", "or"] = "and"
"""If "and", combines conditions with AND, otherwise uses OR."""

@property
def _and(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_

@property
def _or(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_

def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply EXISTS condition to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with EXISTS condition
Note:
The conditions are combined using AND or OR based on the operator parameter.
"""
if not self.values:
return statement

if self.operator == "and":
exists_clause = select(model).where(self._and(*self.values)).exists()
exists_clause = select(model).where(self._or(*self.values)).exists()
return cast("StatementTypeT", statement.where(exists_clause))


@dataclass
class NotExistsFilter(StatementFilter):
"""Filter for NOT EXISTS subqueries.
This filter creates a NOT EXISTS condition using a list of column expressions.
The expressions can be combined using either AND or OR logic.
Args:
field_name: The field to narrow against in the NOT EXISTS clause
values: List of SQLAlchemy column expressions to use in the NOT EXISTS clause
operator: If "and", combines conditions with AND, otherwise uses OR. Defaults to "and".
Example:
Basic usage with AND conditions::
from sqlalchemy import select
from advanced_alchemy.filters import NotExistsFilter
filter = NotExistsFilter(
values=[User.email.like("%@example.com%")],
)
statement = filter.append_to_statement(
select(Organization), Organization
)
Using OR conditions::
filter = NotExistsFilter(
values=[User.role == "admin", User.role == "owner"],
operator="or",
)
"""

values: list[ColumnElement[bool]]
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
operator: Literal["and", "or"] = "and"
"""If "and", combines conditions with AND, otherwise uses OR."""

@property
def _and(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `and_` operator for AND conditions
See Also:
:func:`sqlalchemy.sql.expression.and_`: SQLAlchemy AND operator
"""
return and_

@property
def _or(self) -> Callable[..., ColumnElement[bool]]:
"""Return the SQL operator for combining multiple search clauses.
Returns:
Callable[..., ColumnElement[bool]]: The `or_` operator for OR conditions
See Also:
:func:`sqlalchemy.sql.expression.or_`: SQLAlchemy OR operator
"""
return or_

def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
"""Apply NOT EXISTS condition to the statement.
Args:
statement: The SQLAlchemy statement to modify
model: The SQLAlchemy model class
Returns:
StatementTypeT: Modified statement with NOT EXISTS condition
Note:
The conditions are combined using AND or OR based on the operator parameter.
"""
if not self.values:
return statement

if self.operator == "and":
exists_clause = select(model).where(self._and(*self.values)).exists()
exists_clause = select(model).where(self._or(*self.values)).exists()
return cast("StatementTypeT", statement.where(~exists_clause))
149 changes: 149 additions & 0 deletions tests/integration/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from datetime import datetime, timezone
from pathlib import Path
from typing import Generator

import pytest
from sqlalchemy import String, create_engine, select
from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker

from advanced_alchemy.base import BigIntBase
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
ExistsFilter,
LimitOffset,
NotExistsFilter,
NotInCollectionFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
)


class Movie(BigIntBase):
__tablename__ = "movies"

title: Mapped[str] = mapped_column(String(length=100))
release_date: Mapped[datetime] = mapped_column()
genre: Mapped[str] = mapped_column(String(length=50))


@pytest.fixture()
def db_session(tmp_path: Path) -> Generator[Session, None, None]:
engine = create_engine(f"sqlite:///{tmp_path}/test_filters.sqlite", echo=True)
Movie.metadata.create_all(engine)
session_factory = sessionmaker(engine, expire_on_commit=False)
session = session_factory()
# Add test data
movie1 = Movie(title="The Matrix", release_date=datetime(1999, 3, 31, tzinfo=timezone.utc), genre="Action")
movie2 = Movie(title="The Hangover", release_date=datetime(2009, 6, 1, tzinfo=timezone.utc), genre="Comedy")
movie3 = Movie(
title="Shawshank Redemption", release_date=datetime(1994, 10, 14, tzinfo=timezone.utc), genre="Drama"
)
session.add_all([movie1, movie2, movie3])
session.commit()
yield session
session.close()
engine.dispose()
Path(tmp_path / "test_filters.sqlite").unlink(missing_ok=True)


def test_before_after_filter(db_session: Session) -> None:
before_after_filter = BeforeAfter(
field_name="release_date", before=datetime(1999, 3, 31, tzinfo=timezone.utc), after=None
)
statement = before_after_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 1


def test_on_before_after_filter(db_session: Session) -> None:
on_before_after_filter = OnBeforeAfter(
field_name="release_date", on_or_before=None, on_or_after=datetime(1999, 3, 31, tzinfo=timezone.utc)
)
statement = on_before_after_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_collection_filter(db_session: Session) -> None:
collection_filter = CollectionFilter(field_name="title", values=["The Matrix", "Shawshank Redemption"])
statement = collection_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_not_in_collection_filter(db_session: Session) -> None:
not_in_collection_filter = NotInCollectionFilter(field_name="title", values=["The Hangover"])
statement = not_in_collection_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_exists_filter_basic(db_session: Session) -> None:
exists_filter_1 = ExistsFilter(values=[Movie.genre == "Action"])
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 1

exists_filter_2 = ExistsFilter(values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")])
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_exists_filter(db_session: Session) -> None:
exists_filter_1 = ExistsFilter(values=[Movie.title.startswith("The")])
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 3

exists_filter_2 = ExistsFilter(
values=[Movie.title.startswith("Shawshank Redemption"), Movie.title.startswith("The")],
operator="and",
)
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 0

exists_filter_3 = ExistsFilter(
values=[Movie.title.startswith("The"), Movie.title.startswith("Shawshank")],
operator="or",
)
statement = exists_filter_3.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 3


def test_not_exists_filter(db_session: Session) -> None:
not_exists_filter = NotExistsFilter(values=[Movie.title.like("%Hangover%")])
statement = not_exists_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_limit_offset_filter(db_session: Session) -> None:
limit_offset_filter = LimitOffset(limit=2, offset=1)
statement = limit_offset_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 2


def test_order_by_filter(db_session: Session) -> None:
order_by_filter = OrderBy(field_name="release_date", sort_order="asc")
statement = order_by_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert results[0].title == "Shawshank Redemption"
order_by_filter = OrderBy(field_name="release_date", sort_order="desc")
statement = order_by_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert results[0].title == "The Hangover"


def test_search_filter(db_session: Session) -> None:
search_filter = SearchFilter(field_name="title", value="Hangover")
statement = search_filter.append_to_statement(select(Movie), Movie)
results = db_session.execute(statement).scalars().all()
assert len(results) == 1

0 comments on commit 21826ca

Please sign in to comment.