diff --git a/CHANGES.md b/CHANGES.md index fc4a11e9e..df7ae0d3c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -18,6 +18,7 @@ ### Changed +* Replaced `@attrs` with python `@dataclass` for `APIRequest` (model for GET request) class type [#714](https://github.com/stac-utils/stac-fastapi/pull/714) * Moved `GETPagination`, `POSTPagination`, `GETTokenPagination` and `POSTTokenPagination` to `stac_fastapi.extensions.core.pagination.request` submodule [#717](https://github.com/stac-utils/stac-fastapi/pull/717) ## [3.0.0a4] - 2024-06-27 diff --git a/docs/src/migrations/v3.0.0.md b/docs/src/migrations/v3.0.0.md index 6cbb3605a..8bc86f940 100644 --- a/docs/src/migrations/v3.0.0.md +++ b/docs/src/migrations/v3.0.0.md @@ -13,7 +13,6 @@ Most of the **stac-fastapi's** dependencies have been upgraded. Moving from pyda In addition to pydantic v2 update, `stac-pydantic` has been updated to better match the STAC and STAC-API specifications (see https://github.com/stac-utils/stac-pydantic/blob/main/CHANGELOG.md#310-2024-05-21) - ## Deprecation * the `ContextExtension` have been removed (see https://github.com/stac-utils/stac-pydantic/pull/138) and was replaced by optional `NumberMatched` and `NumberReturned` attributes, defined by the OGC features specification. @@ -24,6 +23,49 @@ In addition to pydantic v2 update, `stac-pydantic` has been updated to better ma * `PostFieldsExtension.filter_fields` property has been removed. +## `attr` -> `dataclass` for APIRequest models + +Models for **GET** requests, defining the path and query parameters, now uses python `dataclass` instead of `attr`. + +```python +# before +@attr.s +class CollectionModel(APIRequest): + collections: Optional[str] = attr.ib(default=None, converter=str2list) + +# now +@dataclass +class CollectionModel(APIRequest): + collections: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.collections: + self.collections = str2list(self.collections) # type: ignore + +``` + +!!! warning + + if you want to extend a class with a `required` attribute (without default), you will have to write all the attributes to avoid having *non-default* attributes defined after *default* attributes (ref: https://github.com/stac-utils/stac-fastapi/pull/714/files#r1651557338) + + ```python + @dataclass + class A: + value: Annotated[str, Query()] + + # THIS WON'T WORK + @dataclass + class B(A): + another_value: Annotated[str, Query(...)] + + # DO THIS + @dataclass + class B(A): + another_value: Annotated[str, Query(...)] + value: Annotated[str, Query()] + ``` + ## Middlewares configuration The `StacApi.middlewares` attribute has been updated to accept a list of `starlette.middleware.Middleware`. This enables dynamic configuration of middlewares (see https://github.com/stac-utils/stac-fastapi/pull/442). diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 307be14a7..7a39fe49a 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,12 +1,13 @@ """Api request/response models.""" import importlib.util +from dataclasses import dataclass, make_dataclass from typing import List, Optional, Type, Union -import attr -from fastapi import Path +from fastapi import Path, Query from pydantic import BaseModel, create_model from stac_pydantic.shared import BBox +from typing_extensions import Annotated from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.rfc3339 import DateTimeType @@ -37,11 +38,11 @@ def create_request_model( mixins = mixins or [] - models = [base_model] + extension_models + mixins + models = extension_models + mixins + [base_model] # Handle GET requests if all([issubclass(m, APIRequest) for m in models]): - return attr.make_class(model_name, attrs={}, bases=tuple(models)) + return make_dataclass(model_name, [], bases=tuple(models)) # Handle POST requests elif all([issubclass(m, BaseModel) for m in models]): @@ -80,34 +81,43 @@ def create_post_request_model( ) -@attr.s # type:ignore +@dataclass class CollectionUri(APIRequest): """Get or delete collection.""" - collection_id: str = attr.ib(default=Path(..., description="Collection ID")) + collection_id: Annotated[str, Path(description="Collection ID")] -@attr.s -class ItemUri(CollectionUri): +@dataclass +class ItemUri(APIRequest): """Get or delete item.""" - item_id: str = attr.ib(default=Path(..., description="Item ID")) + collection_id: Annotated[str, Path(description="Collection ID")] + item_id: Annotated[str, Path(description="Item ID")] -@attr.s +@dataclass class EmptyRequest(APIRequest): """Empty request.""" ... -@attr.s -class ItemCollectionUri(CollectionUri): +@dataclass +class ItemCollectionUri(APIRequest): """Get item collection.""" - limit: int = attr.ib(default=10) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) + collection_id: Annotated[str, Path(description="Collection ID")] + limit: Annotated[int, Query()] = 10 + bbox: Annotated[Optional[BBox], Query()] = None + datetime: Annotated[Optional[DateTimeType], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.bbox: + self.bbox = str2bbox(self.bbox) # type: ignore + if self.datetime: + self.datetime = str_to_interval(self.datetime) # type: ignore # Test for ORJSON and use it rather than stdlib JSON where supported diff --git a/stac_fastapi/api/tests/test_models.py b/stac_fastapi/api/tests/test_models.py index cbff0f53d..24ed59a18 100644 --- a/stac_fastapi/api/tests/test_models.py +++ b/stac_fastapi/api/tests/test_models.py @@ -1,6 +1,8 @@ import json import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient from pydantic import ValidationError from stac_fastapi.api.models import create_get_request_model, create_post_request_model @@ -26,13 +28,33 @@ def test_create_get_request_model(): datetime="2020-01-01T00:00:00Z", limit=10, filter="test==test", - # FIXME: https://github.com/stac-utils/stac-fastapi/issues/638 - # hyphen aliases are not properly working - # **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"}, + filter_crs="epsg:4326", + filter_lang="cql2-text", ) assert model.collections == ["test1", "test2"] - # assert model.filter_crs == "epsg:4326" + assert model.filter_crs == "epsg:4326" + + app = FastAPI() + + @app.get("/test") + def route(model=Depends(request_model)): + return model + + with TestClient(app) as client: + resp = client.get( + "/test", + params={ + "collections": "test1,test2", + "filter-crs": "epsg:4326", + "filter-lang": "cql2-text", + }, + ) + assert resp.status_code == 200 + response_dict = resp.json() + assert response_dict["collections"] == ["test1", "test2"] + assert response_dict["filter_crs"] == "epsg:4326" + assert response_dict["filter_lang"] == "cql2-text" @pytest.mark.parametrize( diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py index 08ebe0cfc..325fc55ee 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py @@ -1,8 +1,11 @@ """Request model for the Aggregation extension.""" +from dataclasses import dataclass from typing import List, Optional -import attr +from fastapi import Query +from pydantic import Field +from typing_extensions import Annotated from stac_fastapi.types.search import ( BaseSearchGetRequest, @@ -11,14 +14,20 @@ ) -@attr.s +@dataclass class AggregationExtensionGetRequest(BaseSearchGetRequest): """Aggregation Extension GET request model.""" - aggregations: Optional[str] = attr.ib(default=None, converter=str2list) + aggregations: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + super().__post_init__() + if self.aggregations: + self.aggregations = str2list(self.aggregations) # type: ignore class AggregationExtensionPostRequest(BaseSearchPostRequest): """Aggregation Extension POST request model.""" - aggregations: Optional[List[str]] = attr.ib(default=None) + aggregations: Optional[List[str]] = Field(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py index e08572ca0..a77539c0b 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -1,10 +1,12 @@ """Request models for the fields extension.""" import warnings +from dataclasses import dataclass from typing import Dict, Optional, Set -import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest, str2list @@ -68,11 +70,16 @@ def filter_fields(self) -> Dict: } -@attr.s +@dataclass class FieldsExtensionGetRequest(APIRequest): """Additional fields for the GET request.""" - fields: Optional[str] = attr.ib(default=None, converter=str2list) + fields: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.fields: + self.fields = str2list(self.fields) # type: ignore class FieldsExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py index 35a17bf36..970804b6d 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py @@ -1,22 +1,24 @@ """Filter extension request models.""" +from dataclasses import dataclass from typing import Any, Dict, Literal, Optional -import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest FilterLang = Literal["cql-json", "cql2-json", "cql2-text"] -@attr.s +@dataclass class FilterExtensionGetRequest(APIRequest): """Filter extension GET request model.""" - filter: Optional[str] = attr.ib(default=None) - filter_crs: Optional[str] = Field(alias="filter-crs", default=None) - filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql2-text") + filter: Annotated[Optional[str], Query()] = None + filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = None + filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = "cql2-text" class FilterExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py index 9524ee324..94d98df65 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py @@ -1,18 +1,20 @@ """Pagination extension request models.""" +from dataclasses import dataclass from typing import Optional -import attr +from fastapi import Query from pydantic import BaseModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest -@attr.s +@dataclass class GETTokenPagination(APIRequest): """Token pagination for GET requests.""" - token: Optional[str] = attr.ib(default=None) + token: Annotated[Optional[str], Query()] = None class POSTTokenPagination(BaseModel): @@ -21,11 +23,11 @@ class POSTTokenPagination(BaseModel): token: Optional[str] = None -@attr.s +@dataclass class GETPagination(APIRequest): """Page based pagination for GET requests.""" - page: Optional[str] = attr.ib(default=None) + page: Annotated[Optional[str], Query()] = None class POSTPagination(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py index 7f8425e70..d431b0dea 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -1,18 +1,20 @@ """Request model for the Query extension.""" +from dataclasses import dataclass from typing import Any, Dict, Optional -import attr +from fastapi import Query from pydantic import BaseModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest -@attr.s +@dataclass class QueryExtensionGetRequest(APIRequest): """Query Extension GET request model.""" - query: Optional[str] = attr.ib(default=None) + query: Annotated[Optional[str], Query()] = None class QueryExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py index 377067ff9..7165d2e31 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -1,20 +1,27 @@ # encoding: utf-8 """Request model for the Sort Extension.""" +from dataclasses import dataclass from typing import List, Optional -import attr +from fastapi import Query from pydantic import BaseModel from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest, str2list -@attr.s +@dataclass class SortExtensionGetRequest(APIRequest): """Sortby Parameter for GET requests.""" - sortby: Optional[str] = attr.ib(default=None, converter=str2list) + sortby: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.sortby: + self.sortby = str2list(self.sortby) # type: ignore class SortExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index a1c2391f6..27f2291d1 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -1,5 +1,6 @@ """Transaction extension.""" +from dataclasses import dataclass from typing import List, Optional, Type, Union import attr @@ -7,6 +8,7 @@ from stac_pydantic import Collection, Item, ItemCollection from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response +from typing_extensions import Annotated from stac_fastapi.api.models import CollectionUri, ItemUri from stac_fastapi.api.routes import create_async_endpoint @@ -15,25 +17,25 @@ from stac_fastapi.types.extension import ApiExtension -@attr.s +@dataclass class PostItem(CollectionUri): """Create Item.""" - item: Union[Item, ItemCollection] = attr.ib(default=Body(None)) + item: Annotated[Union[Item, ItemCollection], Body()] = None -@attr.s +@dataclass class PutItem(ItemUri): """Update Item.""" - item: Item = attr.ib(default=Body(None)) + item: Annotated[Item, Body()] = None -@attr.s +@dataclass class PutCollection(CollectionUri): """Update Collection.""" - collection: Collection = attr.ib(default=Body(None)) + collection: Annotated[Collection, Body()] = None @attr.s diff --git a/stac_fastapi/extensions/tests/test_aggregation.py b/stac_fastapi/extensions/tests/test_aggregation.py index c96e316ae..480cc669f 100644 --- a/stac_fastapi/extensions/tests/test_aggregation.py +++ b/stac_fastapi/extensions/tests/test_aggregation.py @@ -1,11 +1,15 @@ from typing import Iterator import pytest +from fastapi import Depends, FastAPI from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.extensions.core import AggregationExtension from stac_fastapi.extensions.core.aggregation.client import BaseAggregationClient +from stac_fastapi.extensions.core.aggregation.request import ( + AggregationExtensionGetRequest, +) from stac_fastapi.extensions.core.aggregation.types import ( Aggregation, AggregationCollection, @@ -100,3 +104,31 @@ def core_client() -> DummyCoreClient: @pytest.fixture def aggregations_client() -> BaseAggregationClient: return BaseAggregationClient() + + +def test_agg_get_query(): + """test AggregationExtensionGetRequest model.""" + app = FastAPI() + + @app.get("/test") + def test(query=Depends(AggregationExtensionGetRequest)): + return query + + with TestClient(app) as client: + response = client.get("/test") + assert response.is_success + params = response.json() + assert not params["collections"] + assert not params["aggregations"] + + response = client.get( + "/test", + params={ + "collections": "collection1,collection2", + "aggregations": "prop1,prop2", + }, + ) + assert response.is_success + params = response.json() + assert params["collections"] == ["collection1", "collection2"] + assert params["aggregations"] == ["prop1", "prop2"] diff --git a/stac_fastapi/extensions/tests/test_filter.py b/stac_fastapi/extensions/tests/test_filter.py index ca72dc51a..a13fb14c9 100644 --- a/stac_fastapi/extensions/tests/test_filter.py +++ b/stac_fastapi/extensions/tests/test_filter.py @@ -21,7 +21,8 @@ def get_item(self, *args, **kwargs): raise NotImplementedError def get_search(self, *args, **kwargs): - raise NotImplementedError + _ = kwargs.pop("request", None) + return kwargs def post_search(self, *args, **kwargs): return args[0].model_dump() @@ -73,3 +74,46 @@ def test_search_filter_post_filter_lang_non_default(client: TestClient): assert response.is_success, response.json() response_dict = response.json() assert response_dict["filter_lang"] == filter_lang_value + + +def test_search_filter_get(client: TestClient): + """Test search GET endpoint with filter ext.""" + response = client.get( + "/search", + params={ + "filter": "id='item_id' AND collection='collection_id'", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert not response_dict["collections"] + assert response_dict["filter"] == "id='item_id' AND collection='collection_id'" + assert not response_dict["filter_crs"] + assert response_dict["filter_lang"] == "cql2-text" + + response = client.get( + "/search", + params={ + "filter": {"op": "=", "args": [{"property": "id"}, "test-item"]}, + "filter-lang": "cql2-json", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert not response_dict["collections"] + assert ( + response_dict["filter"] + == "{'op': '=', 'args': [{'property': 'id'}, 'test-item']}" + ) + assert not response_dict["filter_crs"] + assert response_dict["filter_lang"] == "cql2-json" + + response = client.get( + "/search", + params={ + "collections": "collection1,collection2", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert response_dict["collections"] == ["collection1", "collection2"] diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index cf6647340..649a1a8ef 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -3,9 +3,10 @@ """ import abc +from dataclasses import dataclass from typing import Dict, List, Optional, Union -import attr +from fastapi import Query from pydantic import PositiveInt from pydantic.functional_validators import AfterValidator from stac_pydantic.api import Search @@ -42,7 +43,7 @@ def str2bbox(x: str) -> Optional[BBox]: Limit = Annotated[PositiveInt, AfterValidator(crop)] -@attr.s # type:ignore +@dataclass class APIRequest(abc.ABC): """Generic API Request base class.""" @@ -52,16 +53,27 @@ def kwargs(self) -> Dict: return self.__dict__ -@attr.s +@dataclass class BaseSearchGetRequest(APIRequest): """Base arguments for GET Request.""" - collections: Optional[str] = attr.ib(default=None, converter=str2list) - ids: Optional[str] = attr.ib(default=None, converter=str2list) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - intersects: Optional[str] = attr.ib(default=None) - datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) - limit: Optional[int] = attr.ib(default=10) + collections: Annotated[Optional[str], Query()] = None + ids: Annotated[Optional[str], Query()] = None + bbox: Annotated[Optional[BBox], Query()] = None + intersects: Annotated[Optional[str], Query()] = None + datetime: Annotated[Optional[DateTimeType], Query()] = None + limit: Annotated[Optional[int], Query()] = 10 + + def __post_init__(self): + """convert attributes.""" + if self.collections: + self.collections = str2list(self.collections) # type: ignore + if self.ids: + self.ids = str2list(self.ids) # type: ignore + if self.bbox: + self.bbox = str2bbox(self.bbox) # type: ignore + if self.datetime: + self.datetime = str_to_interval(self.datetime) # type: ignore class BaseSearchPostRequest(Search):