diff --git a/examples/docker-compose.yml b/examples/docker-compose.yml new file mode 100644 index 00000000..66138d94 --- /dev/null +++ b/examples/docker-compose.yml @@ -0,0 +1,24 @@ +version: "3.3" + +services: + + elasticsearch: + image: elasticsearch:7.6.2 + # restart: always + ports: + - 9200:9200 + environment: + - node.name=fastapi-filter-es + - cluster.name=fastapi-filter-es-docker-cluster + - discovery.type=single-node + - bootstrap.memory_lock=true + - "ES_JAVA_OPTS=-Xms512m -Xmx512m" + ulimits: + memlock: + soft: -1 + hard: -1 + volumes: + - elasticsearch-data:/usr/share/elasticsearch/data + +volumes: + elasticsearch-data: \ No newline at end of file diff --git a/examples/fastapi_filter_elasticsearch_dsl.py b/examples/fastapi_filter_elasticsearch_dsl.py new file mode 100644 index 00000000..4035216e --- /dev/null +++ b/examples/fastapi_filter_elasticsearch_dsl.py @@ -0,0 +1,189 @@ +import logging +from typing import Any, List, Optional + +import uvicorn +from faker import Faker +from fastapi import FastAPI +from pydantic import BaseModel, ConfigDict, EmailStr + +from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter.contrib.elasticsearch_dsl import Filter + +fake = Faker() + +logger = logging.getLogger("uvicorn") +from datetime import datetime +from fnmatch import fnmatch + +from elasticsearch_dsl import Document, Keyword, connections, Integer, Nested, SearchAsYouType, InnerDoc + + +ALIAS = "address" +PATTERN = ALIAS + "-*" + + +class Address(InnerDoc): + street = Keyword() + city = SearchAsYouType() + country = Keyword() + number = Integer() + + +class User(Document): + name = SearchAsYouType() + email = Keyword() + age = Integer() + address = Nested(Address) + + @classmethod + def _matches(cls, hit): + return fnmatch(hit["_index"], PATTERN) + + class Index: + name = ALIAS + settings = {"number_of_shards": 1, "number_of_replicas": 0} + + +def setup(): + index_template = User._index.as_template(ALIAS, PATTERN) + index_template.save() + + if not User._index.exists(): + migrate(move_data=False) + + +def migrate(move_data=True, update_alias=True): + # construct a new index name by appending current timestamp + next_index = PATTERN.replace("*", datetime.now().strftime("%Y%m%d%H%M%S%f")) + es = connections.get_connection() + # create new index, it will use the settings from the template + es.indices.create(index=next_index) + if move_data: + # move data from current alias to the new index + es.reindex( + body={"source": {"index": ALIAS}, "dest": {"index": next_index}}, + request_timeout=3600, + ) + # refresh the index to make the changes visible + es.indices.refresh(index=next_index) + + if update_alias: + # repoint the alias to point to the newly created index + es.indices.update_aliases( + body={ + "actions": [ + {"remove": {"alias": ALIAS, "index": PATTERN}}, + {"add": {"alias": ALIAS, "index": next_index}}, + ] + } + ) + + +class AddressOut(BaseModel): + street: Optional[str] = None + city: str + number: int + country: str + + class Config: + orm_mode = True + + +class UserIn(BaseModel): + name: str + email: EmailStr + age: int + + +class UserOut(UserIn): + model_config = ConfigDict(from_attributes=True) + + name: str + email: EmailStr + age: int + address: Optional[AddressOut] = None + + +class AddressFilter(Filter): + street: Optional[str] = None + number: Optional[int] = None + number__gt: Optional[int] = None + number__gte: Optional[int] = None + number__lt: Optional[int] = None + number__lte: Optional[int] = None + street__isnull: Optional[bool] = None + country: Optional[str] = None + country_not: Optional[str] = None + city: Optional[str] = None + city__in: Optional[List[str]] = None + city__not_in: Optional[List[str]] = ["city"] + custom_order_by: Optional[List[str]] = None + custom_search: Optional[str] = None + order_by: List[str] = ["-street"] + + class Constants(Filter.Constants): + model = Address + # ordering_field_name = "street" + search_field_name = "custom_search" + search_model_fields = ["street", "country", "city"] + + +class UserFilter(Filter): + name: Optional[str] = None + address: Optional[AddressFilter] = FilterDepends(with_prefix("address", AddressFilter)) + age__lt: Optional[int] = None + # age__gte: int = Field(Query(description="this is a nice description")) + """Required field with a custom description. + + See: https://github.com/tiangolo/fastapi/issues/4700 for why we need to wrap `Query` in `Field`. + """ + order_by: List[str] = ["-age"] + search: Optional[str] = None + + class Constants(Filter.Constants): + model = User + search_model_fields = ["name"] + + +app = FastAPI() + + +@app.on_event("startup") +async def on_startup() -> None: + connections.create_connection(hosts="http://localhost:9200") + + setup() + migrate() + + for i in range(100): + if i % 5 == 0: + address = Address( + street=fake.street_address(), + city=fake.city(), + country=fake.country(), + number=fake.random_int(min=5, max=100), + ) + else: + address = Address(city=fake.city(), country=fake.country(), number=fake.random_int(min=5, max=100)) + user = User(name=fake.name(), email=fake.email(), age=fake.random_int(min=5, max=120), address=address) + user.save() + + +@app.on_event("shutdown") +async def on_shutdown() -> None: + s = Address.search().query("match_all") + s.delete() + + +@app.get("/users", response_model=List[UserOut]) +async def get_users( + user_filter: UserFilter = FilterDepends(with_prefix("my_custom_prefix", UserFilter), by_alias=True), +) -> Any: + query = user_filter.filter(User.search()) + query = user_filter.sort(query) + response = query.execute() + return [UserOut(**user.to_dict()) for user in response] + + +if __name__ == "__main__": + uvicorn.run("main:app", reload=True) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 00000000..d2be7845 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,26 @@ +annotated-types==0.6.0 +anyio==4.3.0 +bson==0.5.10 +certifi==2024.2.2 +click==8.1.7 +dnspython==2.6.1 +elastic-transport==8.12.0 +elasticsearch==7.17.9 +elasticsearch-dsl==7.4.1 +email-validator==2.1.0.post1 +Faker==23.2.1 +fastapi==0.109.2 +h11==0.14.0 +idna==3.6 +mongoengine==0.27.0 +pydantic==2.6.2 +pydantic_core==2.16.3 +pymongo==4.6.2 +python-dateutil==2.8.2 +six==1.16.0 +sniffio==1.3.0 +starlette==0.36.3 +typing_extensions==4.9.0 +urllib3==1.26.18 +uvicorn==0.27.1 + diff --git a/fastapi_filter/contrib/elasticsearch_dsl/__init__.py b/fastapi_filter/contrib/elasticsearch_dsl/__init__.py new file mode 100644 index 00000000..93429855 --- /dev/null +++ b/fastapi_filter/contrib/elasticsearch_dsl/__init__.py @@ -0,0 +1,3 @@ +from .filter import Filter + +__all__ = ("Filter",) diff --git a/fastapi_filter/contrib/elasticsearch_dsl/filter.py b/fastapi_filter/contrib/elasticsearch_dsl/filter.py new file mode 100644 index 00000000..61cb12fc --- /dev/null +++ b/fastapi_filter/contrib/elasticsearch_dsl/filter.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +from elasticsearch_dsl import Q, Search +from elasticsearch_dsl.query import Query +from pydantic import ValidationInfo, field_validator + +from ...base.filter import BaseFilterModel + + +_operator_transformer = { + "neq": lambda value, field_name: ~Q("term", **{field_name: value}), + "gt": lambda value, field_name: Q("range", **{field_name: {"gt": value}}), + "gte": lambda value, field_name: Q("range", **{field_name: {"gte": value}}), + "lt": lambda value, field_name: Q("range", **{field_name: {"lt": value}}), + "lte": lambda value, field_name: Q("range", **{field_name: {"lte": value}}), + "in": lambda value, field_name: Q("terms", **{field_name: value}), + "isnull": lambda value, field_name: ~Q("exists", field=field_name) + if value is True + else Q("exists", field=field_name), + "not": lambda value, field_name: ~Q("term", **{field_name: value}), + "not_in": lambda value, field_name: ~Q("terms", **{field_name: value}), + "nin": lambda value, field_name: ~Q("terms", **{field_name: value}), +} + + +class Filter(BaseFilterModel): + """Base filter for elasticsearch_dsl related filters. + + Example: + ```python + + class MyModel(Document): + street = Keyword() + city = Keyword() + country = Keyword() + number = Integer() + + class MyModelFilter(Filter): + street: Optional[str] = None + number: Optional[int] = None + number__gt: Optional[int] = None + number__gte: Optional[int] = None + number__lt: Optional[int] = None + number__lte: Optional[int] = None + street__isnull: Optional[bool] = None + country: Optional[str] = None + country_not: Optional[str] = None + city: Optional[str] = None + city__in: Optional[List[str]] = None + city__not_in: Optional[List[str]] = ["city"] + custom_order_by: Optional[List[str]] = None + custom_search: Optional[str] = None + order_by: List[str] = ["-street"] + ``` + """ + + def sort(self, query: Search) -> Search: + if not self.ordering_values: + return query + return query.sort(*self.ordering_values) + + @field_validator("*", mode="before") + def split_str(cls, value, field: ValidationInfo): + if ( + field.field_name is not None + and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__nin") + or field.field_name.endswith("__not_in") + ) + and isinstance(value, str) + ): + if not value: + # Empty string should return [] not [''] + return [] + return list(value.split(",")) + return value + + def make_query(self, field_name: str, value) -> Query: + if "__" in field_name: + field_name, operator = field_name.split("__") + query = _operator_transformer[operator](value, field_name) + elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"): + query = Q( + "multi_match", + type="bool_prefix", + fields=[ + field_gram + for field in self.Constants.search_model_fields + for field_gram in [f"{field}", f"{field}._2gram", f"{field}._3gram"] + ], + query=value, + ) + else: + query = Q("term", **{field_name: value}) + return query + + def filter(self, search: Search) -> Search: + queries = Q() + for field_name, value in self.filtering_fields: + field_value = getattr(self, field_name) + if isinstance(field_value, Filter): + nested_queries = Q() + for inner_field, inner_value in field_value.filtering_fields: + nested_queries &= self.make_query(f"{field_name}.{inner_field}", inner_value) + search.query("nested", path=field_name, query=nested_queries) + else: + queries &= self.make_query(field_name, value) + + return search.query(queries)