From eea9c8227705a35425b7e2f864946a787f9ecaef Mon Sep 17 00:00:00 2001 From: Thomas Maschler Date: Thu, 4 Apr 2024 21:24:55 -0400 Subject: [PATCH] remove response_model module, update openapi schema --- stac_fastapi/api/stac_fastapi/api/app.py | 84 ++++++++-- stac_fastapi/api/tests/conftest.py | 38 ++--- stac_fastapi/api/tests/test_app.py | 149 ++++++++++++------ .../types/stac_fastapi/types/config.py | 2 - stac_fastapi/types/stac_fastapi/types/core.py | 106 ++++++------- .../stac_fastapi/types/response_model.py | 27 ---- .../types/tests/test_response_model.py | 39 ----- 7 files changed, 234 insertions(+), 211 deletions(-) delete mode 100644 stac_fastapi/types/stac_fastapi/types/response_model.py delete mode 100644 stac_fastapi/types/tests/test_response_model.py diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 1ecc7b5ba..7db477526 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -7,10 +7,10 @@ from fastapi import APIRouter, FastAPI from fastapi.openapi.utils import get_openapi from fastapi.params import Depends -from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.api import ConformanceClasses, LandingPage +from stac_pydantic import api from stac_pydantic.api.collections import Collections from stac_pydantic.api.version import STAC_API_VERSION +from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers @@ -127,8 +127,16 @@ def register_landing_page(self): name="Landing Page", path="/", response_model=( - LandingPage if self.settings.enable_response_models else None + api.LandingPage if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.LandingPage, + }, + }, response_class=self.response_class, response_model_exclude_unset=False, response_model_exclude_none=True, @@ -148,8 +156,16 @@ def register_conformance_classes(self): name="Conformance Classes", path="/conformance", response_model=( - ConformanceClasses if self.settings.enable_response_models else None + api.ConformanceClasses if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.ConformanceClasses, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -168,7 +184,15 @@ def register_get_item(self): self.router.add_api_route( name="Get Item", path="/collections/{collection_id}/items/{item_id}", - response_model=Item if self.settings.enable_response_models else None, + response_model=api.Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.Item, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -189,10 +213,18 @@ def register_post_search(self): name="Search", path="/search", response_model=( - (ItemCollection if not fields_ext else None) + (api.ItemCollection if not fields_ext else None) if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -213,10 +245,18 @@ def register_get_search(self): name="Search", path="/search", response_model=( - (ItemCollection if not fields_ext else None) + (api.ItemCollection if not fields_ext else None) if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -238,6 +278,14 @@ def register_get_collections(self): response_model=( Collections if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collections, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -256,7 +304,17 @@ def register_get_collection(self): self.router.add_api_route( name="Get Collection", path="/collections/{collection_id}", - response_model=Collection if self.settings.enable_response_models else None, + response_model=api.Collection + if self.settings.enable_response_models + else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.Collection, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -286,8 +344,16 @@ def register_get_item_collection(self): name="Get ItemCollection", path="/collections/{collection_id}/items", response_model=( - ItemCollection if self.settings.enable_response_models else None + api.ItemCollection if self.settings.enable_response_models else None ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, diff --git a/stac_fastapi/api/tests/conftest.py b/stac_fastapi/api/tests/conftest.py index 6fa1471d3..cd5049736 100644 --- a/stac_fastapi/api/tests/conftest.py +++ b/stac_fastapi/api/tests/conftest.py @@ -5,7 +5,7 @@ from stac_pydantic import Collection, Item from stac_pydantic.api.utils import link_factory -from stac_fastapi.types import core, response_model +from stac_fastapi.types import core, stac from stac_fastapi.types.core import NumType from stac_fastapi.types.search import BaseSearchPostRequest @@ -67,9 +67,9 @@ def TestCoreClient(collection_dict, item_dict): class CoreClient(core.BaseCoreClient): def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> response_model.ItemCollection: - return response_model.ItemCollection( - type="FeatureCollection", features=[response_model.Item(**item_dict)] + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] ) def get_search( @@ -81,19 +81,17 @@ def get_search( datetime: Optional[Union[str, datetime]] = None, limit: Optional[int] = 10, **kwargs, - ) -> response_model.ItemCollection: - return response_model.ItemCollection( - type="FeatureCollection", features=[response_model.Item(**item_dict)] + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] ) - def get_item( - self, item_id: str, collection_id: str, **kwargs - ) -> response_model.Item: - return response_model.Item(**item_dict) + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + return stac.Item(**item_dict) - def all_collections(self, **kwargs) -> response_model.Collections: - return response_model.Collections( - collections=[response_model.Collection(**collection_dict)], + def all_collections(self, **kwargs) -> stac.Collections: + return stac.Collections( + collections=[stac.Collection(**collection_dict)], links=[ {"href": "test", "rel": "root"}, {"href": "test", "rel": "self"}, @@ -101,10 +99,8 @@ def all_collections(self, **kwargs) -> response_model.Collections: ], ) - def get_collection( - self, collection_id: str, **kwargs - ) -> response_model.Collection: - return response_model.Collection(**collection_dict) + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: + return stac.Collection(**collection_dict) def item_collection( self, @@ -114,9 +110,9 @@ def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> response_model.ItemCollection: - return response_model.ItemCollection( - type="FeatureCollection", features=[response_model.Item(**item_dict)] + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] ) return CoreClient diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py index ae5a859d2..9b4e0e828 100644 --- a/stac_fastapi/api/tests/test_app.py +++ b/stac_fastapi/api/tests/test_app.py @@ -1,55 +1,113 @@ -import importlib from datetime import datetime from typing import List, Optional, Union import pytest from fastapi.testclient import TestClient -from pydantic import BaseModel +from pydantic import ValidationError +from stac_pydantic import api -from stac_fastapi.api.app import StacApi +from stac_fastapi.api import app from stac_fastapi.api.models import create_get_request_model, create_post_request_model from stac_fastapi.extensions.core.filter.filter import FilterExtension -from stac_fastapi.types import core, response_model, search +from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import NumType from stac_fastapi.types.search import BaseSearchPostRequest -@pytest.mark.parametrize( - "validate, response_type", - [ - ("True", BaseModel), - ("False", dict), - ], -) -def test_client_response_type(validate, response_type, TestCoreClient, monkeypatch): - """Test for correct response type when VALIDATE_RESPONSE is set.""" - monkeypatch.setenv("VALIDATE_RESPONSE", validate) +def test_client_response_type(TestCoreClient): + """Test all GET endpoints. Verify that responses are valid STAC items.""" - importlib.reload(response_model) - importlib.reload(core) - - test_app = StacApi( + test_app = app.StacApi( settings=ApiSettings(), client=TestCoreClient(), ) - class MockRequest: - base_url = "http://test" - app = test_app.app - - assert isinstance(TestCoreClient().landing_page(request=MockRequest()), response_type) - assert isinstance(TestCoreClient().get_collection("test"), response_type) - assert isinstance(TestCoreClient().all_collections(), response_type) - assert isinstance(TestCoreClient().get_item("test", "test"), response_type) - assert isinstance(TestCoreClient().item_collection("test"), response_type) - assert isinstance( - TestCoreClient().post_search(search.BaseSearchPostRequest()), response_type + with TestClient(test_app.app) as client: + landing = client.get("/") + collection = client.get("/collections/test") + collections = client.get("/collections") + item = client.get("/collections/test/items/test") + item_collection = client.get( + "/collections/test/items", + params={"limit": 10}, + ) + get_search = client.get( + "/search", + params={ + "collections": ["test"], + }, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + + assert landing.status_code == 200, landing.text + api.LandingPage(**landing.json()) + + assert collection.status_code == 200, collection.text + api.Collection(**collection.json()) + + assert collections.status_code == 200, collections.text + api.collections.Collections(**collections.json()) + + assert item.status_code == 200, item.text + api.Item(**item.json()) + + assert item_collection.status_code == 200, item_collection.text + api.ItemCollection(**item_collection.json()) + + assert get_search.status_code == 200, get_search.text + api.ItemCollection(**get_search.json()) + + assert post_search.status_code == 200, post_search.text + api.ItemCollection(**post_search.json()) + + +@pytest.mark.parametrize("validate", [True, False]) +def test_client_invalid_response_type(validate, TestCoreClient, item_dict): + """Check if the build in response validation switch works.""" + + class InValidResponseClient(TestCoreClient): + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + item_dict.pop("bbox") + item_dict.pop("geometry") + return stac.Item(**item_dict) + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=InValidResponseClient(), ) - assert isinstance( - TestCoreClient().get_search(), - response_type, + + with TestClient(test_app.app) as client: + item = client.get("/collections/test/items/test") + + # Even if API validation passes, we should receive an invalid item + if item.status_code == 200: + with pytest.raises(ValidationError): + api.Item(**item.json()) + + # If internal validation is on, we should expect an internal error + if validate: + assert item.status_code == 500, item.text + else: + assert item.status_code == 200, item.text + + +def test_client_openapi(TestCoreClient): + """Test if response models are all documented with OpenAPI.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=TestCoreClient(), ) + test_app.app.openapi() + components = ["LandingPage", "Collection", "Collections", "Item", "ItemCollection"] + for component in components: + assert component in test_app.app.openapi_schema["components"]["schemas"] def test_filter_extension(TestCoreClient, item_dict): @@ -58,14 +116,14 @@ def test_filter_extension(TestCoreClient, item_dict): class FilterClient(TestCoreClient): def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: search_request.collections = ["test"] search_request.filter = {} search_request.filter_crs = "EPSG:4326" search_request.filter_lang = "cql2-text" - return response_model.ItemCollection( - type="FeatureCollection", features=[response_model.Item(**item_dict)] + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] ) def get_search( @@ -80,7 +138,7 @@ def get_search( filter_crs: Optional[str] = None, filter_lang: Optional[str] = None, **kwargs, - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: # Check if all filter parameters are passed correctly assert filter == "TEST" @@ -94,13 +152,13 @@ def get_search( # assert filter_crs == "EPSG:4326" # assert filter_lang == "cql2-text" - return response_model.ItemCollection( - type="FeatureCollection", features=[response_model.Item(**item_dict)] + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] ) post_request_model = create_post_request_model([FilterExtension()]) - test_app = StacApi( + test_app = app.StacApi( settings=ApiSettings(), client=FilterClient(post_request_model=post_request_model), search_get_request_model=create_get_request_model([FilterExtension()]), @@ -108,14 +166,6 @@ def get_search( ) with TestClient(test_app.app) as client: - landing = client.get("/") - collection = client.get("/collections/test") - collections = client.get("/collections") - item = client.get("/collections/test/items/test") - item_collection = client.get( - "/collections/test/items", - params={"limit": 10}, - ) get_search = client.get( "/search", params={ @@ -134,10 +184,5 @@ def get_search( }, ) - assert landing.status_code == 200, landing.text - assert collection.status_code == 200, collection.text - assert collections.status_code == 200, collections.text - assert item.status_code == 200, item.text - assert item_collection.status_code == 200, item_collection.text assert get_search.status_code == 200, get_search.text assert post_search.status_code == 200, post_search.text diff --git a/stac_fastapi/types/stac_fastapi/types/config.py b/stac_fastapi/types/stac_fastapi/types/config.py index 203adf4a1..f3fd4d655 100644 --- a/stac_fastapi/types/stac_fastapi/types/config.py +++ b/stac_fastapi/types/stac_fastapi/types/config.py @@ -31,8 +31,6 @@ class ApiSettings(BaseSettings): openapi_url: str = "/api" docs_url: str = "/api.html" - validate_response: bool = False - model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 1c25aacc4..bfa77772b 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -7,14 +7,13 @@ import attr from fastapi import Request -from pydantic import BaseModel from stac_pydantic import Collection, Item, ItemCollection from stac_pydantic.api.version import STAC_API_VERSION from stac_pydantic.links import Relations from stac_pydantic.shared import MimeTypes from starlette.responses import Response -from stac_fastapi.types import response_model +from stac_fastapi.types import stac from stac_fastapi.types.config import Settings from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension @@ -265,8 +264,8 @@ def _landing_page( base_url: str, conformance_classes: List[str], extension_schemas: List[str], - ) -> Dict[str, Any]: - landing_page = response_model.LandingPage( + ) -> stac.LandingPage: + landing_page = stac.LandingPage( type="Catalog", id=self.landing_page_id, title=self.title, @@ -276,42 +275,42 @@ def _landing_page( links=[ { "rel": Relations.self.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { "rel": Relations.root.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { - "rel": "data", - "type": MimeTypes.json, + "rel": Relations.data.value, + "type": MimeTypes.json.value, "href": urljoin(base_url, "collections"), }, { "rel": Relations.conformance.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "title": "STAC/OGC conformance classes implemented by this server", "href": urljoin(base_url, "conformance"), }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "GET", }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "POST", }, { "rel": Relations.service_desc.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "Service Description", "href": Settings.get().openapi_url, }, @@ -319,10 +318,7 @@ def _landing_page( stac_extensions=extension_schemas, ) - if isinstance(landing_page, BaseModel): - return landing_page.model_dump(mode="json") - else: - return landing_page + return landing_page @attr.s # type:ignore @@ -364,7 +360,7 @@ def list_conformance_classes(self): return base_conformance - def landing_page(self, **kwargs) -> response_model.LandingPage: + def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -383,10 +379,8 @@ def landing_page(self, **kwargs) -> response_model.LandingPage: # Add Collections links _collections = self.all_collections(request=kwargs["request"]) - if isinstance(_collections, BaseModel): - collections = _collections.model_dump(mode="json") - else: - collections = _collections + collections = _collections + for collection in collections["collections"]: landing_page["links"].append( { @@ -400,8 +394,8 @@ def landing_page(self, **kwargs) -> response_model.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", "href": urljoin( str(request.base_url), request.app.openapi_url.lstrip("/") @@ -412,16 +406,16 @@ def landing_page(self, **kwargs) -> response_model.LandingPage: # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", "href": urljoin(str(request.base_url), request.app.docs_url.lstrip("/")), } ) - return response_model.LandingPage(**landing_page) + return stac.LandingPage(**landing_page) - def conformance(self, **kwargs) -> response_model.Conformance: + def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -429,12 +423,12 @@ def conformance(self, **kwargs) -> response_model.Conformance: Returns: Conformance classes which the server conforms to. """ - return response_model.Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -457,7 +451,7 @@ def get_search( datetime: Optional[Union[str, datetime]] = None, limit: Optional[int] = 10, **kwargs, - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -468,7 +462,7 @@ def get_search( ... @abc.abstractmethod - def get_item(self, item_id: str, collection_id: str, **kwargs) -> response_model.Item: + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -483,7 +477,7 @@ def get_item(self, item_id: str, collection_id: str, **kwargs) -> response_model ... @abc.abstractmethod - def all_collections(self, **kwargs) -> response_model.Collections: + def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -494,7 +488,7 @@ def all_collections(self, **kwargs) -> response_model.Collections: ... @abc.abstractmethod - def get_collection(self, collection_id: str, **kwargs) -> response_model.Collection: + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -516,7 +510,7 @@ def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` @@ -561,7 +555,7 @@ def extension_is_enabled(self, extension: str) -> bool: """Check if an api extension is enabled.""" return any([type(ext).__name__ == extension for ext in self.extensions]) - async def landing_page(self, **kwargs) -> response_model.LandingPage: + async def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -580,10 +574,8 @@ async def landing_page(self, **kwargs) -> response_model.LandingPage: # Add Collections links _collections = await self.all_collections(request=kwargs["request"]) - if isinstance(_collections, BaseModel): - collections = _collections.model_dump(mode="json") - else: - collections = _collections + collections = _collections + for collection in collections["collections"]: landing_page["links"].append( { @@ -597,8 +589,8 @@ async def landing_page(self, **kwargs) -> response_model.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", "href": urljoin( str(request.base_url), request.app.openapi_url.lstrip("/") @@ -609,16 +601,16 @@ async def landing_page(self, **kwargs) -> response_model.LandingPage: # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", "href": urljoin(str(request.base_url), request.app.docs_url.lstrip("/")), } ) - return response_model.LandingPage(**landing_page) + return stac.LandingPage(**landing_page) - async def conformance(self, **kwargs) -> response_model.Conformance: + async def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -626,12 +618,12 @@ async def conformance(self, **kwargs) -> response_model.Conformance: Returns: Conformance classes which the server conforms to. """ - return response_model.Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod async def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -650,15 +642,11 @@ async def get_search( collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, datetime: Optional[Union[str, datetime]] = None, limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -669,9 +657,7 @@ async def get_search( ... @abc.abstractmethod - async def get_item( - self, item_id: str, collection_id: str, **kwargs - ) -> response_model.Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -686,7 +672,7 @@ async def get_item( ... @abc.abstractmethod - async def all_collections(self, **kwargs) -> response_model.Collections: + async def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -697,9 +683,7 @@ async def all_collections(self, **kwargs) -> response_model.Collections: ... @abc.abstractmethod - async def get_collection( - self, collection_id: str, **kwargs - ) -> response_model.Collection: + async def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -721,7 +705,7 @@ async def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> response_model.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` diff --git a/stac_fastapi/types/stac_fastapi/types/response_model.py b/stac_fastapi/types/stac_fastapi/types/response_model.py deleted file mode 100644 index 76d266a81..000000000 --- a/stac_fastapi/types/stac_fastapi/types/response_model.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Response models for STAC FastAPI. -Depending on settings models are either TypeDicts or Pydantic models.""" - -from stac_pydantic import api - -from stac_fastapi.types import stac -from stac_fastapi.types.config import ApiSettings - -settings = ApiSettings() - -if settings.validate_response: - response_model = api -else: - response_model = stac - - -LandingPage = response_model.LandingPage -Collection = response_model.Collection -Collections = response_model.Collections -Item = response_model.Item -ItemCollection = response_model.ItemCollection -try: - Conformance = response_model.Conformance -except AttributeError: - # TODO: class name needs to be fixed in stac_pydantic - # stac-utils/stac-pydantic#136 - Conformance = response_model.ConformanceClasses diff --git a/stac_fastapi/types/tests/test_response_model.py b/stac_fastapi/types/tests/test_response_model.py deleted file mode 100644 index 3086c1352..000000000 --- a/stac_fastapi/types/tests/test_response_model.py +++ /dev/null @@ -1,39 +0,0 @@ -import importlib -import os - -import pytest -from pydantic import BaseModel - -from stac_fastapi.types import response_model - - -@pytest.fixture -def cleanup(): - old_environ = dict(os.environ) - yield - os.environ.clear() - os.environ.update(old_environ) - - -@pytest.mark.parametrize( - "validate, response_type", - [ - ("True", BaseModel), - ("False", dict), - ], -) -def test_response_model(validate, response_type, cleanup): - os.environ["VALIDATE_RESPONSE"] = str(validate) - importlib.reload(response_model) - - landing_page = response_model.LandingPage( - id="test", - description="test", - links=[ - {"href": "test", "rel": "root"}, - {"href": "test", "rel": "self"}, - {"href": "test", "rel": "service-desc"}, - ], - ) - - assert isinstance(landing_page, response_type)