From a6d149bb84ce02c67df049d88f75386c04b6a6f7 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Mon, 24 Jun 2024 21:34:32 +0200 Subject: [PATCH] move from attr to dataclass+fastapi.Query() for GET models --- stac_fastapi/api/stac_fastapi/api/models.py | 48 +++++++++++-------- stac_fastapi/api/tests/test_models.py | 30 ++++++++++-- .../extensions/core/aggregation/request.py | 21 ++++---- .../extensions/core/fields/request.py | 13 +++-- .../extensions/core/filter/request.py | 12 +++-- .../extensions/core/query/request.py | 8 ++-- .../extensions/core/sort/request.py | 13 +++-- .../extensions/core/transaction.py | 14 +++--- stac_fastapi/extensions/tests/test_filter.py | 10 ++++ .../types/stac_fastapi/types/search.py | 30 ++++++++---- 10 files changed, 136 insertions(+), 63 deletions(-) diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 2716fe7fb..ed050a6b2 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,12 +1,13 @@ """Api request/response models.""" import importlib.util +from dataclasses import dataclass, make_dataclass from typing import List, Optional, Type, Union -import attr -from fastapi import Path +from fastapi import Path, Query from pydantic import BaseModel, create_model from stac_pydantic.shared import BBox +from typing_extensions import Annotated from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.rfc3339 import DateTimeType @@ -37,11 +38,11 @@ def create_request_model( mixins = mixins or [] - models = [base_model] + extension_models + mixins + models = extension_models + mixins + [base_model] # Handle GET requests if all([issubclass(m, APIRequest) for m in models]): - return attr.make_class(model_name, attrs={}, bases=tuple(models)) + return make_dataclass(model_name, [], bases=tuple(models)) # Handle POST requests elif all([issubclass(m, BaseModel) for m in models]): @@ -80,34 +81,43 @@ def create_post_request_model( ) -@attr.s # type:ignore +@dataclass class CollectionUri(APIRequest): """Get or delete collection.""" - collection_id: str = attr.ib(default=Path(..., description="Collection ID")) + collection_id: Annotated[str, Path(description="Collection ID")] -@attr.s -class ItemUri(CollectionUri): +@dataclass +class ItemUri(APIRequest): """Get or delete item.""" - item_id: str = attr.ib(default=Path(..., description="Item ID")) + collection_id: Annotated[str, Path(description="Collection ID")] + item_id: Annotated[str, Path(description="Item ID")] -@attr.s +@dataclass class EmptyRequest(APIRequest): """Empty request.""" ... -@attr.s -class ItemCollectionUri(CollectionUri): +@dataclass +class ItemCollectionUri(APIRequest): """Get item collection.""" - limit: int = attr.ib(default=10) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) + collection_id: Annotated[str, Path(description="Collection ID")] + limit: Annotated[int, Query()] = 10 + bbox: Annotated[Optional[BBox], Query()] = None + datetime: Annotated[Optional[DateTimeType], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.bbox: + self.bbox = str2bbox(self.bbox) # type: ignore + if self.datetime: + self.datetime = str_to_interval(self.datetime) # type: ignore class POSTTokenPagination(BaseModel): @@ -116,11 +126,11 @@ class POSTTokenPagination(BaseModel): token: Optional[str] = None -@attr.s +@dataclass class GETTokenPagination(APIRequest): """Token pagination for GET requests.""" - token: Optional[str] = attr.ib(default=None) + token: Annotated[Optional[str], Query()] = None class POSTPagination(BaseModel): @@ -129,11 +139,11 @@ class POSTPagination(BaseModel): page: Optional[str] = None -@attr.s +@dataclass class GETPagination(APIRequest): """Page based pagination for GET requests.""" - page: Optional[str] = attr.ib(default=None) + page: Annotated[Optional[str], Query()] = None # Test for ORJSON and use it rather than stdlib JSON where supported diff --git a/stac_fastapi/api/tests/test_models.py b/stac_fastapi/api/tests/test_models.py index cbff0f53d..24ed59a18 100644 --- a/stac_fastapi/api/tests/test_models.py +++ b/stac_fastapi/api/tests/test_models.py @@ -1,6 +1,8 @@ import json import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient from pydantic import ValidationError from stac_fastapi.api.models import create_get_request_model, create_post_request_model @@ -26,13 +28,33 @@ def test_create_get_request_model(): datetime="2020-01-01T00:00:00Z", limit=10, filter="test==test", - # FIXME: https://github.com/stac-utils/stac-fastapi/issues/638 - # hyphen aliases are not properly working - # **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"}, + filter_crs="epsg:4326", + filter_lang="cql2-text", ) assert model.collections == ["test1", "test2"] - # assert model.filter_crs == "epsg:4326" + assert model.filter_crs == "epsg:4326" + + app = FastAPI() + + @app.get("/test") + def route(model=Depends(request_model)): + return model + + with TestClient(app) as client: + resp = client.get( + "/test", + params={ + "collections": "test1,test2", + "filter-crs": "epsg:4326", + "filter-lang": "cql2-text", + }, + ) + assert resp.status_code == 200 + response_dict = resp.json() + assert response_dict["collections"] == ["test1", "test2"] + assert response_dict["filter_crs"] == "epsg:4326" + assert response_dict["filter_lang"] == "cql2-text" @pytest.mark.parametrize( diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py index fcab3323f..97fa553be 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py @@ -1,24 +1,23 @@ """Request model for the Aggregation extension.""" +from dataclasses import dataclass from typing import List, Optional, Union -import attr +from fastapi import Query +from pydantic import BaseModel, Field +from typing_extensions import Annotated -from stac_fastapi.extensions.core.filter.request import ( - FilterExtensionGetRequest, - FilterExtensionPostRequest, -) -from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest +from stac_fastapi.types.search import APIRequest -@attr.s -class AggregationExtensionGetRequest(BaseSearchGetRequest, FilterExtensionGetRequest): +@dataclass +class AggregationExtensionGetRequest(APIRequest): """Aggregation Extension GET request model.""" - aggregations: Optional[str] = attr.ib(default=None) + aggregations: Annotated[Optional[str], Query()] = None -class AggregationExtensionPostRequest(BaseSearchPostRequest, FilterExtensionPostRequest): +class AggregationExtensionPostRequest(BaseModel): """Aggregation Extension POST request model.""" - aggregations: Optional[Union[str, List[str]]] = attr.ib(default=None) + aggregations: Optional[Union[str, List[str]]] = Field(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py index e08572ca0..a77539c0b 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -1,10 +1,12 @@ """Request models for the fields extension.""" import warnings +from dataclasses import dataclass from typing import Dict, Optional, Set -import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest, str2list @@ -68,11 +70,16 @@ def filter_fields(self) -> Dict: } -@attr.s +@dataclass class FieldsExtensionGetRequest(APIRequest): """Additional fields for the GET request.""" - fields: Optional[str] = attr.ib(default=None, converter=str2list) + fields: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.fields: + self.fields = str2list(self.fields) # type: ignore class FieldsExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py index 35a17bf36..970804b6d 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py @@ -1,22 +1,24 @@ """Filter extension request models.""" +from dataclasses import dataclass from typing import Any, Dict, Literal, Optional -import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest FilterLang = Literal["cql-json", "cql2-json", "cql2-text"] -@attr.s +@dataclass class FilterExtensionGetRequest(APIRequest): """Filter extension GET request model.""" - filter: Optional[str] = attr.ib(default=None) - filter_crs: Optional[str] = Field(alias="filter-crs", default=None) - filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql2-text") + filter: Annotated[Optional[str], Query()] = None + filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = None + filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = "cql2-text" class FilterExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py index 7f8425e70..d431b0dea 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -1,18 +1,20 @@ """Request model for the Query extension.""" +from dataclasses import dataclass from typing import Any, Dict, Optional -import attr +from fastapi import Query from pydantic import BaseModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest -@attr.s +@dataclass class QueryExtensionGetRequest(APIRequest): """Query Extension GET request model.""" - query: Optional[str] = attr.ib(default=None) + query: Annotated[Optional[str], Query()] = None class QueryExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py index 377067ff9..7165d2e31 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -1,20 +1,27 @@ # encoding: utf-8 """Request model for the Sort Extension.""" +from dataclasses import dataclass from typing import List, Optional -import attr +from fastapi import Query from pydantic import BaseModel from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest, str2list -@attr.s +@dataclass class SortExtensionGetRequest(APIRequest): """Sortby Parameter for GET requests.""" - sortby: Optional[str] = attr.ib(default=None, converter=str2list) + sortby: Annotated[Optional[str], Query()] = None + + def __post_init__(self): + """convert attributes.""" + if self.sortby: + self.sortby = str2list(self.sortby) # type: ignore class SortExtensionPostRequest(BaseModel): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index a1c2391f6..27f2291d1 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -1,5 +1,6 @@ """Transaction extension.""" +from dataclasses import dataclass from typing import List, Optional, Type, Union import attr @@ -7,6 +8,7 @@ from stac_pydantic import Collection, Item, ItemCollection from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response +from typing_extensions import Annotated from stac_fastapi.api.models import CollectionUri, ItemUri from stac_fastapi.api.routes import create_async_endpoint @@ -15,25 +17,25 @@ from stac_fastapi.types.extension import ApiExtension -@attr.s +@dataclass class PostItem(CollectionUri): """Create Item.""" - item: Union[Item, ItemCollection] = attr.ib(default=Body(None)) + item: Annotated[Union[Item, ItemCollection], Body()] = None -@attr.s +@dataclass class PutItem(ItemUri): """Update Item.""" - item: Item = attr.ib(default=Body(None)) + item: Annotated[Item, Body()] = None -@attr.s +@dataclass class PutCollection(CollectionUri): """Update Collection.""" - collection: Collection = attr.ib(default=Body(None)) + collection: Annotated[Collection, Body()] = None @attr.s diff --git a/stac_fastapi/extensions/tests/test_filter.py b/stac_fastapi/extensions/tests/test_filter.py index ec9712f25..a13fb14c9 100644 --- a/stac_fastapi/extensions/tests/test_filter.py +++ b/stac_fastapi/extensions/tests/test_filter.py @@ -107,3 +107,13 @@ def test_search_filter_get(client: TestClient): ) assert not response_dict["filter_crs"] assert response_dict["filter_lang"] == "cql2-json" + + response = client.get( + "/search", + params={ + "collections": "collection1,collection2", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert response_dict["collections"] == ["collection1", "collection2"] diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index cf6647340..649a1a8ef 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -3,9 +3,10 @@ """ import abc +from dataclasses import dataclass from typing import Dict, List, Optional, Union -import attr +from fastapi import Query from pydantic import PositiveInt from pydantic.functional_validators import AfterValidator from stac_pydantic.api import Search @@ -42,7 +43,7 @@ def str2bbox(x: str) -> Optional[BBox]: Limit = Annotated[PositiveInt, AfterValidator(crop)] -@attr.s # type:ignore +@dataclass class APIRequest(abc.ABC): """Generic API Request base class.""" @@ -52,16 +53,27 @@ def kwargs(self) -> Dict: return self.__dict__ -@attr.s +@dataclass class BaseSearchGetRequest(APIRequest): """Base arguments for GET Request.""" - collections: Optional[str] = attr.ib(default=None, converter=str2list) - ids: Optional[str] = attr.ib(default=None, converter=str2list) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - intersects: Optional[str] = attr.ib(default=None) - datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) - limit: Optional[int] = attr.ib(default=10) + collections: Annotated[Optional[str], Query()] = None + ids: Annotated[Optional[str], Query()] = None + bbox: Annotated[Optional[BBox], Query()] = None + intersects: Annotated[Optional[str], Query()] = None + datetime: Annotated[Optional[DateTimeType], Query()] = None + limit: Annotated[Optional[int], Query()] = 10 + + def __post_init__(self): + """convert attributes.""" + if self.collections: + self.collections = str2list(self.collections) # type: ignore + if self.ids: + self.ids = str2list(self.ids) # type: ignore + if self.bbox: + self.bbox = str2bbox(self.bbox) # type: ignore + if self.datetime: + self.datetime = str_to_interval(self.datetime) # type: ignore class BaseSearchPostRequest(Search):