Skip to content

Commit

Permalink
add tests for FieldsExtension impact on validation (#708)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Healy <[email protected]>
  • Loading branch information
vincentsarago and jonhealy1 authored Jun 14, 2024
1 parent 68dfbd5 commit 80064f7
Showing 1 changed file with 106 additions and 2 deletions.
108 changes: 106 additions & 2 deletions stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from stac_fastapi.api import app
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.extensions.core.filter.filter import FilterExtension
from stac_fastapi.extensions.core import FieldsExtension, FilterExtension
from stac_fastapi.types import stac
from stac_fastapi.types.config import ApiSettings
from stac_fastapi.types.core import NumType
from stac_fastapi.types.core import BaseCoreClient, NumType
from stac_fastapi.types.search import BaseSearchPostRequest


Expand Down Expand Up @@ -190,3 +190,107 @@ def get_search(
assert landing.status_code == 200, landing.text
assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text


@pytest.mark.parametrize("validate", [True, False])
def test_fields_extension(validate, TestCoreClient, item_dict):
"""Test if fields Parameters are passed correctly."""

class BadCoreClient(BaseCoreClient):
def post_search(
self, search_request: BaseSearchPostRequest, **kwargs
) -> stac.ItemCollection:
return {"not": "a proper stac item"}

def get_search(
self,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[List[NumType]] = None,
intersects: Optional[str] = None,
datetime: Optional[Union[str, datetime]] = None,
limit: Optional[int] = 10,
**kwargs,
) -> stac.ItemCollection:
return {"not": "a proper stac item"}

def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item:
raise NotImplementedError

def all_collections(self, **kwargs) -> stac.Collections:
raise NotImplementedError

def get_collection(self, collection_id: str, **kwargs) -> stac.Collection:
raise NotImplementedError

def item_collection(
self,
collection_id: str,
bbox: Optional[List[Union[float, int]]] = None,
datetime: Optional[Union[str, datetime]] = None,
limit: int = 10,
token: str = None,
**kwargs,
) -> stac.ItemCollection:
raise NotImplementedError

test_app = app.StacApi(
settings=ApiSettings(enable_response_models=validate),
client=BadCoreClient(),
search_get_request_model=create_get_request_model([FieldsExtension()]),
search_post_request_model=create_post_request_model([FieldsExtension()]),
extensions=[FieldsExtension()],
)

with TestClient(test_app.app) as client:
get_search = client.get(
"/search",
params={"fields": "properties.datetime"},
)
post_search = client.post(
"/search",
json={
"collections": ["test"],
"fields": {
"include": ["properties.datetime"],
"exclude": [],
},
},
)

assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text

test_app = app.StacApi(
settings=ApiSettings(enable_response_models=validate),
client=BadCoreClient(),
search_get_request_model=create_get_request_model([FieldsExtension()]),
search_post_request_model=create_post_request_model([FieldsExtension()]),
extensions=[],
)

with TestClient(test_app.app) as client:
get_search = client.get(
"/search",
params={"fields": "properties.datetime"},
)
post_search = client.post(
"/search",
json={
"collections": ["test"],
"fields": {
"include": ["properties.datetime"],
"exclude": [],
},
},
)
if validate:
assert get_search.status_code == 500, (
get_search.json()["code"] == "ResponseValidationError"
)
assert post_search.status_code == 500, (
post_search.json()["code"] == "ResponseValidationError"
)
else:
assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text

0 comments on commit 80064f7

Please sign in to comment.