Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove pagination extension dependency and add request model attributes #718

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

## [Unreleased] - TBD

### Added

* Add attributes to `stac_fastapi.api.app.StacApi` to enable customization of request model for:
- `/collections`: **collections_get_request_model**, default to `EmptyRequest`
- `/collections/{collection_id}`: **collection_get_request_model**, default to `CollectionUri`
- `/collections/{collection_id}/items`: **items_get_request_model**, default to `ItemCollectionUri`
- `/collections/{collection_id}/items/{item_id}`: **item_get_request_model**, default to `ItemUri`

### Removed

* Removed the Filter Extension dependency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716)
* `pagination_extension` attribute in `stac_fastapi.api.app.StacApi`
* remove use of `pagination_extension` in `register_get_item_collection` function (User now need to construct the request model and pass it using `items_get_request_model` attribute)

## [3.0.0a4] - 2024-06-27

### Fixed
Expand Down
35 changes: 18 additions & 17 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
EmptyRequest,
GeoJSONResponse,
ItemCollectionUri,
ItemUri,
create_request_model,
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
from stac_fastapi.extensions.core import FieldsExtension
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
Expand Down Expand Up @@ -108,7 +108,10 @@ class StacApi:
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
default=BaseSearchPostRequest
)
pagination_extension = attr.ib(default=TokenPaginationExtension)
collections_get_request_model: Type[APIRequest] = attr.ib(default=EmptyRequest)
collection_get_request_model: Type[APIRequest] = attr.ib(default=CollectionUri)
items_get_request_model: Type[APIRequest] = attr.ib(default=ItemCollectionUri)
item_get_request_model: Type[APIRequest] = attr.ib(default=ItemUri)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List[Middleware] = attr.ib(
default=attr.Factory(
Expand Down Expand Up @@ -211,7 +214,9 @@ def register_get_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.get_item, ItemUri),
endpoint=create_async_endpoint(
self.client.get_item, self.item_get_request_model
),
)

def register_post_search(self):
Expand Down Expand Up @@ -302,7 +307,9 @@ def register_get_collections(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.all_collections, EmptyRequest),
endpoint=create_async_endpoint(
self.client.all_collections, self.collections_get_request_model
),
)

def register_get_collection(self):
Expand All @@ -329,7 +336,9 @@ def register_get_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.get_collection, CollectionUri),
endpoint=create_async_endpoint(
self.client.get_collection, self.collection_get_request_model
),
)

def register_get_item_collection(self):
Expand All @@ -338,16 +347,6 @@ def register_get_item_collection(self):
Returns:
None
"""
pagination_extension = self.get_extension(self.pagination_extension)
if pagination_extension is not None:
mixins = [pagination_extension.GET]
else:
mixins = None
request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=mixins,
)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
self.router.add_api_route(
name="Get ItemCollection",
path="/collections/{collection_id}/items",
Expand All @@ -366,7 +365,9 @@ def register_get_item_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=create_async_endpoint(self.client.item_collection, request_model),
endpoint=create_async_endpoint(
self.client.item_collection, self.items_get_request_model
),
)

def register_core(self):
Expand Down
9 changes: 9 additions & 0 deletions stac_fastapi/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from starlette.testclient import TestClient

from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import ItemCollectionUri, create_request_model
from stac_fastapi.extensions.core import (
TokenPaginationExtension,
TransactionExtension,
Expand All @@ -13,6 +14,13 @@ class TestRouteDependencies:
@staticmethod
def _build_api(**overrides):
settings = config.ApiSettings()

items_get_request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=[TokenPaginationExtension().GET],
)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved

return StacApi(
**{
"settings": settings,
Expand All @@ -23,6 +31,7 @@ def _build_api(**overrides):
),
TokenPaginationExtension(),
],
"items_get_request_model": items_get_request_model,
**overrides,
}
)
Expand Down
71 changes: 70 additions & 1 deletion stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Union

import pytest
from fastapi import Path, Query
from fastapi.testclient import TestClient
from pydantic import ValidationError
from stac_pydantic import api

from stac_fastapi.api import app
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.api.models import (
APIRequest,
create_get_request_model,
create_post_request_model,
)
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension
from stac_fastapi.types import stac
from stac_fastapi.types.config import ApiSettings
Expand Down Expand Up @@ -294,3 +300,66 @@ def item_collection(
else:
assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text


def test_request_model(AsyncTestCoreClient):
"""Test if request models are passed correctly."""

@dataclass
class CollectionsRequest(APIRequest):
user: str = Query(...)

@dataclass
class CollectionRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)

@dataclass
class ItemsRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
user: str = Query(...)

@dataclass
class ItemRequest(APIRequest):
collection_id: str = Path(description="Collection ID")
item_id: str = Path(description="Item ID")
user: str = Query(...)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved

test_app = app.StacApi(
settings=ApiSettings(),
client=AsyncTestCoreClient(),
collections_get_request_model=CollectionsRequest,
collection_get_request_model=CollectionRequest,
items_get_request_model=ItemsRequest,
item_get_request_model=ItemRequest,
extensions=[],
)

with TestClient(test_app.app) as client:
resp = client.get("/collections")
assert resp.status_code == 400

resp = client.get("/collections", params={"user": "Luke"})
assert resp.status_code == 200

resp = client.get("/collections/test_collection")
assert resp.status_code == 400

resp = client.get("/collections/test_collection", params={"user": "Leia"})
assert resp.status_code == 200

resp = client.get("/collections/test_collection/items")
assert resp.status_code == 400

resp = client.get(
"/collections/test_collection/items", params={"user": "Obi-Wan"}
)
assert resp.status_code == 200

resp = client.get("/collections/test_collection/items/test_item")
assert resp.status_code == 400

resp = client.get(
"/collections/test_collection/items/test_item", params={"user": "Chewbacca"}
)
assert resp.status_code == 200