Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update codebase to support fastapi>=0.100.0 and pydantic>=2.0.0 #447

Merged
merged 23 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6a09192
update codebase to support fastapi>=0.100.0 and pydantic>=2.0.0
johnybx Aug 21, 2023
f234e0d
update comment
johnybx Aug 21, 2023
43268a7
update examples
johnybx Aug 21, 2023
742accd
fix unsupported typing, fix compatibility with older python versions
johnybx Aug 22, 2023
5bc801b
fix breaking change of `with_prefix` -> keep same behaviour as before…
johnybx Sep 4, 2023
d5618bd
Merge branch 'main' into main
arthurio Sep 5, 2023
c457b97
Merge remote-tracking branch 'origin/main' into johnybx/main
arthurio Sep 5, 2023
9449e25
Merge remote-tracking branch 'origin/main' into johnybx/main
arthurio Sep 5, 2023
4af4450
Merge remote-tracking branch 'origin/main' into johnybx/main
arthurio Sep 5, 2023
30f4c3b
fix coverage
johnybx Sep 6, 2023
3174131
update required dependencies in readme
johnybx Sep 6, 2023
bf7e5c2
Merge branch 'main' into main
johnybx Sep 6, 2023
b1f8c0d
fix linting
johnybx Sep 6, 2023
73baabe
fix supported pydantic versions
johnybx Sep 6, 2023
5551a2a
Merge remote-tracking branch 'origin/main' into johnybx/main
arthurio Sep 12, 2023
e76051a
Update examples/fastapi_filter_mongoengine.py
johnybx Sep 13, 2023
94c45f2
Update examples/fastapi_filter_sqlalchemy.py
johnybx Sep 13, 2023
eadbac0
Update fastapi_filter/base/filter.py
johnybx Sep 13, 2023
cf2473b
resolve conflicts, update poetry.lock
johnybx Sep 13, 2023
35ff5ae
Merge branch 'main' into main
johnybx Sep 13, 2023
d8406fa
Merge branch 'main' into main
johnybx Sep 13, 2023
0f1c161
fix linting
johnybx Sep 13, 2023
a267eea
add note and docs section about limitations of union types in filter
johnybx Sep 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ fastapi_filter.sqlite
poetry.toml
.pytest_cache/
.ruff_cache/
__pycache__
34 changes: 15 additions & 19 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Add querystring filters to your api endpoints and show them in the swagger UI.
The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy) and
[MongoEngine](https://github.com/MongoEngine/mongoengine).


## Example

![Swagger UI](./swagger-ui.png)
Expand All @@ -25,16 +24,16 @@ as well as the type of operator, then tie your filter to a specific model.

By default, **fastapi_filter** supports the following operators:

- `neq`
- `gt`
- `gte`
- `in`
- `isnull`
- `lt`
- `lte`
- `not`/`ne`
- `not_in`/`nin`
- `like`/`ilike`
- `neq`
- `gt`
- `gte`
- `in`
- `isnull`
- `lt`
- `lte`
- `not`/`ne`
- `not_in`/`nin`
- `like`/`ilike`

_**Note:** Mysql excludes `None` values when using `in` filter_

Expand Down Expand Up @@ -89,7 +88,6 @@ Wherever you would use a `Depends`, replace it with `FilterDepends` if you are p
that `FilterDepends` converts the `list` filter fields to `str` so that they can be displayed and used in swagger.
It also handles turning `ValidationError` into `HTTPException(status_code=422)`.


### with_prefix

[link](https://github.com/arthurio/fastapi-filter/blob/main/fastapi_filter/base/filter.py#L21)
Expand Down Expand Up @@ -133,12 +131,11 @@ There is a specific field on the filter class that can be used for ordering. The
takes a list of string. From an API call perspective, just like the `__in` filters, you simply pass a comma separated
list of strings.

You can change the **direction** of the sorting (*asc* or *desc*) by prefixing with `-` or `+` (Optional, it's the
You can change the **direction** of the sorting (_asc_ or _desc_) by prefixing with `-` or `+` (Optional, it's the
default behavior if omitted).

If you don't want to allow ordering on your filter, just don't add `order_by` (or custom `ordering_field_name`) as a field and you are all set.


## Search

There is a specific field on the filter class that can be used for searching. The default name is `search` and it takes
Expand All @@ -148,7 +145,6 @@ You have to define what fields/columns to search in with the `search_model_field

If you don't want to allow searching on your filter, just don't add `search` (or custom `search_field_name`) as a field and you are all set.


### Example - Basic

```python
Expand Down Expand Up @@ -215,17 +211,17 @@ curl /users?custom_order_by=+id

### Restrict the `order_by` values

Add the following validator to your filter class:
Add the following field_validator to your filter class:

```python
from typing import Optional
from fastapi_filter.contrib.sqlalchemy import Filter
from pydantic import validator
from pydantic import field_validator

class MyFilter(Filter):
order_by: Optional[list[str]]

@validator("order_by")
@field_validator("order_by")
def restrict_sortable_fields(cls, value):
if value is None:
return None
Expand All @@ -241,7 +237,7 @@ class MyFilter(Filter):
```

1. If you want to restrict only on specific directions, like `-created_at` and `name` for example, you can remove this
line. Your `allowed_field_names` would be something like `["age", "-age", "-created_at"]`.
line. Your `allowed_field_names` would be something like `["age", "-age", "-created_at"]`.

### Example - Search

Expand Down
93 changes: 59 additions & 34 deletions fastapi_filter/base/filter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from collections import defaultdict
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from types import UnionType
from typing import Annotated, Any, Dict, Iterable, List, Optional, Tuple, Type, Union, get_args, get_origin

from fastapi import Depends
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Extra, ValidationError, create_model, fields, validator
from pydantic import (
BaseModel,
ConfigDict,
FieldValidationInfo,
PlainValidator,
ValidationError,
create_model,
field_validator,
)
from pydantic.fields import FieldInfo


class BaseFilterModel(BaseModel, extra=Extra.forbid):
class BaseFilterModel(BaseModel, extra="forbid"):
"""Abstract base filter class.

Provides the interface for filtering and ordering.
Expand Down Expand Up @@ -48,7 +56,7 @@ def filter(self, query): # pragma: no cover

@property
def filtering_fields(self):
fields = self.dict(exclude_none=True, exclude_unset=True)
fields = self.model_dump(exclude_none=True, exclude_unset=True)
fields.pop(self.Constants.ordering_field_name, None)
return fields.items()

Expand All @@ -66,13 +74,9 @@ def ordering_values(self):
"Make sure to add it to your filter class."
) from e

@validator("*", pre=True)
def split_str(cls, value, field): # pragma: no cover
...

@validator("*", pre=True, allow_reuse=True, check_fields=False)
def strip_order_by_values(cls, value, values, field):
if field.name != cls.Constants.ordering_field_name:
@field_validator("*", mode="before", check_fields=False)
def strip_order_by_values(cls, value, field: FieldValidationInfo):
if field.field_name != cls.Constants.ordering_field_name:
return value

if not value:
Expand All @@ -86,9 +90,9 @@ def strip_order_by_values(cls, value, values, field):

return stripped_values

@validator("*", allow_reuse=True, check_fields=False)
def validate_order_by(cls, value, values, field):
if field.name != cls.Constants.ordering_field_name:
@field_validator("*", mode="before", check_fields=False)
def validate_order_by(cls, value, field: FieldValidationInfo):
if field.field_name != cls.Constants.ordering_field_name:
return value

if not value:
Expand Down Expand Up @@ -135,9 +139,10 @@ def with_prefix(prefix: str, Filter: Type[BaseFilterModel]):
class NumberFilter(BaseModel):
count: Optional[int]

number_filter_prefixed, Annotation = with_prefix("number_filter", Filter)
class MainFilter(BaseModel):
name: str
number_filter: Optional[Filter] = FilterDepends(with_prefix("number_filter", Filter))
number_filter: Optional[Annotation] = FilterDepends(number_filter_prefixed)
```

As a result, you'll get the following filters:
Expand All @@ -156,9 +161,10 @@ class MainFilter(BaseModel):
class NumberFilter(BaseModel):
count: Optional[int] = Query(default=10, alias=counter)

number_filter_prefixed, Annotation = with_prefix("number_filter", Filter)
class MainFilter(BaseModel):
name: str
number_filter: Optional[Filter] = FilterDepends(with_prefix("number_filter", Filter))
number_filter: Optional[Annotation] = FilterDepends(number_filter_prefixed)
```

As a result, you'll get the following filters:
Expand All @@ -167,32 +173,51 @@ class MainFilter(BaseModel):
"""

class NestedFilter(Filter): # type: ignore[misc, valid-type]
class Config:
extra = Extra.forbid

@classmethod
def alias_generator(cls, string: str) -> str:
return f"{prefix}__{string}"
model_config = ConfigDict(extra="forbid", alias_generator=lambda string: f"{prefix}__{string}")

class Constants(Filter.Constants): # type: ignore[name-defined]
...

NestedFilter.Constants.prefix = prefix

return NestedFilter
def plain_validator(value):
# Make sure we validate Model.
# Probably would be better if this was subclass of specific Filter but
if issubclass(value.__class__, BaseModel):
value = value.model_dump()

if isinstance(value, dict):
stripped = {k.removeprefix(NestedFilter.Constants.prefix): v for k, v in value.items()}
return Filter(**stripped)

raise ValueError(f"Unexpected type: {type(value)}")

annotation = Annotated[Filter, PlainValidator(plain_validator)]

return NestedFilter, annotation
johnybx marked this conversation as resolved.
Show resolved Hide resolved


def _list_to_str_fields(Filter: Type[BaseFilterModel]):
ret: Dict[str, Tuple[Union[object, Type], Optional[FieldInfo]]] = {}
for f in Filter.__fields__.values():
field_info = deepcopy(f.field_info)
if f.shape == fields.SHAPE_LIST:
for name, f in Filter.model_fields.items():
field_info = deepcopy(f)
annotation = f.annotation

if get_origin(annotation) in [UnionType, Union]:
annotation_args: list = list(get_args(f.annotation))
if type(None) in annotation_args:
annotation_args.remove(type(None))
if len(annotation_args) == 1:
annotation = annotation_args[0]
# Not sure what to do if there is more then 1 value 🤔
johnybx marked this conversation as resolved.
Show resolved Hide resolved
# Do we need to handle Optional[Annotated[...]] ?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should raise an exception for now until we figure this out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left the comment here for future as a note that union of list type with other types is not supported just in case someone thought that for example type like list[str] | int | None should work with filter. At the same time I don't think that it is reasonable to support such types ( I mean transformation of list part to str split by comma in such complex types) without knowing real usecase. Also supporting such types would be problem in case of types like list[str] | str | None.

Raising exception here is not a good idea because the exception would be raised on any type which is union basically so for example int | float because this is union type and will have length 2 but at the same time it is valid type which we want to skip. Maybe I should just update comment to specifically mention that I am not sure what to do if the type is combination of list and other types like list[str] | str ? Because list filtering will not work in that case 🤔

What do you think ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnybx Do you have time to make that change or do you want me to do it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arthurio sure I just wasn't sure if you are fine with what I wrote. I added note to code and also added small section about current limitations of union types in filters with list. Let me know if something should be changed 👍


if annotation is list or get_origin(annotation) is list:
if isinstance(field_info.default, Iterable):
field_info.default = ",".join(field_info.default)
ret[f.name] = (str if f.required else Optional[str], field_info)
ret[name] = (str if f.is_required() else Optional[str], field_info)
else:
field_type = Filter.__annotations__.get(f.name, f.outer_type_)
ret[f.name] = (field_type if f.required else Optional[field_type], field_info)
ret[name] = (f.annotation, field_info)

return ret

Expand All @@ -214,16 +239,16 @@ def FilterDepends(Filter: Type[BaseFilterModel], *, by_alias: bool = False, use_
class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type]
def filter(self, *args, **kwargs):
try:
original_filter = Filter(**self.dict(by_alias=by_alias))
original_filter = Filter(**self.model_dump(by_alias=by_alias))
except ValidationError as e:
raise RequestValidationError(e.raw_errors) from e
raise RequestValidationError(e.errors()) from e
return original_filter.filter(*args, **kwargs)

def sort(self, *args, **kwargs):
try:
original_filter = Filter(**self.dict(by_alias=by_alias))
original_filter = Filter(**self.model_dump(by_alias=by_alias))
except ValidationError as e:
raise RequestValidationError(e.raw_errors) from e
raise RequestValidationError(e.errors()) from e
return original_filter.sort(*args, **kwargs)

return Depends(FilterWrapper)
19 changes: 11 additions & 8 deletions fastapi_filter/contrib/mongoengine/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mongoengine import QuerySet
from mongoengine.queryset.visitor import Q
from pydantic import validator
from pydantic import FieldValidationInfo, field_validator

from ...base.filter import BaseFilterModel

Expand Down Expand Up @@ -33,21 +33,24 @@ def sort(self, query: QuerySet) -> QuerySet:
return query
return query.order_by(*self.ordering_values)

@validator("*", pre=True)
def split_str(cls, value, field):
@field_validator("*", mode="before")
def split_str(cls, value, field: FieldValidationInfo):
if (
field.name == cls.Constants.ordering_field_name
or field.name.endswith("__in")
or field.name.endswith("__nin")
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__nin")
) and isinstance(value, str):
return [field.type_(v) for v in value.split(",")]
if not value:
# Empty string should return [] not ['']
return []
return [v for v in value.split(",")]
return value

def filter(self, query: QuerySet) -> QuerySet:
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
if not field_value.dict(exclude_none=True, exclude_unset=True):
if not field_value.model_dump(exclude_none=True, exclude_unset=True):
continue

query = query.filter(**{f"{field_name}__in": field_value.filter(field_value.Constants.model.objects())})
Expand Down
17 changes: 10 additions & 7 deletions fastapi_filter/contrib/sqlalchemy/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union
from warnings import warn

from pydantic import validator
from pydantic import FieldValidationInfo, field_validator
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.sql.selectable import Select
Expand Down Expand Up @@ -87,14 +87,17 @@ class Direction(str, Enum):
asc = "asc"
desc = "desc"

@validator("*", pre=True)
def split_str(cls, value, field):
@field_validator("*", mode="before")
def split_str(cls, value, field: FieldValidationInfo):
if (
field.name == cls.Constants.ordering_field_name
or field.name.endswith("__in")
or field.name.endswith("__not_in")
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__not_in")
) and isinstance(value, str):
return [field.type_(v) for v in value.split(",")]
if not value:
# Empty string should return [] not ['']
return []
return [v for v in value.split(",")]
return value

def filter(self, query: Union[Query, Select]):
Expand Down
Loading