Skip to content

Commit

Permalink
remove pagination extension dependency and add request model attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Jun 26, 2024
1 parent f311dc6 commit 5de6249
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 19 deletions.
12 changes: 11 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@

## [Unreleased] - TBD

### Added

* Add attributes to `stac_fastapi.api.app.StacApi` to enable customization of request model for:
- `/collections`: **collections_get_request_model**, default to `EmptyRequest`
- `/collections/{collection_id}`: **collection_get_request_model**, default to `CollectionUri`
- `/collections/{collection_id}/items`: **items_get_request_model**, default to `ItemCollectionUri`
- `/collections/{collection_id}/items/{item_id}`: **item_get_request_model**, default to `ItemUri`

### Fixed

* Updated default filter language in filter extension's POST search request model to match the extension's documentation [#711](https://github.com/stac-utils/stac-fastapi/issues/711)

### Removed

* Removed the Filter Extension depenency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716)
* Removed the Filter Extension dependency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716)
* `pagination_extension` attribute in `stac_fastapi.api.app.StacApi`
* remove use of `pagination_extension` in `register_get_item_collection` function (User now need to construct the request model and pass it using `items_get_request_model` attribute)

## [3.0.0a3] - 2024-06-13

Expand Down
35 changes: 18 additions & 17 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
EmptyRequest,
GeoJSONResponse,
ItemCollectionUri,
ItemUri,
create_request_model,
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
from stac_fastapi.extensions.core import FieldsExtension
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
Expand Down Expand Up @@ -108,7 +108,10 @@ class StacApi:
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
default=BaseSearchPostRequest
)
pagination_extension = attr.ib(default=TokenPaginationExtension)
collections_get_request_model: Type[APIRequest] = attr.ib(default=EmptyRequest)
collection_get_request_model: Type[APIRequest] = attr.ib(default=CollectionUri)
items_get_request_model: Type[APIRequest] = attr.ib(default=ItemCollectionUri)
item_get_request_model: Type[APIRequest] = attr.ib(default=ItemUri)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List[Middleware] = attr.ib(
default=attr.Factory(
Expand Down Expand Up @@ -211,7 +214,9 @@ def register_get_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.get_item, ItemUri),
endpoint=create_async_endpoint(
self.client.get_item, self.item_get_request_model
),
)

def register_post_search(self):
Expand Down Expand Up @@ -302,7 +307,9 @@ def register_get_collections(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.all_collections, EmptyRequest),
endpoint=create_async_endpoint(
self.client.all_collections, self.collections_get_request_model
),
)

def register_get_collection(self):
Expand All @@ -329,7 +336,9 @@ def register_get_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.get_collection, CollectionUri),
endpoint=create_async_endpoint(
self.client.get_collection, self.collection_get_request_model
),
)

def register_get_item_collection(self):
Expand All @@ -338,16 +347,6 @@ def register_get_item_collection(self):
Returns:
None
"""
pagination_extension = self.get_extension(self.pagination_extension)
if pagination_extension is not None:
mixins = [pagination_extension.GET]
else:
mixins = None
request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=mixins,
)
self.router.add_api_route(
name="Get ItemCollection",
path="/collections/{collection_id}/items",
Expand All @@ -366,7 +365,9 @@ def register_get_item_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.item_collection, request_model),
endpoint=create_async_endpoint(
self.client.item_collection, self.items_get_request_model
),
)

def register_core(self):
Expand Down
9 changes: 9 additions & 0 deletions stac_fastapi/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from starlette.testclient import TestClient

from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import ItemCollectionUri, create_request_model
from stac_fastapi.extensions.core import (
TokenPaginationExtension,
TransactionExtension,
Expand All @@ -13,6 +14,13 @@ class TestRouteDependencies:
@staticmethod
def _build_api(**overrides):
settings = config.ApiSettings()

items_get_request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=[TokenPaginationExtension().GET],
)

return StacApi(
**{
"settings": settings,
Expand All @@ -23,6 +31,7 @@ def _build_api(**overrides):
),
TokenPaginationExtension(),
],
"items_get_request_model": items_get_request_model,
**overrides,
}
)
Expand Down
69 changes: 68 additions & 1 deletion stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Union

import pytest
from fastapi import Path, Query
from fastapi.testclient import TestClient
from pydantic import ValidationError
from stac_pydantic import api

from stac_fastapi.api import app
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.api.models import (
APIRequest,
create_get_request_model,
create_post_request_model,
)
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension
from stac_fastapi.types import stac
from stac_fastapi.types.config import ApiSettings
Expand Down Expand Up @@ -294,3 +300,64 @@ def item_collection(
else:
assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text


def test_request_model(AsyncTestCoreClient):
"""Test if request models are passed correctly."""

@dataclass
class CollectionsRequest(APIRequest):
user: str = Query(...)

@dataclass
class CollectionRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)

@dataclass
class ItemsRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)

@dataclass
class ItemRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
item_id: str = Path(description="Item ID")
user: str = Query(...)

test_app = app.StacApi(
settings=ApiSettings(),
client=AsyncTestCoreClient(),
collections_get_request_model=CollectionsRequest,
collection_get_request_model=CollectionRequest,
items_get_request_model=ItemsRequest,
item_get_request_model=ItemRequest,
extensions=[],
)

with TestClient(test_app.app) as client:
resp = client.get("/collections")
assert resp.status_code == 400

resp = client.get("/collections", params={"user": "luke"})
assert resp.status_code == 200

resp = client.get("/collections/test_collection")
assert resp.status_code == 400

resp = client.get("/collections/test_collection", params={"user": "luke"})
assert resp.status_code == 200

resp = client.get("/collections/test_collection/items")
assert resp.status_code == 400

resp = client.get("/collections/test_collection/items", params={"user": "luke"})
assert resp.status_code == 200

resp = client.get("/collections/test_collection/items/test_item")
assert resp.status_code == 400

resp = client.get(
"/collections/test_collection/items/test_item", params={"user": "luke"}
)
assert resp.status_code == 200

0 comments on commit 5de6249

Please sign in to comment.