diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index f86cb6786..3da9d9cdc 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -10,33 +10,9 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] timeout-minutes: 20 - services: - db_service: - image: ghcr.io/stac-utils/pgstac:v0.7.1 - env: - POSTGRES_USER: username - POSTGRES_PASSWORD: password - POSTGRES_DB: postgis - POSTGRES_HOST: localhost - POSTGRES_PORT: 5432 - PGUSER: username - PGPASSWORD: password - PGDATABASE: postgis - ALLOW_IP_RANGE: 0.0.0.0/0 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 10s - --health-retries 10 - --log-driver none - ports: - # Maps tcp port 5432 on service container to the host - - 5432:5432 - steps: - name: Check out repository code uses: actions/checkout@v4 @@ -55,18 +31,18 @@ jobs: - name: Install types run: | - pip install ./stac_fastapi/types[dev] + python -m pip install ./stac_fastapi/types[dev] - name: Install core api run: | - pip install ./stac_fastapi/api[dev] + python -m pip install ./stac_fastapi/api[dev] - name: Install Extensions run: | - pip install ./stac_fastapi/extensions[dev] + python -m pip install ./stac_fastapi/extensions[dev] - name: Test - run: pytest -svvv + run: python -m pytest -svvv env: ENVIRONMENT: testing @@ -93,18 +69,19 @@ jobs: run: | python -m pip install ./stac_fastapi/types[dev] - - name: Install extensions - run: | - python -m pip install ./stac_fastapi/extensions - - name: Install core api run: | python -m pip install ./stac_fastapi/api[dev,benchmark] + - name: Install extensions + run: | + python -m pip install ./stac_fastapi/extensions + - name: Run Benchmark run: python -m pytest stac_fastapi/api/tests/benchmarks.py --benchmark-only --benchmark-columns 'min, max, mean, median' --benchmark-json output.json - name: Store and benchmark result + if: github.repository == 'stac-utils/stac-fastapi' uses: benchmark-action/github-action-benchmark@v1 with: name: STAC FastAPI Benchmarks diff --git a/.github/workflows/deploy_mkdocs.yml b/.github/workflows/deploy_mkdocs.yml index 7132fdb6c..a3469aad8 100644 --- a/.github/workflows/deploy_mkdocs.yml +++ b/.github/workflows/deploy_mkdocs.yml @@ -20,17 +20,17 @@ jobs: - name: Checkout main uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.11 - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install \ - stac_fastapi/api[docs] \ stac_fastapi/types[docs] \ + stac_fastapi/api[docs] \ stac_fastapi/extensions[docs] \ - name: update API docs @@ -40,14 +40,6 @@ jobs: --exclude_source \ --overwrite \ stac_fastapi - env: - POSTGRES_USER: username - POSTGRES_PASS: password - POSTGRES_DBNAME: postgis - POSTGRES_HOST: localhost - POSTGRES_PORT: 5432 - POSTGRES_HOST_READER: localhost - POSTGRES_HOST_WRITER: localhost - name: Deploy docs run: mkdocs gh-deploy --force -f docs/mkdocs.yml diff --git a/.gitignore b/.gitignore index 908694a3a..3b2a1fea8 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ docs/api/* # Virtualenv venv +.venv/ # IDE .vscode \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 193edc5c7..68c3b8567 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,7 @@ repos: - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.0.267" + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.2.2" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black + - id: ruff-format diff --git a/CHANGES.md b/CHANGES.md index bc31368d8..939471bbb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,12 +1,154 @@ # Changelog -## [Unreleased] +## [Unreleased] - TBD + +### Changed + +* add more openapi metadata in input models [#734](https://github.com/stac-utils/stac-fastapi/pull/734) ### Added -* Add benchmark in CI ([#650](https://github.com/stac-utils/stac-fastapi/pull/650)) * Add Free-text Extension to third party extensions ([#655](https://github.com/stac-utils/stac-fastapi/pull/655)) +## [3.0.0b2] - 2024-07-09 + +### Changed + +* move back to `@attrs` (instead of dataclass) for `APIRequest` (model for GET request) class type [#729](https://github.com/stac-utils/stac-fastapi/pull/729) + +## [3.0.0b1] - 2024-07-05 + +### Added + +* Add attributes to `stac_fastapi.api.app.StacApi` to enable customization of request model for: + - `/collections`: **collections_get_request_model**, default to `EmptyRequest` + - `/collections/{collection_id}`: **collection_get_request_model**, default to `CollectionUri` + - `/collections/{collection_id}/items`: **items_get_request_model**, default to `ItemCollectionUri` + - `/collections/{collection_id}/items/{item_id}`: **item_get_request_model**, default to `ItemUri` + +### Removed + +* Removed the Filter Extension dependency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716) +* Removed `pagination_extension` attribute in `stac_fastapi.api.app.StacApi` +* Removed use of `pagination_extension` in `register_get_item_collection` function (User now need to construct the request model and pass it using `items_get_request_model` attribute) +* Removed use of `FieldsExtension` in `stac_fastapi.api.app.StacApi`. If users use `FieldsExtension`, they would have to handle overpassing the model validation step by returning a `JSONResponse` from the `post_search` and `get_search` client methods. + +### Changed + +* Replaced `@attrs` with python `@dataclass` for `APIRequest` (model for GET request) class type [#714](https://github.com/stac-utils/stac-fastapi/pull/714) +* Moved `GETPagination`, `POSTPagination`, `GETTokenPagination` and `POSTTokenPagination` to `stac_fastapi.extensions.core.pagination.request` submodule [#717](https://github.com/stac-utils/stac-fastapi/pull/717) +* update FastAPI requirement to `>=0.111.0` + +## [3.0.0a4] - 2024-06-27 + +### Fixed + +* Updated default filter language in filter extension's POST search request model to match the extension's documentation [#711](https://github.com/stac-utils/stac-fastapi/issues/711) + +### Removed + +* Removed the Filter Extension depenency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716) +* Removed `add_middleware` method in `StacApi` object and let starlette handle the middleware stack creation [721](https://github.com/stac-utils/stac-fastapi/pull/721) + +## [3.0.0a3] - 2024-06-13 + +### Added + +* Add base support for the Aggregation extension [#684](https://github.com/stac-utils/stac-fastapi/pull/684) + +### Changed + +* Added option for default route dependencies `*` can be used for `path` or `method` to match all allowed route. ([#705](https://github.com/stac-utils/stac-fastapi/pull/705)) +* Moved `AsyncBaseFiltersClient` and `BaseFiltersClient` classes in `stac_fastapi.extensions.core.filter.client` submodule ([#704](https://github.com/stac-utils/stac-fastapi/pull/704)) +* Removed `default_includes` from `stac_fastapi.types.config.ApiSettings` ([#706](https://github.com/stac-utils/stac-fastapi/pull/706)) +* Deprecated *Fields* extension `PostFieldsExtension.filter_fields` property ([#706](https://github.com/stac-utils/stac-fastapi/pull/706)) + +## [3.0.0a2] - 2024-05-31 + +### Fixed + +* Fix missing default (`None`) for optional `query` attribute in `QueryExtensionPostRequest` model ([#701](https://github.com/stac-utils/stac-fastapi/pull/701)) + +## [3.0.0a1] - 2024-05-22 + +### Changed + +* Switch from `fastapi` to `fastapi-slim` to avoid installing unwanted dependencies. ([#687](https://github.com/stac-utils/stac-fastapi/pull/687)) +* Replace Enum with `Literal` for `FilterLang`. ([#686](https://github.com/stac-utils/stac-fastapi/pull/686)) +* Update stac-pydantic requirement to `~3.1` ([#697](https://github.com/stac-utils/stac-fastapi/pull/697)) + +### Removed + +* Pystac as it was just used for a datetime to string function. ([#690](https://github.com/stac-utils/stac-fastapi/pull/690)) + +### Fixed + +* Make `str_to_interval` not return a tuple for single-value input (fixing `datetime` argument as passed to `get_search`). ([#692](https://github.com/stac-utils/stac-fastapi/pull/692)) + +## [3.0.0a0] - 2024-05-06 + +### Added + +* Add enhanced middleware configuration to the StacApi class, enabling specific middleware options and dynamic addition post-application initialization. ([#442](https://github.com/stac-utils/stac-fastapi/pull/442)) +* Add Response Model to OpenAPI, even if model validation is turned off ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) + +## Changed + +* Update to pydantic v2 and stac_pydantic v3 ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Removed internal Search and Operator Types in favor of stac_pydantic Types ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Fix response model validation ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Use status code 201 for Item/Collection creation ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* Replace Black with Ruff Format ([#625](https://github.com/stac-utils/stac-fastapi/pull/625)) +* add `response_class` in the route definitions for `FilterExtension` + +## [2.5.5.post1] - 2024-04-25 + +### Fixed + +* Fix `service-doc` and `service-desc` url in landing page when using router prefix for `AsyncBaseCoreClient`. ([#675](https://github.com/stac-utils/stac-fastapi/pull/675)) + +## [2.5.5] - 2024-04-24 + +### Fixed + +* Fix `service-doc` and `service-desc` url in landing page when using router prefix. ([#673](https://github.com/stac-utils/stac-fastapi/pull/673)) + +## [2.5.4] - 2024-04-24 + +### Fixed + +* Fix missing payload for the PUT `collection/{collection_id}` endpoint ([#665](https://github.com/stac-utils/stac-fastapi/issues/665)) +* Return 400 for datetime errors ([#670](https://github.com/stac-utils/stac-fastapi/pull/670)) + +## [2.5.3] - 2024-04-23 + +### Fixed + +* Remove the str2list converter from intersection queries via BaseSearchGetRequest ([#668](https://github.com/stac-utils/stac-fastapi/pull/668)) +* Apply datetime converter in ItemCollection endpoint model ([#667](https://github.com/stac-utils/stac-fastapi/pull/667)) + +## [2.5.2] - 2024-04-19 + +### Fixed + +* BaseSearchGetRequest datetime validator str_to_interval not allowing GET /search requests with datetime = None ([#662](https://github.com/stac-utils/stac-fastapi/pull/662)) + +## [2.5.1] - 2024-04-18 + +### Fixed + +* Fixed warnings.warn deprecation syntax for response class and the context extension ([#660](https://github.com/stac-utils/stac-fastapi/pull/660)) + +## [2.5.0] - 2024-04-12 + +### Added + +* Add benchmark in CI ([#650](https://github.com/stac-utils/stac-fastapi/pull/650)) +* Add `/queryables` link to the landing page ([#587](https://github.com/stac-utils/stac-fastapi/pull/587)) +- `id`, `title`, `description` and `api_version` fields can be customized via env variables +* Add `DeprecationWarning` for the `ContextExtension` +* Add support for Python 3.12 + ### Changed * Updated the collection update endpoint to match with the collection-transaction extension. ([#630](https://github.com/stac-utils/stac-fastapi/issues/630)) @@ -304,7 +446,21 @@ * First PyPi release! -[Unreleased]: +[Unreleased]: +[3.0.0b2]: +[3.0.0b1]: +[3.0.0a4]: +[3.0.0a3]: +[3.0.0a2]: +[3.0.0a1]: +[3.0.0a0]: +[2.5.5.post1]: +[2.5.5]: +[2.5.4]: +[2.5.3]: +[2.5.2]: +[2.5.1]: +[2.5.0]: [2.4.9]: [2.4.8]: [2.4.7]: diff --git a/Dockerfile b/Dockerfile index 2187ac53e..9b6817182 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.8-slim as base +FROM python:3.11-slim as base # Any python libraries that require system libraries to be installed will likely # need the following packages in order to build @@ -16,6 +16,6 @@ WORKDIR /app COPY . /app -RUN pip install -e ./stac_fastapi/types[dev] && \ - pip install -e ./stac_fastapi/api[dev] && \ - pip install -e ./stac_fastapi/extensions[dev] +RUN python -m pip install -e ./stac_fastapi/types[dev] && \ + python -m pip install -e ./stac_fastapi/api[dev] && \ + python -m pip install -e ./stac_fastapi/extensions[dev] diff --git a/Dockerfile.docs b/Dockerfile.docs index e3c7447e5..6c7f00843 100644 --- a/Dockerfile.docs +++ b/Dockerfile.docs @@ -1,4 +1,4 @@ -FROM python:3.8-slim +FROM python:3.11-slim # build-essential is required to build a wheel for ciso8601 RUN apt update && apt install -y build-essential @@ -11,8 +11,8 @@ COPY . /opt/src WORKDIR /opt/src RUN python -m pip install \ - stac_fastapi/api \ stac_fastapi/types \ + stac_fastapi/api \ stac_fastapi/extensions CMD ["pdocs", \ @@ -21,4 +21,4 @@ CMD ["pdocs", \ "docs/api/", \ "--exclude_source", \ "--overwrite", \ - "stac_fastapi"] \ No newline at end of file + "stac_fastapi"] diff --git a/Makefile b/Makefile index e802fbb54..eef5dae35 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,10 @@ image: .PHONY: install install: - pip install wheel && \ - pip install -e ./stac_fastapi/api[dev] && \ - pip install -e ./stac_fastapi/types[dev] && \ - pip install -e ./stac_fastapi/extensions[dev] + python -m pip install wheel && \ + python -m pip install -e ./stac_fastapi/types[dev] && \ + python -m pip install -e ./stac_fastapi/api[dev] && \ + python -m pip install -e ./stac_fastapi/extensions[dev] .PHONY: docs-image docs-image: @@ -21,4 +21,4 @@ docs: docs-image .PHONY: test test: image - pytest . \ No newline at end of file + python -m pytest . diff --git a/README.md b/README.md index 350ce2589..02c155993 100644 --- a/README.md +++ b/README.md @@ -41,16 +41,28 @@ Backends are hosted in their own repositories: `stac-fastapi` was initially developed by [arturo-ai](https://github.com/arturo-ai). + +## Response Model Validation + +A common question when using this package is how request and response types are validated? + +This package uses [`stac-pydantic`](https://github.com/stac-utils/stac-pydantic) to validate and document STAC objects. However, by default, validation of response types is turned off and the API will simply forward responses without validating them against the Pydantic model first. This decision was made with the assumption that responses usually come from a (typed) database and can be considered safe. Extra validation would only increase latency, in particular for large payloads. + +To turn on response validation, set `ENABLE_RESPONSE_MODELS` to `True`. Either as an environment variable or directly in the `ApiSettings`. + +With the introduction of Pydantic 2, the extra [time it takes to validate models became negatable](https://github.com/stac-utils/stac-fastapi/pull/625#issuecomment-2045824578). While `ENABLE_RESPONSE_MODELS` still defaults to `False` there should be no penalty for users to turn on this feature but users discretion is advised. + + ## Installation ```bash # Install from PyPI -pip install stac-fastapi.api stac-fastapi.types stac-fastapi.extensions +python -m pip install stac-fastapi.types stac-fastapi.api stac-fastapi.extensions # Install a backend of your choice -pip install stac-fastapi.sqlalchemy +python -m pip install stac-fastapi.sqlalchemy # or -pip install stac-fastapi.pgstac +python -m pip install stac-fastapi.pgstac ``` Other backends may be available from other sources, search [PyPI](https://pypi.org/) for more. @@ -60,14 +72,14 @@ Other backends may be available from other sources, search [PyPI](https://pypi.o Install the packages in editable mode: ```shell -pip install -e \ - 'stac_fastapi/api[dev]' \ +python -m pip install -e \ 'stac_fastapi/types[dev]' \ + 'stac_fastapi/api[dev]' \ 'stac_fastapi/extensions[dev]' ``` To run the tests: ```shell -pytest +python -m pytest ``` diff --git a/RELEASING.md b/RELEASING.md index 8aa14afcb..3a23940f6 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -4,7 +4,13 @@ This is a checklist for releasing a new version of **stac-fastapi**. 1. Determine the next version. We currently do not have published versioning guidelines, but there is some text on the subject here: . 2. Create a release branch named `release/vX.Y.Z`, where `X.Y.Z` is the new version. -3. Search and replace all instances of the current version number with the new version. As of this writing, there's five different `version.py` files, and one `VERSION` file, in the repo. +3. Search and replace all instances of the current version number with the new version. As of this writing, there's 3 different `version.py` files, and one `VERSION` file, in the repo. + + Note: You can use [`bump-my-version`](https://github.com/callowayproject/bump-my-version) CLI + ``` + bump-my-version bump --new-version 3.1.0 + ``` + 4. Update [CHANGES.md](./CHANGES.md) for the new version. Add the appropriate header, and update the links at the bottom of the file. 5. Audit CHANGES.md for completeness and accuracy. Also, ensure that the changes in this version are appropriate for the version number change (i.e. if you're making breaking changes, you should be increasing the `MAJOR` version number). 6. (optional) If you have permissions, run `scripts/publish --test` to test your PyPI publish. If successful, the published packages will be available on . diff --git a/VERSION b/VERSION index 158349812..2aa4d8f0a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.4.9 \ No newline at end of file +3.0.0b2 diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 9d9eea0fb..79af024a0 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -36,6 +36,10 @@ nav: - core: - module: api/stac_fastapi/extensions/core/index.md - context: api/stac_fastapi/extensions/core/context.md + - free_text: + - module: api/stac_fastapi/extensions/core/free_text/index.md + - free_text: api/stac_fastapi/extensions/core/free_text/free_text.md + - request: api/stac_fastapi/extensions/core/free_text/request.md - filter: - module: api/stac_fastapi/extensions/core/filter/index.md - filter: api/stac_fastapi/extensions/core/filter/filter.md @@ -60,7 +64,6 @@ nav: - version: api/stac_fastapi/extensions/version.md - third_party: - bulk_transactions: api/stac_fastapi/extensions/third_party/bulk_transactions.md - - free_text: api/stac_fastapi/extensions/third_party/free_text.md - index: api/stac_fastapi/extensions/third_party/index.md - stac_fastapi.types: - module: api/stac_fastapi/types/index.md @@ -75,6 +78,8 @@ nav: - search: api/stac_fastapi/types/search.md - stac: api/stac_fastapi/types/stac.md - version: api/stac_fastapi/types/version.md + - Migration Guides: + - v2.5 -> v3.0: migrations/v3.0.0.md - Performance Benchmarks: benchmarks.html - Development - Contributing: "contributing.md" - Release Notes: "release-notes.md" diff --git a/docs/src/migrations/v3.0.0.md b/docs/src/migrations/v3.0.0.md new file mode 100644 index 000000000..f781687c3 --- /dev/null +++ b/docs/src/migrations/v3.0.0.md @@ -0,0 +1,275 @@ + +# stac-fastapi v3.0 Migration Guide + +This document aims to help you update your application from **stac-fastapi** 2.5 to 3.0.0. + +## Dependencies + +- **pydantic~=2.0** +- **fastapi>=0.111** +- **stac-pydantic~=3.1** + +Most of the **stac-fastapi's** dependencies have been upgraded. Moving from pydantic v1 to v2 is mostly the one update bringing most breaking changes (see https://docs.pydantic.dev/latest/migration/). + +In addition to pydantic v2 update, `stac-pydantic` has been updated to better match the STAC and STAC-API specifications (see https://github.com/stac-utils/stac-pydantic/blob/main/CHANGELOG.md#310-2024-05-21) + +## Deprecation + +* the `ContextExtension` have been removed (see https://github.com/stac-utils/stac-pydantic/pull/138) and was replaced by optional `NumberMatched` and `NumberReturned` attributes, defined by the OGC features specification. + +* `stac_fastapi.api.config_openapi` method was removed (see https://github.com/stac-utils/stac-fastapi/pull/523) + +* passing `response_class` in `stac_fastapi.api.routes.create_async_endpoint` is now deprecated. The response class now has to be set when registering the endpoint to the application (see https://github.com/stac-utils/stac-fastapi/issues/461) + +* `PostFieldsExtension.filter_fields` property has been removed. + +## Middlewares configuration + +The `StacApi.middlewares` attribute has been updated to accept a list of `starlette.middleware.Middleware`. This enables dynamic configuration of middlewares (see https://github.com/stac-utils/stac-fastapi/pull/442). + +```python +# before +class myMiddleware(mainMiddleware): + option1 = option1 + option2 = option2 + +stac = StacApi( + middlewares=[ + myMiddleware, + ] +) + +# now +stac = StacApi( + middlewares=[ + Middleware(myMiddleware, option1, option2), + ] +) +``` + +## Request Models + +In stac-fastapi v2.0, users could already customize both GET/POST search request models. For v3.0, we've added more attributes to enable other endpoints customization: + +- `collections_get_request_model`: GET request model for the `/collections` endpoint (default to `EmptyRequest`) +- `collection_get_request_model`: GET request model for the `/collections/{collection_id}` endpoint (default to `stac_fastapi.api.models.CollectionUri`) +- `items_get_request_model`: GET request model for the `/collections/{collection_id}/items` endpoint (default to `stac_fastapi.api.models.ItemCollectionUri`) +- `item_get_request_model`: GET request model for the `/collections/{collection_id}/items/{item_id}` endpoint (default to `stac_fastapi.api.models.ItemUri`) + +```python +# before +getSearchModel = create_request_model( + model_name="SearchGetRequest", + base_model=BaseSearchGetRequest + extensions=[...], + request_type="GET" +) +stac = StacApi( + search_get_request_model=getSearchModel, + search_post_request_model=..., +) + +# now +@attr.s +class CollectionsRequest(APIRequest): + user: Annotated[str, Query(...)] = attr.ib() + +stac = StacApi( + search_get_request_model=getSearchModel, + search_post_request_model=postSearchModel, + collections_get_request_model=CollectionsRequest, + collection_get_request_model=..., + items_get_request_model=..., + item_get_request_model=..., +) +``` + +## APIRequest - GET Request Model + +Most of the **GET** endpoints are configured with `stac_fastapi.types.search.APIRequest` base class. + +e.g the BaseSearchGetRequest, default for the `GET - /search` endpoint: + +```python +@attr.s +class BaseSearchGetRequest(APIRequest): + """Base arguments for GET Request.""" + + collections: Optional[List[str]] = attr.ib(default=None, converter=_collection_converter) + ids: Optional[List[str]] = attr.ib(default=None, converter=_ids_converter) + bbox: Optional[BBox] = attr.ib(default=None, converter=_bbox_converter) + intersects: Annotated[Optional[str], Query()] = attr.ib(default=None) + datetime: Optional[DateTimeType] = attr.ib( + default=None, converter=_datetime_converter + ) + limit: Annotated[Optional[int], Query()] = attr.ib(default=10) +``` + +We use [*python attrs*](https://www.attrs.org/en/stable/) to construct those classes. **Type Hint** for each attribute is important and should be defined using `Annotated[{type}, fastapi.Query()]` form. + +```python +@attr.s +class SomeRequest(APIRequest): + user_number: Annotated[Optional[int], Query(alias="user-number")] = attr.ib(default=None) +``` + +Note: when an attribute has a `converter` (e.g `_ids_converter`), the **Type Hint** should be defined directly in the converter: + +```python +def _ids_converter( + val: Annotated[ + Optional[str], + Query( + description="Array of Item ids to return.", + ), + ] = None, +) -> Optional[List[str]]: + return str2list(val) + +@attr.s +class BaseSearchGetRequest(APIRequest): + """Base arguments for GET Request.""" + + ids: Optional[List[str]] = attr.ib(default=None, converter=_ids_converter) +``` + +## Filter extension + +`default_includes` attribute has been removed from the `ApiSettings` object. If you need `defaults` includes you can overwrite the `FieldExtension` models (see https://github.com/stac-utils/stac-fastapi/pull/706). + +```python +# before +stac = StacApi( + extensions=[ + FieldsExtension() + ] +) + +# now +class PostFieldsExtension(requests.PostFieldsExtension): + include: Optional[Set[str]] = Field( + default_factory=lambda: { + "id", + "type", + "stac_version", + "geometry", + "bbox", + "links", + "assets", + "properties.datetime", + "collection", + } + ) + exclude: Optional[Set[str]] = set() + + +class FieldsExtensionPostRequest(BaseModel): + """Additional fields and schema for the POST request.""" + + fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) + + +class FieldsExtension(FieldsExtensionBase): + """Override the POST model""" + + POST = FieldsExtensionPostRequest + + +from stac_fastapi.api.app import StacApi + +stac = StacApi( + extensions=[ + FieldsExtension() + ] +) +``` + +## Pagination extension + +In stac-fastapi v3.0, we removed the `pagination_extension` attribute in `stac_fastapi.api.app.StacApi`. This attribute was used within the `register_get_item_collection` to update the request model for the `/collections/{collection_id}/items` endpoint. + +It's now up to the user to create the request model and use the `items_get_request_model=` attribute in the StacApi object. + +```python +# before +stac=StacApi( + pagination_extension=TokenPaginationExtension, + extension=[TokenPaginationExtension] +) + +# now +items_get_request_model = create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[TokenPaginationExtension().GET], +) + +stac=StacApi( + extension=[TokenPaginationExtension], + items_get_request_model=items_get_request_model, +) +``` + + +## Fields extension and model validation + +When using the `Fields` extension, the `/search` endpoint should be able to return `**invalid** STAC Items. This creates an issue when *model validation* is enabled at the application level. + +Previously when adding the `FieldsExtension` to the extensions list and if setting output model validation, we were turning off the validation for both GET/POST `/search` endpoints. This was by-passing validation even when users were not using the `fields` options in requests. + +In `stac-fastapi` v3.0, implementers will have to by-pass the *validation step* at `Client` level by returning `JSONResponse` from the `post_search` and `get_search` client methods. + +```python +# before +class BadCoreClient(BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + return {"not": "a proper stac item"} + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + return {"not": "a proper stac item"} + +# now +class BadCoreClient(BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + resp = {"not": "a proper stac item"} + + # if `fields` extension is enabled, then we return a JSONResponse + # to avoid Item validation + if getattr(search_request, "fields", None): + return JSONResponse(content=resp) + + return resp + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + resp = {"not": "a proper stac item"} + + # if `fields` extension is enabled, then we return a JSONResponse + # to avoid Item validation + if "fields" in kwargs: + return JSONResponse(content=resp) + + return resp + +``` diff --git a/docs/src/tips-and-tricks.md b/docs/src/tips-and-tricks.md index ca5463c59..5398112ef 100644 --- a/docs/src/tips-and-tricks.md +++ b/docs/src/tips-and-tricks.md @@ -22,6 +22,10 @@ If needed, you can edit the `allow_origins` parameter to only allow CORS request ## Enable the Context extension +!!! Warning + + The `ContextExtension` is deprecated and will be removed in 3.0. See https://github.com/radiantearth/stac-api-spec/issues/396 + The Context STAC extension provides information on the number of items matched and returned from a STAC search. This is required by various other STAC-related tools, such as the pystac command-line client. To enable the extension, edit your backend's `app.py` and add the following import: @@ -30,4 +34,66 @@ To enable the extension, edit your backend's `app.py` and add the following impo from stac_fastapi.extensions.core.context import ContextExtension ``` + and then edit the `api = StacApi(...` call to add `ContextExtension()` to the list given as the `extensions` parameter. + +## Set API title, description and version + +For the landing page, you can set the API title, description and version using environment variables. + +- `STAC_FASTAPI_VERSION` (string) is the version number of your API instance (this is not the STAC version). +- `STAC FASTAPI_TITLE` (string) should be a self-explanatory title for your API. +- `STAC FASTAPI_DESCRIPTION` (string) should be a good description for your API. It can contain CommonMark. +- `STAC_FASTAPI_LANDING_ID` (string) is a unique identifier for your Landing page. + + +## Default `includes` in Fields extension (POST request) + +The [**Fields** API extension](https://github.com/stac-api-extensions/fields) enables to filter in/out STAC Items keys (e.g `geometry`). The default behavior is to not filter out anything, but this can be overridden by providing a custom `FieldsExtensionPostRequest` class: + +```python +from typing import Optional, Set + +import attr +from stac_fastapi.extensions.core import FieldsExtension as FieldsExtensionBase +from stac_fastapi.extensions.core.fields import request +from pydantic import BaseModel, Field + + +class PostFieldsExtension(requests.PostFieldsExtension): + include: Optional[Set[str]] = Field( + default_factory=lambda: { + "id", + "type", + "stac_version", + "geometry", + "bbox", + "links", + "assets", + "properties.datetime", + "collection", + } + ) + exclude: Optional[Set[str]] = set() + + +class FieldsExtensionPostRequest(BaseModel): + """Additional fields and schema for the POST request.""" + + fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) + + +class FieldsExtension(FieldsExtensionBase): + """Override the POST model""" + + POST = FieldsExtensionPostRequest + + +from stac_fastapi.api.app import StacApi + +stac = StacApi( + extensions=[ + FieldsExtension() + ] +) +``` diff --git a/pyproject.toml b/pyproject.toml index 162a81b1e..9f4172999 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,8 @@ [tool.ruff] +target-version = "py38" # minimum supported version line-length = 90 + +[tool.ruff.lint] select = [ "C9", "D1", @@ -9,13 +12,60 @@ select = [ "W", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "**/tests/**/*.py" = ["D1"] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["stac_fastapi"] known-third-party = ["stac_pydantic", "fastapi"] section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] -[tool.black] -target-version = ["py38", "py39", "py310", "py311"] +[tool.ruff.format] +quote-style = "double" + +[tool.bumpversion] +current_version = "3.0.0b2" +parse = """(?x) + (?P\\d+)\\. + (?P\\d+)\\. + (?P\\d+) + (?: + (?Pa|b|rc) # pre-release label + (?P\\d+) # pre-release version number + )? # pre-release section is optional + (?: + \\.post + (?P\\d+) # post-release version number + )? # post-release section is optional +""" +serialize = [ + "{major}.{minor}.{patch}.post{post_n}", + "{major}.{minor}.{patch}{pre_l}{pre_n}", + "{major}.{minor}.{patch}", +] + +search = "{current_version}" +replace = "{new_version}" +regex = false +tag = false +commit = true + +[[tool.bumpversion.files]] +filename = "VERSION" +search = "{current_version}" +replace = "{new_version}" + +[[tool.bumpversion.files]] +filename = "stac_fastapi/api/stac_fastapi/api/version.py" +search = '__version__ = "{current_version}"' +replace = '__version__ = "{new_version}"' + +[[tool.bumpversion.files]] +filename = "stac_fastapi/extensions/stac_fastapi/extensions/version.py" +search = '__version__ = "{current_version}"' +replace = '__version__ = "{new_version}"' + +[[tool.bumpversion.files]] +filename = "stac_fastapi/types/stac_fastapi/types/version.py" +search = '__version__ = "{current_version}"' +replace = '__version__ = "{new_version}"' diff --git a/stac_fastapi/api/setup.py b/stac_fastapi/api/setup.py index 9dfa86ac9..5050d3a7c 100644 --- a/stac_fastapi/api/setup.py +++ b/stac_fastapi/api/setup.py @@ -6,9 +6,6 @@ desc = f.read() install_requires = [ - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", "brotli_asgi", "stac-fastapi.types", ] @@ -21,7 +18,6 @@ "pytest-asyncio", "pre-commit", "requests", - "pystac[validation]==1.*", ], "benchmark": [ "pytest-benchmark", @@ -41,6 +37,10 @@ "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", ], keywords="STAC FastAPI COG", diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 557896d8f..5148f2baf 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -1,4 +1,6 @@ """Fastapi app creation.""" + + from typing import Any, Dict, List, Optional, Tuple, Type, Union import attr @@ -6,27 +8,25 @@ from fastapi import APIRouter, FastAPI from fastapi.openapi.utils import get_openapi from fastapi.params import Depends -from stac_pydantic import Collection, Item, ItemCollection -from stac_pydantic.api import ConformanceClasses, LandingPage +from stac_pydantic import api from stac_pydantic.api.collections import Collections -from stac_pydantic.version import STAC_VERSION +from stac_pydantic.api.version import STAC_API_VERSION +from stac_pydantic.shared import MimeTypes +from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( + APIRequest, CollectionUri, EmptyRequest, GeoJSONResponse, ItemCollectionUri, ItemUri, - create_request_model, ) from stac_fastapi.api.openapi import update_openapi from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint - -# TODO: make this module not depend on `stac_fastapi.extensions` -from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension from stac_fastapi.types.config import ApiSettings, Settings from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient from stac_fastapi.types.extension import ApiExtension @@ -83,21 +83,40 @@ class StacApi: converter=update_openapi, ) router: APIRouter = attr.ib(default=attr.Factory(APIRouter)) - title: str = attr.ib(default="stac-fastapi") - api_version: str = attr.ib(default="0.1") - stac_version: str = attr.ib(default=STAC_VERSION) - description: str = attr.ib(default="stac-fastapi") + title: str = attr.ib( + default=attr.Factory( + lambda self: self.settings.stac_fastapi_title, takes_self=True + ) + ) + api_version: str = attr.ib( + default=attr.Factory( + lambda self: self.settings.stac_fastapi_version, takes_self=True + ) + ) + stac_version: str = attr.ib(default=STAC_API_VERSION) + description: str = attr.ib( + default=attr.Factory( + lambda self: self.settings.stac_fastapi_description, takes_self=True + ) + ) search_get_request_model: Type[BaseSearchGetRequest] = attr.ib( default=BaseSearchGetRequest ) search_post_request_model: Type[BaseSearchPostRequest] = attr.ib( default=BaseSearchPostRequest ) - pagination_extension = attr.ib(default=TokenPaginationExtension) + collections_get_request_model: Type[APIRequest] = attr.ib(default=EmptyRequest) + collection_get_request_model: Type[APIRequest] = attr.ib(default=CollectionUri) + items_get_request_model: Type[APIRequest] = attr.ib(default=ItemCollectionUri) + item_get_request_model: Type[APIRequest] = attr.ib(default=ItemUri) response_class: Type[Response] = attr.ib(default=JSONResponse) - middlewares: List = attr.ib( + middlewares: List[Middleware] = attr.ib( default=attr.Factory( - lambda: [BrotliMiddleware, CORSMiddleware, ProxyHeaderMiddleware] + lambda: [ + Middleware(BrotliMiddleware), + Middleware(CORSMiddleware), + Middleware(ProxyHeaderMiddleware), + ] ) ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) @@ -125,9 +144,17 @@ def register_landing_page(self): self.router.add_api_route( name="Landing Page", path="/", - response_model=LandingPage - if self.settings.enable_response_models - else None, + response_model=( + api.LandingPage if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.LandingPage, + }, + }, response_class=self.response_class, response_model_exclude_unset=False, response_model_exclude_none=True, @@ -144,9 +171,17 @@ def register_conformance_classes(self): self.router.add_api_route( name="Conformance Classes", path="/conformance", - response_model=ConformanceClasses - if self.settings.enable_response_models - else None, + response_model=( + api.Conformance if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.Conformance, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -163,12 +198,22 @@ def register_get_item(self): self.router.add_api_route( name="Get Item", path="/collections/{collection_id}/items/{item_id}", - response_model=Item if self.settings.enable_response_models else None, + response_model=api.Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.Item, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.get_item, ItemUri), + endpoint=create_async_endpoint( + self.client.get_item, self.item_get_request_model + ), ) def register_post_search(self): @@ -177,13 +222,20 @@ def register_post_search(self): Returns: None """ - fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", path="/search", - response_model=(ItemCollection if not fields_ext else None) + response_model=api.ItemCollection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -199,13 +251,20 @@ def register_get_search(self): Returns: None """ - fields_ext = self.get_extension(FieldsExtension) self.router.add_api_route( name="Search", path="/search", - response_model=(ItemCollection if not fields_ext else None) + response_model=api.ItemCollection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -224,14 +283,24 @@ def register_get_collections(self): self.router.add_api_route( name="Get Collections", path="/collections", - response_model=Collections - if self.settings.enable_response_models - else None, + response_model=( + Collections if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collections, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.all_collections, EmptyRequest), + endpoint=create_async_endpoint( + self.client.all_collections, self.collections_get_request_model + ), ) def register_get_collection(self): @@ -243,12 +312,24 @@ def register_get_collection(self): self.router.add_api_route( name="Get Collection", path="/collections/{collection_id}", - response_model=Collection if self.settings.enable_response_models else None, + response_model=api.Collection + if self.settings.enable_response_models + else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": api.Collection, + }, + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.get_collection, CollectionUri), + endpoint=create_async_endpoint( + self.client.get_collection, self.collection_get_request_model + ), ) def register_get_item_collection(self): @@ -257,27 +338,27 @@ def register_get_item_collection(self): Returns: None """ - pagination_extension = self.get_extension(self.pagination_extension) - if pagination_extension is not None: - mixins = [pagination_extension.GET] - else: - mixins = None - request_model = create_request_model( - "ItemCollectionURI", - base_model=ItemCollectionUri, - mixins=mixins, - ) self.router.add_api_route( name="Get ItemCollection", path="/collections/{collection_id}/items", - response_model=ItemCollection - if self.settings.enable_response_models - else None, + response_model=( + api.ItemCollection if self.settings.enable_response_models else None + ), + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": api.ItemCollection, + }, + }, response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.item_collection, request_model), + endpoint=create_async_endpoint( + self.client.item_collection, self.items_get_request_model + ), ) def register_core(self): @@ -364,10 +445,6 @@ def __attrs_post_init__(self): self.client.title = self.title self.client.description = self.description - fields_ext = self.get_extension(FieldsExtension) - if fields_ext: - self.settings.default_includes = fields_ext.default_includes - Settings.set(self.settings) self.app.state.settings = self.settings @@ -393,8 +470,11 @@ def __attrs_post_init__(self): self.app.openapi = self.customize_openapi # add middlewares + if self.middlewares and self.app.middleware_stack is not None: + raise RuntimeError("Cannot add middleware after an application has started") + for middleware in self.middlewares: - self.app.add_middleware(middleware) + self.app.user_middleware.insert(0, middleware) # customize route dependencies for scopes, dependencies in self.route_dependencies: diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index ccbe4ee14..74a1c7312 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -18,10 +18,11 @@ class ApiExtensions(enum.Enum): query = "query" sort = "sort" transaction = "transaction" + aggregation = "aggregation" + free_text = "free-text" class AddOns(enum.Enum): """Enumeration of available third party add ons.""" bulk_transaction = "bulk-transaction" - free_text = "free-text" diff --git a/stac_fastapi/api/stac_fastapi/api/errors.py b/stac_fastapi/api/stac_fastapi/api/errors.py index 3f052bd31..6d90ba63a 100644 --- a/stac_fastapi/api/stac_fastapi/api/errors.py +++ b/stac_fastapi/api/stac_fastapi/api/errors.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, Type, TypedDict from fastapi import FastAPI -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, ResponseValidationError from starlette import status from starlette.requests import Request from starlette.responses import JSONResponse @@ -27,6 +27,7 @@ DatabaseError: status.HTTP_424_FAILED_DEPENDENCY, Exception: status.HTTP_500_INTERNAL_SERVER_ERROR, InvalidQueryParameter: status.HTTP_400_BAD_REQUEST, + ResponseValidationError: status.HTTP_500_INTERNAL_SERVER_ERROR, } diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 3ed67d6c9..2ba3ef570 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,4 +1,5 @@ """Api middleware.""" + import re import typing from http.client import HTTP_PORT, HTTPS_PORT diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 53f376aa0..5a239b9f0 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,13 +1,12 @@ """Api request/response models.""" -import importlib.util -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import attr -from fastapi import Body, Path +from fastapi import Path, Query from pydantic import BaseModel, create_model -from pydantic.fields import UndefinedType from stac_pydantic.shared import BBox +from typing_extensions import Annotated from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.rfc3339 import DateTimeType @@ -15,15 +14,22 @@ APIRequest, BaseSearchGetRequest, BaseSearchPostRequest, - str2bbox, + _bbox_converter, + _datetime_converter, ) +try: + import orjson # noqa + from fastapi.responses import ORJSONResponse as JSONResponse +except ImportError: # pragma: nocover + from starlette.responses import JSONResponse + def create_request_model( model_name="SearchGetRequest", base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest, - extensions: Optional[ApiExtension] = None, - mixins: Optional[Union[BaseModel, APIRequest]] = None, + extensions: Optional[List[ApiExtension]] = None, + mixins: Optional[Union[List[BaseModel], List[APIRequest]]] = None, request_type: Optional[str] = "GET", ) -> Union[Type[BaseModel], APIRequest]: """Create a pydantic model for validating request bodies.""" @@ -46,40 +52,19 @@ def create_request_model( # Handle POST requests elif all([issubclass(m, BaseModel) for m in models]): for model in models: - for k, v in model.__fields__.items(): - field_info = v.field_info - body = Body( - None - if isinstance(field_info.default, UndefinedType) - else field_info.default, - default_factory=field_info.default_factory, - alias=field_info.alias, - alias_priority=field_info.alias_priority, - title=field_info.title, - description=field_info.description, - const=field_info.const, - gt=field_info.gt, - ge=field_info.ge, - lt=field_info.lt, - le=field_info.le, - multiple_of=field_info.multiple_of, - min_items=field_info.min_items, - max_items=field_info.max_items, - min_length=field_info.min_length, - max_length=field_info.max_length, - regex=field_info.regex, - extra=field_info.extra, - ) - fields[k] = (v.outer_type_, body) + for k, field_info in model.model_fields.items(): + fields[k] = (field_info.annotation, field_info) return create_model(model_name, **fields, __base__=base_model) raise TypeError("Mixed Request Model types. Check extension request types.") def create_get_request_model( - extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest -): + extensions: Optional[List[ApiExtension]], + base_model: BaseSearchGetRequest = BaseSearchGetRequest, +) -> APIRequest: """Wrap create_request_model to create the GET request model.""" + return create_request_model( "SearchGetRequest", base_model=base_model, @@ -89,8 +74,9 @@ def create_get_request_model( def create_post_request_model( - extensions, base_model: BaseSearchPostRequest = BaseSearchPostRequest -): + extensions: Optional[List[ApiExtension]], + base_model: BaseSearchPostRequest = BaseSearchPostRequest, +) -> Type[BaseModel]: """Wrap create_request_model to create the POST request model.""" return create_request_model( "SearchPostRequest", @@ -100,18 +86,19 @@ def create_post_request_model( ) -@attr.s # type:ignore +@attr.s class CollectionUri(APIRequest): - """Delete collection.""" + """Get or delete collection.""" - collection_id: str = attr.ib(default=Path(..., description="Collection ID")) + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() @attr.s -class ItemUri(CollectionUri): - """Delete item.""" +class ItemUri(APIRequest): + """Get or delete item.""" - item_id: str = attr.ib(default=Path(..., description="Item ID")) + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + item_id: Annotated[str, Path(description="Item ID")] = attr.ib() @attr.s @@ -122,63 +109,24 @@ class EmptyRequest(APIRequest): @attr.s -class ItemCollectionUri(CollectionUri): +class ItemCollectionUri(APIRequest): """Get item collection.""" - limit: int = attr.ib(default=10) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - datetime: Optional[DateTimeType] = attr.ib(default=None) - - -class POSTTokenPagination(BaseModel): - """Token pagination model for POST requests.""" - - token: Optional[str] = None - - -@attr.s -class GETTokenPagination(APIRequest): - """Token pagination for GET requests.""" - - token: Optional[str] = attr.ib(default=None) - - -class POSTPagination(BaseModel): - """Page based pagination for POST requests.""" - - page: Optional[str] = None - - -@attr.s -class GETPagination(APIRequest): - """Page based pagination for GET requests.""" - - page: Optional[str] = attr.ib(default=None) - - -# Test for ORJSON and use it rather than stdlib JSON where supported -if importlib.util.find_spec("orjson") is not None: - from fastapi.responses import ORJSONResponse - - class GeoJSONResponse(ORJSONResponse): - """JSON with custom, vendor content-type.""" - - media_type = "application/geo+json" - - class JSONSchemaResponse(ORJSONResponse): - """JSON with custom, vendor content-type.""" + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + limit: Annotated[int, Query()] = attr.ib(default=10) + bbox: Optional[BBox] = attr.ib(default=None, converter=_bbox_converter) + datetime: Optional[DateTimeType] = attr.ib( + default=None, converter=_datetime_converter + ) - media_type = "application/schema+json" -else: - from starlette.responses import JSONResponse +class GeoJSONResponse(JSONResponse): + """JSON with custom, vendor content-type.""" - class GeoJSONResponse(JSONResponse): - """JSON with custom, vendor content-type.""" + media_type = "application/geo+json" - media_type = "application/geo+json" - class JSONSchemaResponse(JSONResponse): - """JSON with custom, vendor content-type.""" +class JSONSchemaResponse(JSONResponse): + """JSON with custom, vendor content-type.""" - media_type = "application/schema+json" + media_type = "application/schema+json" diff --git a/stac_fastapi/api/stac_fastapi/api/openapi.py b/stac_fastapi/api/stac_fastapi/api/openapi.py index a38a70bae..ab90ce425 100644 --- a/stac_fastapi/api/stac_fastapi/api/openapi.py +++ b/stac_fastapi/api/stac_fastapi/api/openapi.py @@ -1,4 +1,5 @@ """openapi.""" + import warnings from fastapi import FastAPI @@ -43,9 +44,7 @@ async def patched_openapi_endpoint(req: Request) -> Response: # Get the response from the old endpoint function response: JSONResponse = await old_endpoint(req) # Update the content type header in place - response.headers[ - "content-type" - ] = "application/vnd.oai.openapi+json;version=3.0" + response.headers["content-type"] = "application/vnd.oai.openapi+json;version=3.0" # Return the updated response return response diff --git a/stac_fastapi/api/stac_fastapi/api/routes.py b/stac_fastapi/api/stac_fastapi/api/routes.py index 66b76d2d7..bd6f4d9cf 100644 --- a/stac_fastapi/api/stac_fastapi/api/routes.py +++ b/stac_fastapi/api/stac_fastapi/api/routes.py @@ -1,5 +1,6 @@ """Route factories.""" +import copy import functools import inspect import warnings @@ -45,7 +46,7 @@ def create_async_endpoint( """ if response_class: - warnings.warns( + warnings.warn( "`response_class` option is deprecated, please set the Response class directly in the endpoint.", # noqa: E501 DeprecationWarning, ) @@ -100,15 +101,28 @@ def add_route_dependencies( Allows a developer to add dependencies to a route after the route has been defined. + "*" can be used for path or method to match all allowed routes. + Returns: None """ for scope in scopes: + _scope = copy.deepcopy(scope) for route in routes: - match, _ = route.matches({"type": "http", **scope}) + if scope["path"] == "*": + _scope["path"] = route.path + + if scope["method"] == "*": + _scope["method"] = list(route.methods)[0] + + match, _ = route.matches({"type": "http", **_scope}) if match != Match.FULL: continue + # Ignore paths without dependants, e.g. /api, /api.html, /docs/oauth2-redirect + if not hasattr(route, "dependant"): + continue + # Mimicking how APIRoute handles dependencies: # https://github.com/tiangolo/fastapi/blob/1760da0efa55585c19835d81afa8ca386036c325/fastapi/routing.py#L408-L412 for depends in dependencies[::-1]: diff --git a/stac_fastapi/api/stac_fastapi/api/version.py b/stac_fastapi/api/stac_fastapi/api/version.py index bb0c7c379..7296e8a98 100644 --- a/stac_fastapi/api/stac_fastapi/api/version.py +++ b/stac_fastapi/api/stac_fastapi/api/version.py @@ -1,2 +1,2 @@ """Library version.""" -__version__ = "2.4.9" +__version__ = "3.0.0b2" diff --git a/stac_fastapi/api/tests/benchmarks.py b/stac_fastapi/api/tests/benchmarks.py index ad73d2424..475250d7f 100644 --- a/stac_fastapi/api/tests/benchmarks.py +++ b/stac_fastapi/api/tests/benchmarks.py @@ -17,6 +17,7 @@ collections = [ stac_types.Collection( id=f"test_collection_{n}", + type="Collection", title="Test Collection", description="A test collection", keywords=["test"], @@ -25,7 +26,7 @@ "spatial": {"bbox": [[-180, -90, 180, 90]]}, "temporal": {"interval": [["2000-01-01T00:00:00Z", None]]}, }, - links=collection_links.dict(exclude_none=True), + links=collection_links.model_dump(exclude_none=True), ) for n in range(0, 10) ] @@ -37,7 +38,7 @@ geometry={"type": "Point", "coordinates": [0, 0]}, bbox=[-180, -90, 180, 90], properties={"datetime": "2000-01-01T00:00:00Z"}, - links=item_links.dict(exclude_none=True), + links=item_links.model_dump(exclude_none=True), assets={}, ) for n in range(0, 1000) @@ -160,9 +161,7 @@ def f(): benchmark.group = "Collection With Model validation" if validate else "Collection" benchmark.name = "Collection With Model validation" if validate else "Collection" - benchmark.fullname = ( - "Collection With Model validation" if validate else "Collection" - ) + benchmark.fullname = "Collection With Model validation" if validate else "Collection" response = benchmark(f) assert response.status_code == 200 diff --git a/stac_fastapi/api/tests/conftest.py b/stac_fastapi/api/tests/conftest.py new file mode 100644 index 000000000..33919e83e --- /dev/null +++ b/stac_fastapi/api/tests/conftest.py @@ -0,0 +1,175 @@ +from datetime import datetime +from typing import List, Optional, Union + +import pytest +from stac_pydantic import Collection, Item +from stac_pydantic.api.utils import link_factory + +from stac_fastapi.types import core, stac +from stac_fastapi.types.core import NumType +from stac_fastapi.types.search import BaseSearchPostRequest + +collection_links = link_factory.CollectionLinks("/", "test").create_links() +item_links = link_factory.ItemLinks("/", "test", "test").create_links() + + +@pytest.fixture +def _collection(): + return Collection( + type="Collection", + id="test_collection", + title="Test Collection", + description="A test collection", + keywords=["test"], + license="proprietary", + extent={ + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": {"interval": [["2000-01-01T00:00:00Z", None]]}, + }, + links=collection_links, + ) + + +@pytest.fixture +def collection(_collection: Collection): + return _collection.model_dump_json() + + +@pytest.fixture +def collection_dict(_collection: Collection): + return _collection.model_dump(mode="json") + + +@pytest.fixture +def _item(): + return Item( + id="test_item", + type="Feature", + geometry={"type": "Point", "coordinates": [0, 0]}, + bbox=[-180, -90, 180, 90], + properties={"datetime": "2000-01-01T00:00:00Z"}, + links=item_links, + assets={}, + ) + + +@pytest.fixture +def item(_item: Item): + return _item.model_dump_json() + + +@pytest.fixture +def item_dict(_item: Item): + return _item.model_dump(mode="json") + + +@pytest.fixture +def TestCoreClient(collection_dict, item_dict): + class CoreClient(core.BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + return stac.Item(**item_dict) + + def all_collections(self, **kwargs) -> stac.Collections: + return stac.Collections( + collections=[stac.Collection(**collection_dict)], + links=[ + {"href": "test", "rel": "root"}, + {"href": "test", "rel": "self"}, + {"href": "test", "rel": "parent"}, + ], + ) + + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: + return stac.Collection(**collection_dict) + + def item_collection( + self, + collection_id: str, + bbox: Optional[List[Union[float, int]]] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: int = 10, + token: str = None, + **kwargs, + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + return CoreClient + + +@pytest.fixture +def AsyncTestCoreClient(collection_dict, item_dict): + class AsyncCoreClient(core.AsyncBaseCoreClient): + async def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + async def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + return stac.Item(**item_dict) + + async def all_collections(self, **kwargs) -> stac.Collections: + return stac.Collections( + collections=[stac.Collection(**collection_dict)], + links=[ + {"href": "test", "rel": "root"}, + {"href": "test", "rel": "self"}, + {"href": "test", "rel": "parent"}, + ], + ) + + async def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: + return stac.Collection(**collection_dict) + + async def item_collection( + self, + collection_id: str, + bbox: Optional[List[Union[float, int]]] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: int = 10, + token: str = None, + **kwargs, + ) -> stac.ItemCollection: + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + return AsyncCoreClient diff --git a/stac_fastapi/api/tests/test_api.py b/stac_fastapi/api/tests/test_api.py index 91b50371e..7db4d9a5e 100644 --- a/stac_fastapi/api/tests/test_api.py +++ b/stac_fastapi/api/tests/test_api.py @@ -2,7 +2,11 @@ from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi -from stac_fastapi.extensions.core import TokenPaginationExtension, TransactionExtension +from stac_fastapi.api.models import ItemCollectionUri, create_request_model +from stac_fastapi.extensions.core import ( + TokenPaginationExtension, + TransactionExtension, +) from stac_fastapi.types import config, core @@ -10,6 +14,13 @@ class TestRouteDependencies: @staticmethod def _build_api(**overrides): settings = config.ApiSettings() + + items_get_request_model = create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[TokenPaginationExtension().GET], + ) + return StacApi( **{ "settings": settings, @@ -20,6 +31,7 @@ def _build_api(**overrides): ), TokenPaginationExtension(), ], + "items_get_request_model": items_get_request_model, **overrides, } ) @@ -41,11 +53,29 @@ def _assert_dependency_applied(api, routes): method=route["method"].lower(), url=path, auth=("bob", "dobbs"), - content='{"dummy": "payload"}', + content=route["payload"], headers={"content-type": "application/json"}, ) assert ( - response.status_code == 200 + 200 <= response.status_code < 300 + ), "Authenticated requests should be accepted" + assert response.json() == "dummy response" + + @staticmethod + def _assert_dependency_not_applied(api, routes): + with TestClient(api.app) as client: + for route in routes: + path = route["path"].format( + collectionId="test_collection", itemId="test_item" + ) + response = client.request( + method=route["method"].lower(), + url=path, + content=route["payload"], + headers={"content-type": "application/json"}, + ) + assert ( + 200 <= response.status_code < 300 ), "Authenticated requests should be accepted" assert response.json() == "dummy response" @@ -58,32 +88,318 @@ def test_openapi_content_type(self): == "application/vnd.oai.openapi+json;version=3.0" ) - def test_build_api_with_route_dependencies(self): + def test_build_api_with_route_dependencies(self, collection, item): routes = [ - {"path": "/collections", "method": "POST"}, - {"path": "/collections/{collectionId}", "method": "PUT"}, - {"path": "/collections/{collectionId}", "method": "DELETE"}, - {"path": "/collections/{collectionId}/items", "method": "POST"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "PUT"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + {"path": "/collections/{collectionId}", "method": "DELETE", "payload": ""}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, ] dependencies = [Depends(must_be_bob)] api = self._build_api(route_dependencies=[(routes, dependencies)]) self._assert_dependency_applied(api, routes) - def test_add_route_dependencies_after_building_api(self): + def test_add_route_dependencies_after_building_api(self, collection, item): routes = [ - {"path": "/collections", "method": "POST"}, - {"path": "/collections/{collectionId}", "method": "PUT"}, - {"path": "/collections/{collectionId}", "method": "DELETE"}, - {"path": "/collections/{collectionId}/items", "method": "POST"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "PUT"}, - {"path": "/collections/{collectionId}/items/{itemId}", "method": "DELETE"}, + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + {"path": "/collections/{collectionId}", "method": "DELETE", "payload": ""}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, ] api = self._build_api() api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) self._assert_dependency_applied(api, routes) + def test_build_api_with_default_route_dependencies(self, collection, item): + routes = [{"path": "*", "method": "*"}] + test_routes = [ + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + {"path": "/collections/{collectionId}", "method": "DELETE", "payload": ""}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + dependencies = [Depends(must_be_bob)] + api = self._build_api(route_dependencies=[(routes, dependencies)]) + self._assert_dependency_applied(api, test_routes) + + def test_build_api_with_default_path_route_dependencies(self, collection, item): + routes = [{"path": "*", "method": "POST"}] + test_routes = [ + { + "path": "/collections", + "method": "POST", + "payload": collection, + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + ] + test_not_routes = [ + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": "", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + dependencies = [Depends(must_be_bob)] + api = self._build_api(route_dependencies=[(routes, dependencies)]) + self._assert_dependency_applied(api, test_routes) + self._assert_dependency_not_applied(api, test_not_routes) + + def test_build_api_with_default_method_route_dependencies(self, collection, item): + routes = [ + { + "path": "/collections/{collectionId}", + "method": "*", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "*", + }, + ] + test_routes = [ + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": "", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + test_not_routes = [ + { + "path": "/collections", + "method": "POST", + "payload": collection, + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + ] + dependencies = [Depends(must_be_bob)] + api = self._build_api(route_dependencies=[(routes, dependencies)]) + self._assert_dependency_applied(api, test_routes) + self._assert_dependency_not_applied(api, test_not_routes) + + def test_add_default_route_dependencies_after_building_api(self, collection, item): + routes = [{"path": "*", "method": "*"}] + test_routes = [ + { + "path": "/collections", + "method": "POST", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": "", + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + api = self._build_api() + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, test_routes) + + def test_add_default_path_route_dependencies_after_building_api( + self, collection, item + ): + routes = [{"path": "*", "method": "POST"}] + test_routes = [ + { + "path": "/collections", + "method": "POST", + "payload": collection, + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + ] + test_not_routes = [ + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": "", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + api = self._build_api() + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, test_routes) + self._assert_dependency_not_applied(api, test_not_routes) + + def test_add_default_method_route_dependencies_after_building_api( + self, collection, item + ): + routes = [ + { + "path": "/collections/{collectionId}", + "method": "*", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "*", + }, + ] + test_routes = [ + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": "", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": "", + }, + ] + test_not_routes = [ + { + "path": "/collections", + "method": "POST", + "payload": collection, + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + ] + api = self._build_api() + api.add_route_dependencies(scopes=routes, dependencies=[Depends(must_be_bob)]) + self._assert_dependency_applied(api, test_routes) + self._assert_dependency_not_applied(api, test_not_routes) + class DummyCoreClient(core.BaseCoreClient): def all_collections(self, *args, **kwargs): diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py new file mode 100644 index 000000000..0ddcb2429 --- /dev/null +++ b/stac_fastapi/api/tests/test_app.py @@ -0,0 +1,389 @@ +from datetime import datetime +from typing import List, Optional, Union + +import attr +import pytest +from fastapi import Path, Query +from fastapi.testclient import TestClient +from pydantic import ValidationError +from stac_pydantic import api +from typing_extensions import Annotated + +from stac_fastapi.api import app +from stac_fastapi.api.models import ( + APIRequest, + JSONResponse, + create_get_request_model, + create_post_request_model, +) +from stac_fastapi.extensions.core import FieldsExtension, FilterExtension +from stac_fastapi.types import stac +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient, NumType +from stac_fastapi.types.search import BaseSearchPostRequest + + +def test_client_response_type(TestCoreClient): + """Test all GET endpoints. Verify that responses are valid STAC items.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=TestCoreClient(), + ) + + with TestClient(test_app.app) as client: + landing = client.get("/") + collection = client.get("/collections/test") + collections = client.get("/collections") + item = client.get("/collections/test/items/test") + item_collection = client.get( + "/collections/test/items", + params={"limit": 10}, + ) + get_search = client.get( + "/search", + params={ + "collections": ["test"], + }, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + + assert landing.status_code == 200, landing.text + api.LandingPage(**landing.json()) + + assert collection.status_code == 200, collection.text + api.Collection(**collection.json()) + + assert collections.status_code == 200, collections.text + api.collections.Collections(**collections.json()) + + assert item.status_code == 200, item.text + api.Item(**item.json()) + + assert item_collection.status_code == 200, item_collection.text + api.ItemCollection(**item_collection.json()) + + assert get_search.status_code == 200, get_search.text + api.ItemCollection(**get_search.json()) + + assert post_search.status_code == 200, post_search.text + api.ItemCollection(**post_search.json()) + + +@pytest.mark.parametrize("validate", [True, False]) +def test_client_invalid_response_type(validate, TestCoreClient, item_dict): + """Check if the build in response validation switch works.""" + + class InValidResponseClient(TestCoreClient): + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + item_dict.pop("bbox") + item_dict.pop("geometry") + return stac.Item(**item_dict) + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=InValidResponseClient(), + ) + + with TestClient(test_app.app) as client: + item = client.get("/collections/test/items/test") + + # Even if API validation passes, we should receive an invalid item + if item.status_code == 200: + with pytest.raises(ValidationError): + api.Item(**item.json()) + + # If internal validation is on, we should expect an internal error + if validate: + assert item.status_code == 500, item.text + else: + assert item.status_code == 200, item.text + + +def test_client_openapi(TestCoreClient): + """Test if response models are all documented with OpenAPI.""" + + test_app = app.StacApi( + settings=ApiSettings(), + client=TestCoreClient(), + ) + test_app.app.openapi() + components = ["LandingPage", "Collection", "Collections", "Item", "ItemCollection"] + for component in components: + assert component in test_app.app.openapi_schema["components"]["schemas"] + + +@pytest.mark.parametrize("validate", [True, False]) +def test_filter_extension(validate, TestCoreClient, item_dict): + """Test if Filter Parameters are passed correctly.""" + + class FilterClient(TestCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + search_request.collections = ["test"] + search_request.filter = {} + search_request.filter_crs = "EPSG:4326" + search_request.filter_lang = "cql2-text" + + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + filter: Optional[str] = None, + filter_crs: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> stac.ItemCollection: + # Check if all filter parameters are passed correctly + + assert filter == "TEST" + + # FIXME: https://github.com/stac-utils/stac-fastapi/issues/638 + # hyphen alias for filter_crs and filter_lang are currently not working + # Query parameters `filter-crs` and `filter-lang` + # should be recognized by the API + # They are present in the `request.query_params` but not in the `kwargs` + + # assert filter_crs == "EPSG:4326" + # assert filter_lang == "cql2-text" + + return stac.ItemCollection( + type="FeatureCollection", features=[stac.Item(**item_dict)] + ) + + post_request_model = create_post_request_model([FilterExtension()]) + + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=FilterClient(post_request_model=post_request_model), + search_get_request_model=create_get_request_model([FilterExtension()]), + search_post_request_model=post_request_model, + extensions=[FilterExtension()], + ) + + with TestClient(test_app.app) as client: + landing = client.get("/") + get_search = client.get( + "/search", + params={ + "filter": "TEST", + "filter-crs": "EPSG:4326", + "filter-lang": "cql2-text", + }, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "filter": {}, + "filter-crs": "EPSG:4326", + "filter-lang": "cql2-text", + }, + ) + + assert landing.status_code == 200, landing.text + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text + + +@pytest.mark.parametrize("validate", [True, False]) +def test_fields_extension(validate, TestCoreClient, item_dict): + """Test if fields Parameters are passed correctly.""" + + class BadCoreClient(BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac.ItemCollection: + resp = {"not": "a proper stac item"} + + # if `fields` extension is enabled, then we return a JSONResponse + # to avoid Item validation + if getattr(search_request, "fields", None): + return JSONResponse(content=resp) + + return resp + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac.ItemCollection: + resp = {"not": "a proper stac item"} + + # if `fields` extension is enabled, then we return a JSONResponse + # to avoid Item validation + if "fields" in kwargs: + return JSONResponse(content=resp) + + return resp + + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: + raise NotImplementedError + + def all_collections(self, **kwargs) -> stac.Collections: + raise NotImplementedError + + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: + raise NotImplementedError + + def item_collection( + self, + collection_id: str, + bbox: Optional[List[Union[float, int]]] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: int = 10, + token: str = None, + **kwargs, + ) -> stac.ItemCollection: + raise NotImplementedError + + # With FieldsExtension + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=BadCoreClient(), + search_get_request_model=create_get_request_model([FieldsExtension()]), + search_post_request_model=create_post_request_model([FieldsExtension()]), + extensions=[FieldsExtension()], + ) + + with TestClient(test_app.app) as client: + get_search = client.get( + "/search", + params={"fields": "properties.datetime"}, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "fields": { + "include": ["properties.datetime"], + "exclude": [], + }, + }, + ) + + # With or without validation, /search endpoints will always return 200 + # because we have the `FieldsExtension` enabled, so the endpoint + # will avoid the model validation (by returning JSONResponse) + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text + + # Without FieldsExtension + test_app = app.StacApi( + settings=ApiSettings(enable_response_models=validate), + client=BadCoreClient(), + search_get_request_model=create_get_request_model([]), + search_post_request_model=create_post_request_model([]), + extensions=[], + ) + + with TestClient(test_app.app) as client: + get_search = client.get( + "/search", + params={"fields": "properties.datetime"}, + ) + post_search = client.post( + "/search", + json={ + "collections": ["test"], + "fields": { + "include": ["properties.datetime"], + "exclude": [], + }, + }, + ) + + if validate: + # NOTE: the `fields` options will be ignored by fastAPI because it's + # not part of the request model, so the client should not by-pass the validation + assert get_search.status_code == 500, ( + get_search.json()["code"] == "ResponseValidationError" + ) + assert post_search.status_code == 500, ( + post_search.json()["code"] == "ResponseValidationError" + ) + else: + assert get_search.status_code == 200, get_search.text + assert post_search.status_code == 200, post_search.text + + +def test_request_model(AsyncTestCoreClient): + """Test if request models are passed correctly.""" + + @attr.s + class CollectionsRequest(APIRequest): + user: Annotated[str, Query(...)] = attr.ib() + + @attr.s + class CollectionRequest(APIRequest): + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() + + @attr.s + class ItemsRequest(APIRequest): + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() + + @attr.s + class ItemRequest(APIRequest): + collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib() + item_id: Annotated[str, Path(description="Item ID")] = attr.ib() + user: Annotated[str, Query(...)] = attr.ib() + + test_app = app.StacApi( + settings=ApiSettings(), + client=AsyncTestCoreClient(), + collections_get_request_model=CollectionsRequest, + collection_get_request_model=CollectionRequest, + items_get_request_model=ItemsRequest, + item_get_request_model=ItemRequest, + extensions=[], + ) + + with TestClient(test_app.app) as client: + resp = client.get("/collections") + assert resp.status_code == 400 + + resp = client.get("/collections", params={"user": "Luke"}) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection") + assert resp.status_code == 400 + + resp = client.get("/collections/test_collection", params={"user": "Leia"}) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection/items") + assert resp.status_code == 400 + + resp = client.get( + "/collections/test_collection/items", params={"user": "Obi-Wan"} + ) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection/items/test_item") + assert resp.status_code == 400 + + resp = client.get( + "/collections/test_collection/items/test_item", params={"user": "Chewbacca"} + ) + assert resp.status_code == 200 diff --git a/stac_fastapi/api/tests/test_app_prefix.py b/stac_fastapi/api/tests/test_app_prefix.py new file mode 100644 index 000000000..5d604477c --- /dev/null +++ b/stac_fastapi/api/tests/test_app_prefix.py @@ -0,0 +1,146 @@ +import urllib +from typing import Optional + +import pytest +from fastapi import APIRouter +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.types.config import ApiSettings + + +def get_link(landing_page, rel_type, method: Optional[str] = None): + return next( + filter( + lambda link: link["rel"] == rel_type + and (not method or link.get("method") == method), + landing_page["links"], + ), + None, + ) + + +@pytest.mark.parametrize("prefix", ["", "/a_prefix"]) +def test_api_prefix(TestCoreClient, prefix): + api_settings = ApiSettings( + openapi_url=f"{prefix}/api", + docs_url=f"{prefix}/api.html", + ) + + api = StacApi( + settings=api_settings, + client=TestCoreClient(), + router=APIRouter(prefix=prefix), + ) + + with TestClient(api.app, base_url="http://stac.io") as client: + landing = client.get(f"{prefix}/") + assert landing.status_code == 200, landing.json() + + service_doc = client.get(f"{prefix}/api.html") + assert service_doc.status_code == 200, service_doc.text + + service_desc = client.get(f"{prefix}/api") + assert service_desc.status_code == 200, service_desc.json() + + conformance = client.get(f"{prefix}/conformance") + assert conformance.status_code == 200, conformance.json() + + # NOTE: The collections/collection/items/item links do not have the prefix + # because they are created in the fixtures + collections = client.get(f"{prefix}/collections") + assert collections.status_code == 200, collections.json() + collection_id = collections.json()["collections"][0]["id"] + + collection = client.get(f"{prefix}/collections/{collection_id}") + assert collection.status_code == 200, collection.json() + + items = client.get(f"{prefix}/collections/{collection_id}/items") + assert items.status_code == 200, items.json() + + item_id = items.json()["features"][0]["id"] + item = client.get(f"{prefix}/collections/{collection_id}/items/{item_id}") + assert item.status_code == 200, item.json() + + link_tests = [ + ("root", "application/json", "/"), + ("conformance", "application/json", "/conformance"), + ("data", "application/json", "/collections"), + ("search", "application/geo+json", "/search"), + ("service-doc", "text/html", "/api.html"), + ("service-desc", "application/vnd.oai.openapi+json;version=3.0", "/api"), + ] + + for rel_type, expected_media_type, expected_path in link_tests: + link = get_link(landing.json(), rel_type) + + assert link is not None, f"Missing {rel_type} link in landing page" + assert link.get("type") == expected_media_type + + link_path = urllib.parse.urlsplit(link.get("href")).path + assert link_path == prefix + expected_path + + resp = client.get(prefix + expected_path) + assert resp.status_code == 200 + + +@pytest.mark.parametrize("prefix", ["", "/a_prefix"]) +def test_async_api_prefix(AsyncTestCoreClient, prefix): + api_settings = ApiSettings( + openapi_url=f"{prefix}/api", + docs_url=f"{prefix}/api.html", + ) + + api = StacApi( + settings=api_settings, + client=AsyncTestCoreClient(), + router=APIRouter(prefix=prefix), + ) + + with TestClient(api.app, base_url="http://stac.io") as client: + landing = client.get(f"{prefix}/") + assert landing.status_code == 200, landing.json() + + service_doc = client.get(f"{prefix}/api.html") + assert service_doc.status_code == 200, service_doc.text + + service_desc = client.get(f"{prefix}/api") + assert service_desc.status_code == 200, service_desc.json() + + conformance = client.get(f"{prefix}/conformance") + assert conformance.status_code == 200, conformance.json() + + collections = client.get(f"{prefix}/collections") + assert collections.status_code == 200, collections.json() + collection_id = collections.json()["collections"][0]["id"] + + collection = client.get(f"{prefix}/collections/{collection_id}") + assert collection.status_code == 200, collection.json() + + items = client.get(f"{prefix}/collections/{collection_id}/items") + assert items.status_code == 200, items.json() + + item_id = items.json()["features"][0]["id"] + item = client.get(f"{prefix}/collections/{collection_id}/items/{item_id}") + assert item.status_code == 200, item.json() + + link_tests = [ + ("root", "application/json", "/"), + ("conformance", "application/json", "/conformance"), + ("data", "application/json", "/collections"), + ("search", "application/geo+json", "/search"), + ("service-doc", "text/html", "/api.html"), + ("service-desc", "application/vnd.oai.openapi+json;version=3.0", "/api"), + ] + + for rel_type, expected_media_type, expected_path in link_tests: + link = get_link(landing.json(), rel_type) + + assert link is not None, f"Missing {rel_type} link in landing page" + assert link.get("type") == expected_media_type + + link_path = urllib.parse.urlsplit(link.get("href")).path + assert link_path == prefix + expected_path + + resp = client.get(prefix + expected_path) + assert resp.status_code == 200 diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py index 041dc410c..00e7f8038 100644 --- a/stac_fastapi/api/tests/test_middleware.py +++ b/stac_fastapi/api/tests/test_middleware.py @@ -1,6 +1,8 @@ from unittest import mock import pytest +from fastapi import Request +from fastapi.responses import JSONResponse from starlette.applications import Starlette from starlette.testclient import TestClient @@ -166,3 +168,31 @@ def test_cors_middleware(test_client): resp = test_client.get("/_mgmt/ping", headers={"Origin": "http://netloc"}) assert resp.status_code == 200 assert resp.headers["access-control-allow-origin"] == "*" + + +def test_middleware_stack(): + stac_api = StacApi( + settings=ApiSettings(), client=mock.create_autospec(BaseCoreClient) + ) + + def exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"customerrordetail": "yoo", "body": "yo"}, + ) + + class CustomException(Exception): + "Custom Exception" + + pass + + stac_api.app.add_exception_handler(CustomException, exception_handler) + + @stac_api.app.get("/error") + def error_endpoint(): + raise CustomException("got you!") + + with TestClient(stac_api.app) as client: + resp = client.get("/error") + assert resp.status_code == 400 + assert resp.json()["customerrordetail"] == "yoo" diff --git a/stac_fastapi/api/tests/test_models.py b/stac_fastapi/api/tests/test_models.py new file mode 100644 index 000000000..b0c2ad90e --- /dev/null +++ b/stac_fastapi/api/tests/test_models.py @@ -0,0 +1,131 @@ +import json + +import pytest +from fastapi import Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient +from pydantic import ValidationError + +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions.core import FieldsExtension, FilterExtension, SortExtension +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest + + +def test_create_get_request_model(): + request_model = create_get_request_model( + extensions=[FilterExtension(), FieldsExtension()], + base_model=BaseSearchGetRequest, + ) + + model = request_model( + collections="test1,test2", + ids="test1,test2", + bbox="0,0,1,1", + intersects=json.dumps( + { + "type": "Polygon", + "coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], + } + ), + datetime="2020-01-01T00:00:00Z", + limit=10, + filter="test==test", + filter_crs="epsg:4326", + filter_lang="cql2-text", + ) + + assert model.collections == ["test1", "test2"] + assert model.filter_crs == "epsg:4326" + + with pytest.raises(HTTPException): + request_model(datetime="yo") + + app = FastAPI() + + @app.get("/test") + def route(model=Depends(request_model)): + return model + + with TestClient(app) as client: + resp = client.get( + "/test", + params={ + "collections": "test1,test2", + "filter-crs": "epsg:4326", + "filter-lang": "cql2-text", + }, + ) + assert resp.status_code == 200 + response_dict = resp.json() + assert response_dict["collections"] == ["test1", "test2"] + assert response_dict["filter_crs"] == "epsg:4326" + assert response_dict["filter_lang"] == "cql2-text" + + +@pytest.mark.parametrize( + "filter,passes", + [(None, True), ({"test": "test"}, True), ("test==test", False), ([], False)], +) +def test_create_post_request_model(filter, passes): + request_model = create_post_request_model( + extensions=[FilterExtension(), FieldsExtension()], + base_model=BaseSearchPostRequest, + ) + + if not passes: + with pytest.raises(ValidationError): + model = request_model(filter=filter) + else: + model = request_model( + collections=["test1", "test2"], + ids=["test1", "test2"], + bbox=[0, 0, 1, 1], + datetime="2020-01-01T00:00:00Z", + limit=10, + filter=filter, + **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"}, + ) + + assert model.collections == ["test1", "test2"] + assert model.filter_crs == "epsg:4326" + assert model.filter == filter + + +@pytest.mark.parametrize( + "sortby,passes", + [ + (None, True), + ( + [ + {"field": "test", "direction": "asc"}, + {"field": "test2", "direction": "desc"}, + ], + True, + ), + ({"field": "test", "direction": "desc"}, False), + ("test", False), + ], +) +def test_create_post_request_model_nested_fields(sortby, passes): + request_model = create_post_request_model( + extensions=[SortExtension()], + base_model=BaseSearchPostRequest, + ) + + if not passes: + with pytest.raises(ValidationError): + model = request_model(sortby=sortby) + else: + model = request_model( + collections=["test1", "test2"], + ids=["test1", "test2"], + bbox=[0, 0, 1, 1], + datetime="2020-01-01T00:00:00Z", + limit=10, + sortby=sortby, + ) + + assert model.collections == ["test1", "test2"] + if model.sortby is None: + assert sortby is None + else: + assert model.model_dump(mode="json")["sortby"] == sortby diff --git a/stac_fastapi/extensions/setup.py b/stac_fastapi/extensions/setup.py index a70ea5855..39bc59b3f 100644 --- a/stac_fastapi/extensions/setup.py +++ b/stac_fastapi/extensions/setup.py @@ -1,14 +1,12 @@ """stac_fastapi: extensions module.""" + from setuptools import find_namespace_packages, setup with open("README.md") as f: desc = f.read() install_requires = [ - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", "stac-fastapi.types", "stac-fastapi.api", ] @@ -36,6 +34,10 @@ "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", ], keywords="STAC FastAPI COG", diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index 96317fe4a..fa935d8e8 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -1,16 +1,21 @@ """stac_api.extensions.core module.""" + +from .aggregation import AggregationExtension from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension +from .free_text import FreeTextExtension from .pagination import PaginationExtension, TokenPaginationExtension from .query import QueryExtension from .sort import SortExtension from .transaction import TransactionExtension __all__ = ( + "AggregationExtension", "ContextExtension", "FieldsExtension", "FilterExtension", + "FreeTextExtension", "PaginationExtension", "QueryExtension", "SortExtension", diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py new file mode 100644 index 000000000..2a7fc7a71 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py @@ -0,0 +1,5 @@ +"""Aggregation extension module.""" + +from .aggregation import AggregationExtension + +__all__ = ["AggregationExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py new file mode 100644 index 000000000..c6e892914 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py @@ -0,0 +1,111 @@ +"""Aggregation Extension.""" +from enum import Enum +from typing import List, Union + +import attr +from fastapi import APIRouter, FastAPI + +from stac_fastapi.api.models import CollectionUri, EmptyRequest +from stac_fastapi.api.routes import create_async_endpoint +from stac_fastapi.types.extension import ApiExtension + +from .client import AsyncBaseAggregationClient, BaseAggregationClient +from .request import AggregationExtensionGetRequest, AggregationExtensionPostRequest + + +class AggregationConformanceClasses(str, Enum): + """Conformance classes for the Aggregation extension. + + See + https://github.com/stac-api-extensions/aggregation + """ + + AGGREGATION = "https://api.stacspec.org/v0.3.0/aggregation" + + +@attr.s +class AggregationExtension(ApiExtension): + """Aggregation Extension. + + The purpose of the Aggregation Extension is to provide an endpoint similar to + the Search endpoint (/search), but which will provide aggregated information + on matching Items rather than the Items themselves. This is highly influenced + by the Elasticsearch and OpenSearch aggregation endpoint, but with a more + regular structure for responses. + + The Aggregation extension adds several endpoints which allow the retrieval of + available aggregation fields and aggregation buckets based on a seearch query: + GET /aggregations + POST /aggregations + GET /collections/{collection_id}/aggregations + POST /collections/{collection_id}/aggregations + GET /aggregate + POST /aggregate + GET /collections/{collection_id}/aggregate + POST /collections/{collection_id}/aggregate + + https://github.com/stac-api-extensions/aggregation/blob/main/README.md + + Attributes: + conformance_classes: Conformance classes provided by the extension + """ + + GET = AggregationExtensionGetRequest + POST = AggregationExtensionPostRequest + + client: Union[AsyncBaseAggregationClient, BaseAggregationClient] = attr.ib( + factory=BaseAggregationClient + ) + + conformance_classes: List[str] = attr.ib( + default=[AggregationConformanceClasses.AGGREGATION] + ) + router: APIRouter = attr.ib(factory=APIRouter) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + self.router.prefix = app.state.router_prefix + self.router.add_api_route( + name="Aggregations", + path="/aggregations", + methods=["GET", "POST"], + endpoint=create_async_endpoint(self.client.get_aggregations, EmptyRequest), + ) + self.router.add_api_route( + name="Collection Aggregations", + path="/collections/{collection_id}/aggregations", + methods=["GET", "POST"], + endpoint=create_async_endpoint(self.client.get_aggregations, CollectionUri), + ) + self.router.add_api_route( + name="Aggregate", + path="/aggregate", + methods=["GET"], + endpoint=create_async_endpoint(self.client.aggregate, self.GET), + ) + self.router.add_api_route( + name="Aggregate", + path="/aggregate", + methods=["POST"], + endpoint=create_async_endpoint(self.client.aggregate, self.POST), + ) + self.router.add_api_route( + name="Collection Aggregate", + path="/collections/{collection_id}/aggregate", + methods=["GET"], + endpoint=create_async_endpoint(self.client.aggregate, self.GET), + ) + self.router.add_api_route( + name="Collection Aggregate", + path="/collections/{collection_id}/aggregate", + methods=["POST"], + endpoint=create_async_endpoint(self.client.aggregate, self.POST), + ) + app.include_router(self.router, tags=["Aggregation Extension"]) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py new file mode 100644 index 000000000..23d90fb28 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py @@ -0,0 +1,131 @@ +"""Aggregation extensions clients.""" + +import abc +from typing import List, Optional, Union + +import attr +from geojson_pydantic.geometries import Geometry +from stac_pydantic.shared import BBox + +from stac_fastapi.types.rfc3339 import DateTimeType + +from .types import Aggregation, AggregationCollection + + +@attr.s +class BaseAggregationClient(abc.ABC): + """Defines a pattern for implementing the STAC aggregation extension.""" + + # BUCKET = Bucket + # AGGREGAION = Aggregation + # AGGREGATION_COLLECTION = AggregationCollection + + def get_aggregations( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Get the aggregations available for the given collection_id. + + If collection_id is None, returns the available aggregations over all + collections. + """ + return AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(name="total_count", data_type="integer")], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + def aggregate( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Return the aggregation buckets for a given search result""" + return AggregationCollection( + type="AggregationCollection", + aggregations=[], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + +@attr.s +class AsyncBaseAggregationClient(abc.ABC): + """Defines an async pattern for implementing the STAC aggregation extension.""" + + # BUCKET = Bucket + # AGGREGAION = Aggregation + # AGGREGATION_COLLECTION = AggregationCollection + + async def get_aggregations( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Get the aggregations available for the given collection_id. + + If collection_id is None, returns the available aggregations over all + collections. + """ + return AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(name="total_count", data_type="integer")], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + async def aggregate( + self, + collection_id: Optional[str] = None, + aggregations: Optional[Union[str, List[str]]] = None, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> AggregationCollection: + """Return the aggregation buckets for a given search result""" + return AggregationCollection( + type="AggregationCollection", + aggregations=[], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py new file mode 100644 index 000000000..4e72e0005 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py @@ -0,0 +1,39 @@ +"""Request model for the Aggregation extension.""" + +from typing import List, Optional + +import attr +from fastapi import Query +from pydantic import Field +from typing_extensions import Annotated + +from stac_fastapi.types.search import ( + BaseSearchGetRequest, + BaseSearchPostRequest, + str2list, +) + + +def _agg_converter( + val: Annotated[ + Optional[str], + Query(description="A list of aggregations to compute and return."), + ] = None, +) -> Optional[List[str]]: + return str2list(val) + + +@attr.s +class AggregationExtensionGetRequest(BaseSearchGetRequest): + """Aggregation Extension GET request model.""" + + aggregations: Optional[List[str]] = attr.ib(default=None, converter=_agg_converter) + + +class AggregationExtensionPostRequest(BaseSearchPostRequest): + """Aggregation Extension POST request model.""" + + aggregations: Optional[List[str]] = Field( + default=None, + description="A list of aggregations to compute and return.", + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py new file mode 100644 index 000000000..428b65225 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py @@ -0,0 +1,36 @@ +"""Aggregation Extension types.""" + +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import Field +from typing_extensions import TypedDict + +from stac_fastapi.types.rfc3339 import DateTimeType + + +class Bucket(TypedDict, total=False): + """A STAC aggregation bucket.""" + + key: str + data_type: str + frequency: Optional[Dict] = None + _from: Optional[Union[int, float]] = Field(alias="from", default=None) + to: Optional[Optional[Union[int, float]]] = None + + +class Aggregation(TypedDict, total=False): + """A STAC aggregation.""" + + name: str + data_type: str + buckets: Optional[List[Bucket]] = None + overflow: Optional[int] = None + value: Optional[Union[str, int, DateTimeType]] = None + + +class AggregationCollection(TypedDict, total=False): + """STAC Item Aggregation Collection.""" + + type: Literal["AggregationCollection"] + aggregations: List[Aggregation] + links: List[Dict[str, Any]] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/context.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/context.py index 90faae914..4037ba938 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/context.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/context.py @@ -1,4 +1,6 @@ """Context extension.""" + +import warnings from typing import List, Optional import attr @@ -24,6 +26,14 @@ class ContextExtension(ApiExtension): default="https://raw.githubusercontent.com/stac-api-extensions/context/v1.0.0-rc.2/json-schema/schema.json" ) + def __attrs_post_init__(self): + """init.""" + warnings.warn( + "The ContextExtension is deprecated and will be removed in 3.0.", + DeprecationWarning, + stacklevel=1, + ) + def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py index b9a246b63..087d01b7a 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/__init__.py @@ -1,6 +1,5 @@ """Fields extension module.""" - from .fields import FieldsExtension __all__ = ["FieldsExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py index df4cd44de..90b4b2697 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/fields.py @@ -1,5 +1,6 @@ """Fields extension.""" -from typing import List, Optional, Set + +from typing import List, Optional import attr from fastapi import FastAPI @@ -34,19 +35,6 @@ class FieldsExtension(ApiExtension): conformance_classes: List[str] = attr.ib( factory=lambda: ["https://api.stacspec.org/v1.0.0/item-search#fields"] ) - default_includes: Set[str] = attr.ib( - factory=lambda: { - "id", - "type", - "stac_version", - "geometry", - "bbox", - "links", - "assets", - "properties.datetime", - "collection", - } - ) schema_href: Optional[str] = attr.ib(default=None) def register(self, app: FastAPI) -> None: diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py index 4cfbd3293..d3737ea49 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/fields/request.py @@ -1,11 +1,13 @@ """Request models for the fields extension.""" -from typing import Dict, Optional, Set +import warnings +from typing import Dict, List, Optional, Set import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated -from stac_fastapi.types.config import Settings from stac_fastapi.types.search import APIRequest, str2list @@ -39,6 +41,7 @@ def _get_field_dict(fields: Optional[Set[str]]) -> Dict: field_dict[parent].add(key) else: field_dict[field] = ... # type:ignore + return field_dict @property @@ -49,10 +52,17 @@ def filter_fields(self) -> Dict: the included and excluded fields passed to the API Ref: https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude """ + warnings.warn( + """The `PostFieldsExtension.filter_fields` + method is deprecated and will be removed in 3.0.""", + DeprecationWarning, + stacklevel=1, + ) + # Always include default_includes, even if they # exist in the exclude list. include = (self.include or set()) - (self.exclude or set()) - include |= Settings.get().default_includes or set() + include |= set() return { "include": self._get_field_dict(include), @@ -60,14 +70,31 @@ def filter_fields(self) -> Dict: } +def _fields_converter( + val: Annotated[ + Optional[str], + Query( + description="Include or exclude fields from items body.", + json_schema_extra={ + "example": "properties.datetime", + }, + ), + ] = None, +) -> Optional[List[str]]: + return str2list(val) + + @attr.s class FieldsExtensionGetRequest(APIRequest): """Additional fields for the GET request.""" - fields: Optional[str] = attr.ib(default=None, converter=str2list) + fields: Optional[List[str]] = attr.ib(default=None, converter=_fields_converter) class FieldsExtensionPostRequest(BaseModel): """Additional fields and schema for the POST request.""" - fields: Optional[PostFieldsExtension] = Field(PostFieldsExtension()) + fields: Optional[PostFieldsExtension] = Field( + PostFieldsExtension(), + description="Include or exclude fields from items body.", + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py index 78256bfd2..256f3e06e 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/__init__.py @@ -1,6 +1,5 @@ """Filter extension module.""" - from .filter import FilterExtension __all__ = ["FilterExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/client.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/client.py new file mode 100644 index 000000000..03ef96614 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/client.py @@ -0,0 +1,58 @@ +"""Filter extensions clients.""" + +import abc +from typing import Any, Dict, Optional + +import attr + + +@attr.s +class AsyncBaseFiltersClient(abc.ABC): + """Defines a pattern for implementing the STAC filter extension.""" + + async def get_queryables( + self, collection_id: Optional[str] = None, **kwargs + ) -> Dict[str, Any]: + """Get the queryables available for the given collection_id. + + If collection_id is None, returns the intersection of all queryables over all + collections. + + This base implementation returns a blank queryable schema. This is not allowed + under OGC CQL but it is allowed by the STAC API Filter Extension + https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables + """ + return { + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "https://example.org/queryables", + "type": "object", + "title": "Queryables for Example STAC API", + "description": "Queryable names for the example STAC API Item Search filter.", + "properties": {}, + } + + +@attr.s +class BaseFiltersClient(abc.ABC): + """Defines a pattern for implementing the STAC filter extension.""" + + def get_queryables( + self, collection_id: Optional[str] = None, **kwargs + ) -> Dict[str, Any]: + """Get the queryables available for the given collection_id. + + If collection_id is None, returns the intersection of all queryables over all + collections. + + This base implementation returns a blank queryable schema. This is not allowed + under OGC CQL but it is allowed by the STAC API Filter Extension + https://github.com/stac-api-extensions/filter#queryables + """ + return { + "$schema": "https://json-schema.org/draft/2019-09/schema", + "$id": "https://example.org/queryables", + "type": "object", + "title": "Queryables for Example STAC API", + "description": "Queryable names for the example STAC API Item Search filter.", + "properties": {}, + } diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py index 2f875907a..cd9463ec6 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py @@ -9,9 +9,9 @@ from stac_fastapi.api.models import CollectionUri, EmptyRequest, JSONSchemaResponse from stac_fastapi.api.routes import create_async_endpoint -from stac_fastapi.types.core import AsyncBaseFiltersClient, BaseFiltersClient from stac_fastapi.types.extension import ApiExtension +from .client import AsyncBaseFiltersClient, BaseFiltersClient from .request import FilterExtensionGetRequest, FilterExtensionPostRequest @@ -97,16 +97,30 @@ def register(self, app: FastAPI) -> None: name="Queryables", path="/queryables", methods=["GET"], - endpoint=create_async_endpoint( - self.client.get_queryables, EmptyRequest, self.response_class - ), + responses={ + 200: { + "content": { + "application/schema+json": {}, + }, + # TODO: add output model in stac-pydantic + }, + }, + response_class=self.response_class, + endpoint=create_async_endpoint(self.client.get_queryables, EmptyRequest), ) self.router.add_api_route( name="Collection Queryables", path="/collections/{collection_id}/queryables", methods=["GET"], - endpoint=create_async_endpoint( - self.client.get_queryables, CollectionUri, self.response_class - ), + responses={ + 200: { + "content": { + "application/schema+json": {}, + }, + # TODO: add output model in stac-pydantic + }, + }, + response_class=self.response_class, + endpoint=create_async_endpoint(self.client.get_queryables, CollectionUri), ) app.include_router(self.router, tags=["Filter Extension"]) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py index 1fcd6b0b9..30ac011b0 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/request.py @@ -1,41 +1,77 @@ """Filter extension request models.""" -from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional import attr +from fastapi import Query from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest - -class FilterLang(str, Enum): - """Choices for filter-lang value in a POST request. - - Based on - https://github.com/stac-api-extensions/filter#queryables - - Note the addition of cql2-json, which is used by the pgstac backend, - but is not included in the spec above. - """ - - cql_json = "cql-json" - cql2_json = "cql2-json" - cql2_text = "cql2-text" +FilterLang = Literal["cql-json", "cql2-json", "cql2-text"] @attr.s class FilterExtensionGetRequest(APIRequest): """Filter extension GET request model.""" - filter: Optional[str] = attr.ib(default=None) - filter_crs: Optional[str] = Field(alias="filter-crs", default=None) - filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql2-text") + filter: Annotated[ + Optional[str], + Query( + description="""A CQL filter expression for filtering items.\n +Supports `CQL-JSON` as defined in https://portal.ogc.org/files/96288\n +Remember to URL encode the CQL-JSON if using GET""", + json_schema_extra={ + "example": "id='LC08_L1TP_060247_20180905_20180912_01_T1_L1TP' AND collection='landsat8_l1tp'", # noqa: E501 + }, + ), + ] = attr.ib(default=None) + filter_crs: Annotated[ + Optional[str], + Query( + alias="filter-crs", + description="The coordinate reference system (CRS) used by spatial literals in the 'filter' value. Default is `http://www.opengis.net/def/crs/OGC/1.3/CRS84`", # noqa: E501 + ), + ] = attr.ib(default=None) + filter_lang: Annotated[ + Optional[FilterLang], + Query( + alias="filter-lang", + description="The CQL filter encoding that the 'filter' value uses.", + ), + ] = attr.ib(default="cql2-text") class FilterExtensionPostRequest(BaseModel): """Filter extension POST request model.""" - filter: Optional[Dict[str, Any]] = None - filter_crs: Optional[str] = Field(alias="filter-crs", default=None) - filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql-json") + filter: Optional[Dict[str, Any]] = Field( + default=None, + description="A CQL filter expression for filtering items.", + json_schema_extra={ + "example": { + "op": "and", + "args": [ + { + "op": "=", + "args": [ + {"property": "id"}, + "LC08_L1TP_060247_20180905_20180912_01_T1_L1TP", + ], + }, + {"op": "=", "args": [{"property": "collection"}, "landsat8_l1tp"]}, + ], + }, + }, + ) + filter_crs: Optional[str] = Field( + alias="filter-crs", + default=None, + description="The coordinate reference system (CRS) used by spatial literals in the 'filter' value. Default is `http://www.opengis.net/def/crs/OGC/1.3/CRS84`", # noqa: E501 + ) + filter_lang: Optional[FilterLang] = Field( + alias="filter-lang", + default="cql2-json", + description="The CQL filter encoding that the 'filter' value uses.", + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/__init__.py new file mode 100644 index 000000000..1865d64f0 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/__init__.py @@ -0,0 +1,5 @@ +"""Query extension module.""" + +from .free_text import FreeTextConformanceClasses, FreeTextExtension + +__all__ = ["FreeTextExtension", "FreeTextConformanceClasses"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/free_text.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/free_text.py new file mode 100644 index 000000000..be1c389ac --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/free_text.py @@ -0,0 +1,64 @@ +"""Free-text extension.""" + +from enum import Enum +from typing import List, Optional + +import attr +from fastapi import FastAPI + +from stac_fastapi.types.extension import ApiExtension + +from .request import FreeTextExtensionGetRequest, FreeTextExtensionPostRequest + + +class FreeTextConformanceClasses(str, Enum): + """Conformance classes for the Free-Text extension. + + See https://github.com/stac-api-extensions/freetext-search + + """ + + # https://github.com/stac-api-extensions/freetext-search?tab=readme-ov-file#basic + SEARCH_BASIC = "https://api.stacspec.org/v1.0.0-rc.1/item-search#free-text" + COLLECTIONS_BASIC = "https://api.stacspec.org/v1.0.0-rc.1/collection-search#free-text" + ITEMS_BASIC = "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#free-text" + + # https://github.com/stac-api-extensions/freetext-search?tab=readme-ov-file#advanced + SEARCH_ADVANCED = ( + "https://api.stacspec.org/v1.0.0-rc.1/item-search#advanced-free-text" + ) + COLLECTIONS_ADVANCED = ( + "https://api.stacspec.org/v1.0.0-rc.1/collection-search#advanced-free-text" + ) + ITEMS_ADVANCED = ( + "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#advanced-free-text" + ) + + +@attr.s +class FreeTextExtension(ApiExtension): + """Free-text Extension. + + The Free-text extension adds an additional `q` parameter to `/search` requests which + allows the caller to perform free-text queries against STAC metadata. + + https://github.com/stac-api-extensions/freetext-search/README.md + + """ + + GET = FreeTextExtensionGetRequest + POST = FreeTextExtensionPostRequest + + conformance_classes: List[str] = attr.ib() + schema_href: Optional[str] = attr.ib(default=None) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/request.py new file mode 100644 index 000000000..8058fe03a --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/free_text/request.py @@ -0,0 +1,34 @@ +"""Request model for the Free-text extension.""" + +from typing import Optional + +import attr +from fastapi import Query +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from stac_fastapi.types.search import APIRequest + + +@attr.s +class FreeTextExtensionGetRequest(APIRequest): + """Free-text Extension GET request model.""" + + q: Annotated[ + Optional[str], + Query( + description="Parameter to perform free-text queries against STAC metadata", + json_schema_extra={ + "example": "item1,item2", + }, + ), + ] = attr.ib(default=None) + + +class FreeTextExtensionPostRequest(BaseModel): + """Free-text Extension POST request model.""" + + q: Optional[str] = Field( + None, + description="Parameter to perform free-text queries against STAC metadata", + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py index 296e9ae6a..7959b0357 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/pagination.py @@ -5,9 +5,10 @@ import attr from fastapi import FastAPI -from stac_fastapi.api.models import GETPagination, POSTPagination from stac_fastapi.types.extension import ApiExtension +from .request import GETPagination, POSTPagination + @attr.s class PaginationExtension(ApiExtension): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py new file mode 100644 index 000000000..66391c7f9 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/request.py @@ -0,0 +1,36 @@ +"""Pagination extension request models.""" + +from typing import Optional + +import attr +from fastapi import Query +from pydantic import BaseModel +from typing_extensions import Annotated + +from stac_fastapi.types.search import APIRequest + + +@attr.s +class GETTokenPagination(APIRequest): + """Token pagination for GET requests.""" + + token: Annotated[Optional[str], Query()] = attr.ib(default=None) + + +class POSTTokenPagination(BaseModel): + """Token pagination model for POST requests.""" + + token: Optional[str] = None + + +@attr.s +class GETPagination(APIRequest): + """Page based pagination for GET requests.""" + + page: Annotated[Optional[str], Query()] = attr.ib(default=None) + + +class POSTPagination(BaseModel): + """Page based pagination for POST requests.""" + + page: Optional[str] = None diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py index d3fa10391..11ccfb35b 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/pagination/token_pagination.py @@ -5,9 +5,10 @@ import attr from fastapi import FastAPI -from stac_fastapi.api.models import GETTokenPagination, POSTTokenPagination from stac_fastapi.types.extension import ApiExtension +from .request import GETTokenPagination, POSTTokenPagination + @attr.s class TokenPaginationExtension(ApiExtension): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py index 3e85b406d..472c385b4 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/query.py @@ -1,4 +1,5 @@ """Query extension.""" + from typing import List, Optional import attr @@ -16,7 +17,7 @@ class QueryExtension(ApiExtension): The Query extension adds an additional `query` parameter to `/search` requests which allows the caller to perform queries against item metadata (ex. find all images with cloud cover less than 15%). - https://github.com/radiantearth/stac-api-spec/blob/master/item-search/README.md#query + https://github.com/stac-api-extensions/query """ GET = QueryExtensionGetRequest diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py index 8b282884a..ad7f461c3 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/query/request.py @@ -3,7 +3,9 @@ from typing import Any, Dict, Optional import attr -from pydantic import BaseModel +from fastapi import Query +from pydantic import BaseModel, Field +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest @@ -12,10 +14,24 @@ class QueryExtensionGetRequest(APIRequest): """Query Extension GET request model.""" - query: Optional[str] = attr.ib(default=None) + query: Annotated[ + Optional[str], + Query( + description="Allows additional filtering based on the properties of Item objects", # noqa: E501 + json_schema_extra={ + "example": '{"eo:cloud_cover": {"gte": 95}}', + }, + ), + ] = attr.ib(default=None) class QueryExtensionPostRequest(BaseModel): """Query Extension POST request model.""" - query: Optional[Dict[str, Dict[str, Any]]] + query: Optional[Dict[str, Dict[str, Any]]] = Field( + None, + description="Allows additional filtering based on the properties of Item objects", # noqa: E501 + json_schema_extra={ + "example": {"eo:cloud_cover": {"gte": 95}}, + }, + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py index c19f40dba..e1c22eea3 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/request.py @@ -1,23 +1,49 @@ -# encoding: utf-8 """Request model for the Sort Extension.""" from typing import List, Optional import attr -from pydantic import BaseModel +from fastapi import Query +from pydantic import BaseModel, Field from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel +from typing_extensions import Annotated from stac_fastapi.types.search import APIRequest, str2list +def _sort_converter( + val: Annotated[ + Optional[str], + Query( + description="An array of property names, prefixed by either '+' for ascending or '-' for descending. If no prefix is provided, '+' is assumed.", # noqa: E501 + json_schema_extra={ + "example": "-gsd,-datetime", + }, + ), + ], +) -> Optional[List[str]]: + return str2list(val) + + @attr.s class SortExtensionGetRequest(APIRequest): """Sortby Parameter for GET requests.""" - sortby: Optional[str] = attr.ib(default=None, converter=str2list) + sortby: Optional[List[str]] = attr.ib(default=None, converter=_sort_converter) class SortExtensionPostRequest(BaseModel): """Sortby parameter for POST requests.""" - sortby: Optional[List[PostSortModel]] + sortby: Optional[List[PostSortModel]] = Field( + None, + description="An array of property (field) names, and direction in form of '{'field': '', 'direction':''}'", # noqa: E501 + json_schema_extra={ + "example": [ + { + "field": "properties.created", + "direction": "asc", + } + ], + }, + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py index 5dd96cfa6..4b27d8d0e 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/sort/sort.py @@ -1,4 +1,5 @@ """Sort extension.""" + from typing import List, Optional import attr diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index 0ebcc6194..4e940a0ea 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -4,12 +4,13 @@ import attr from fastapi import APIRouter, Body, FastAPI -from stac_pydantic import Collection, Item +from stac_pydantic import Collection, Item, ItemCollection +from stac_pydantic.shared import MimeTypes from starlette.responses import JSONResponse, Response +from typing_extensions import Annotated from stac_fastapi.api.models import CollectionUri, ItemUri from stac_fastapi.api.routes import create_async_endpoint -from stac_fastapi.types import stac as stac_types from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import AsyncBaseTransactionsClient, BaseTransactionsClient from stac_fastapi.types.extension import ApiExtension @@ -19,16 +20,21 @@ class PostItem(CollectionUri): """Create Item.""" - item: Union[stac_types.Item, stac_types.ItemCollection] = attr.ib( - default=Body(None) - ) + item: Annotated[Union[Item, ItemCollection], Body()] = attr.ib(default=None) @attr.s class PutItem(ItemUri): """Update Item.""" - item: stac_types.Item = attr.ib(default=Body(None)) + item: Annotated[Item, Body()] = attr.ib(default=None) + + +@attr.s +class PutCollection(CollectionUri): + """Update Collection.""" + + collection: Annotated[Collection, Body()] = attr.ib(default=None) @attr.s @@ -44,17 +50,20 @@ class TransactionExtension(ApiExtension): PUT /collections/{collection_id}/items DELETE /collections/{collection_id}/items - https://github.com/radiantearth/stac-api-spec/blob/master/ogcapi-features/extensions/transaction/README.md + https://github.com/stac-api-extensions/transaction + https://github.com/stac-api-extensions/collection-transaction Attributes: client: CRUD application logic + """ client: Union[AsyncBaseTransactionsClient, BaseTransactionsClient] = attr.ib() settings: ApiSettings = attr.ib() conformance_classes: List[str] = attr.ib( factory=lambda: [ - "https://api.stacspec.org/v1.0.0-rc.3/ogcapi-features/extensions/transaction", + "https://api.stacspec.org/v1.0.0/ogcapi-features/extensions/transaction", + "https://api.stacspec.org/v1.0.0/collections/extensions/transaction", ] ) schema_href: Optional[str] = attr.ib(default=None) @@ -66,7 +75,16 @@ def register_create_item(self): self.router.add_api_route( name="Create Item", path="/collections/{collection_id}/items", + status_code=201, response_model=Item if self.settings.enable_response_models else None, + responses={ + 201: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -81,6 +99,14 @@ def register_update_item(self): name="Update Item", path="/collections/{collection_id}/items/{item_id}", response_model=Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -95,6 +121,14 @@ def register_delete_item(self): name="Delete Item", path="/collections/{collection_id}/items/{item_id}", response_model=Item if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.geojson.value: {}, + }, + "model": Item, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, @@ -102,19 +136,31 @@ def register_delete_item(self): endpoint=create_async_endpoint(self.client.delete_item, ItemUri), ) + def register_patch_item(self): + """Register patch item endpoint (PATCH + /collections/{collection_id}/items/{item_id}).""" + raise NotImplementedError + def register_create_collection(self): """Register create collection endpoint (POST /collections).""" self.router.add_api_route( name="Create Collection", path="/collections", + status_code=201, response_model=Collection if self.settings.enable_response_models else None, + responses={ + 201: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=create_async_endpoint( - self.client.create_collection, stac_types.Collection - ), + endpoint=create_async_endpoint(self.client.create_collection, Collection), ) def register_update_collection(self): @@ -123,13 +169,19 @@ def register_update_collection(self): name="Update Collection", path="/collections/{collection_id}", response_model=Collection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["PUT"], - endpoint=create_async_endpoint( - self.client.update_collection, stac_types.Collection - ), + endpoint=create_async_endpoint(self.client.update_collection, PutCollection), ) def register_delete_collection(self): @@ -138,15 +190,25 @@ def register_delete_collection(self): name="Delete Collection", path="/collections/{collection_id}", response_model=Collection if self.settings.enable_response_models else None, + responses={ + 200: { + "content": { + MimeTypes.json.value: {}, + }, + "model": Collection, + } + }, response_class=self.response_class, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["DELETE"], - endpoint=create_async_endpoint( - self.client.delete_collection, CollectionUri - ), + endpoint=create_async_endpoint(self.client.delete_collection, CollectionUri), ) + def register_patch_collection(self): + """Register patch collection endpoint (PATCH /collections/{collection_id}).""" + raise NotImplementedError + def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py index 0ae3b0b25..d35c4c8f9 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/__init__.py @@ -1,9 +1,5 @@ """stac_api.extensions.third_party module.""" from .bulk_transactions import BulkTransactionExtension -from .free_text import FreeTextExtension -__all__ = ( - "BulkTransactionExtension", - "FreeTextExtension", -) +__all__ = ("BulkTransactionExtension",) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index 9fa96ff2b..d1faa5c0f 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -1,4 +1,5 @@ """Bulk transactions extension.""" + import abc from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -109,9 +110,7 @@ class BulkTransactionExtension(ApiExtension): } """ - client: Union[ - AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient - ] = attr.ib() + client: Union[AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient] = attr.ib() conformance_classes: List[str] = attr.ib(default=list()) schema_href: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/__init__.py deleted file mode 100644 index 62c0dee1d..000000000 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Query extension module.""" - -from .free_text import FreeTextExtension - -__all__ = ["FreeTextExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/free_text.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/free_text.py deleted file mode 100644 index 06177ef79..000000000 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/free_text.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Free-text extension.""" - -from typing import List, Optional - -import attr -from fastapi import FastAPI - -from stac_fastapi.types.extension import ApiExtension - -from .request import FreeTextExtensionGetRequest, FreeTextExtensionPostRequest - - -@attr.s -class FreeTextExtension(ApiExtension): - """Free-text Extension. - - The Free-text extension adds an additional `q` parameter to `/search` requests which - allows the caller to perform free-text queries against STAC metadata. - https://github.com/stac-api-extensions/freetext-search/README.md - """ - - GET = FreeTextExtensionGetRequest - POST = FreeTextExtensionPostRequest - - conformance_classes: List[str] = attr.ib( - factory=lambda: [ - "https://api.stacspec.org/v1.0.0-rc.1/item-search#free-text", - "https://api.stacspec.org/v1.0.0-rc.1/item-search#advanced-free-text", - "https://api.stacspec.org/v1.0.0-rc.1/collection-search#free-text", - "https://api.stacspec.org/v1.0.0-rc.1/collection-search#advanced-free-text", - "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#free-text", - "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#advanced-free-text", - ] - ) - schema_href: Optional[str] = attr.ib(default=None) - - def register(self, app: FastAPI) -> None: - """Register the extension with a FastAPI application. - - Args: - app: target FastAPI application. - - Returns: - None - """ - pass diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/request.py deleted file mode 100644 index bac27049d..000000000 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/free_text/request.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Request model for the Free-text extension.""" - -from typing import Optional - -import attr -from pydantic import BaseModel - -from stac_fastapi.types.search import APIRequest - - -@attr.s -class FreeTextExtensionGetRequest(APIRequest): - """Free-text Extension GET request model.""" - - q: Optional[str] = attr.ib(default=None) - - -class FreeTextExtensionPostRequest(BaseModel): - """Free-text Extension POST request model.""" - - q: Optional[str] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/version.py b/stac_fastapi/extensions/stac_fastapi/extensions/version.py index bb0c7c379..7296e8a98 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/version.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/version.py @@ -1,2 +1,2 @@ """Library version.""" -__version__ = "2.4.9" +__version__ = "3.0.0b2" diff --git a/stac_fastapi/extensions/tests/test_aggregation.py b/stac_fastapi/extensions/tests/test_aggregation.py new file mode 100644 index 000000000..480cc669f --- /dev/null +++ b/stac_fastapi/extensions/tests/test_aggregation.py @@ -0,0 +1,134 @@ +from typing import Iterator + +import pytest +from fastapi import Depends, FastAPI +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.extensions.core import AggregationExtension +from stac_fastapi.extensions.core.aggregation.client import BaseAggregationClient +from stac_fastapi.extensions.core.aggregation.request import ( + AggregationExtensionGetRequest, +) +from stac_fastapi.extensions.core.aggregation.types import ( + Aggregation, + AggregationCollection, +) +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient + + +class DummyCoreClient(BaseCoreClient): + def all_collections(self, *args, **kwargs): + raise NotImplementedError + + def get_collection(self, *args, **kwargs): + raise NotImplementedError + + def get_item(self, *args, **kwargs): + raise NotImplementedError + + def get_search(self, *args, **kwargs): + raise NotImplementedError + + def post_search(self, *args, **kwargs): + raise NotImplementedError + + def item_collection(self, *args, **kwargs): + raise NotImplementedError + + +def test_get_aggregations(client: TestClient) -> None: + response = client.get("/aggregations") + assert response.is_success, response.text + assert response.json()["aggregations"] == [ + {"name": "total_count", "data_type": "integer"} + ] + assert AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(**response.json()["aggregations"][0])], + ) + + +def test_get_aggregate(client: TestClient) -> None: + response = client.get("/aggregate") + assert response.is_success, response.text + assert response.json()["aggregations"] == [] + assert AggregationCollection( + type="AggregationCollection", aggregations=response.json()["aggregations"] + ) + + +def test_post_aggregations(client: TestClient) -> None: + response = client.post("/aggregations") + assert response.is_success, response.text + assert response.json()["aggregations"] == [ + {"name": "total_count", "data_type": "integer"} + ] + assert AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(**response.json()["aggregations"][0])], + ) + + +def test_post_aggregate(client: TestClient) -> None: + response = client.post("/aggregate", content="{}") + assert response.is_success, response.text + assert response.json()["aggregations"] == [] + assert AggregationCollection( + type="AggregationCollection", aggregations=response.json()["aggregations"] + ) + + +@pytest.fixture +def client( + core_client: DummyCoreClient, aggregations_client: BaseAggregationClient +) -> Iterator[TestClient]: + settings = ApiSettings() + api = StacApi( + settings=settings, + client=core_client, + extensions=[ + AggregationExtension(client=aggregations_client), + ], + ) + with TestClient(api.app) as client: + yield client + + +@pytest.fixture +def core_client() -> DummyCoreClient: + return DummyCoreClient() + + +@pytest.fixture +def aggregations_client() -> BaseAggregationClient: + return BaseAggregationClient() + + +def test_agg_get_query(): + """test AggregationExtensionGetRequest model.""" + app = FastAPI() + + @app.get("/test") + def test(query=Depends(AggregationExtensionGetRequest)): + return query + + with TestClient(app) as client: + response = client.get("/test") + assert response.is_success + params = response.json() + assert not params["collections"] + assert not params["aggregations"] + + response = client.get( + "/test", + params={ + "collections": "collection1,collection2", + "aggregations": "prop1,prop2", + }, + ) + assert response.is_success + params = response.json() + assert params["collections"] == ["collection1", "collection2"] + assert params["aggregations"] == ["prop1", "prop2"] diff --git a/stac_fastapi/extensions/tests/test_filter.py b/stac_fastapi/extensions/tests/test_filter.py new file mode 100644 index 000000000..a13fb14c9 --- /dev/null +++ b/stac_fastapi/extensions/tests/test_filter.py @@ -0,0 +1,119 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions.core import FilterExtension +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient + + +class DummyCoreClient(BaseCoreClient): + def all_collections(self, *args, **kwargs): + raise NotImplementedError + + def get_collection(self, *args, **kwargs): + raise NotImplementedError + + def get_item(self, *args, **kwargs): + raise NotImplementedError + + def get_search(self, *args, **kwargs): + _ = kwargs.pop("request", None) + return kwargs + + def post_search(self, *args, **kwargs): + return args[0].model_dump() + + def item_collection(self, *args, **kwargs): + raise NotImplementedError + + +@pytest.fixture +def client() -> Iterator[TestClient]: + settings = ApiSettings() + extensions = [FilterExtension()] + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=extensions, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=create_post_request_model(extensions), + ) + with TestClient(api.app) as client: + yield client + + +def test_search_filter_post_filter_lang_default(client: TestClient): + """Test search POST endpoint with filter ext.""" + response = client.post( + "/search", + json={ + "collections": ["test"], + "filter": {"op": "=", "args": [{"property": "test_property"}, "test-value"]}, + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert response_dict["filter_lang"] == "cql2-json" + + +def test_search_filter_post_filter_lang_non_default(client: TestClient): + """Test search POST endpoint with filter ext.""" + filter_lang_value = "cql2-text" + response = client.post( + "/search", + json={ + "collections": ["test"], + "filter": {"op": "=", "args": [{"property": "test_property"}, "test-value"]}, + "filter-lang": filter_lang_value, + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert response_dict["filter_lang"] == filter_lang_value + + +def test_search_filter_get(client: TestClient): + """Test search GET endpoint with filter ext.""" + response = client.get( + "/search", + params={ + "filter": "id='item_id' AND collection='collection_id'", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert not response_dict["collections"] + assert response_dict["filter"] == "id='item_id' AND collection='collection_id'" + assert not response_dict["filter_crs"] + assert response_dict["filter_lang"] == "cql2-text" + + response = client.get( + "/search", + params={ + "filter": {"op": "=", "args": [{"property": "id"}, "test-item"]}, + "filter-lang": "cql2-json", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert not response_dict["collections"] + assert ( + response_dict["filter"] + == "{'op': '=', 'args': [{'property': 'id'}, 'test-item']}" + ) + assert not response_dict["filter_crs"] + assert response_dict["filter_lang"] == "cql2-json" + + response = client.get( + "/search", + params={ + "collections": "collection1,collection2", + }, + ) + assert response.is_success, response.json() + response_dict = response.json() + assert response_dict["collections"] == ["collection1", "collection2"] diff --git a/stac_fastapi/extensions/tests/test_free_text.py b/stac_fastapi/extensions/tests/test_free_text.py new file mode 100644 index 000000000..362d96025 --- /dev/null +++ b/stac_fastapi/extensions/tests/test_free_text.py @@ -0,0 +1,252 @@ +# noqa: E501 +"""test freetext extension.""" + + +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import ( + ItemCollectionUri, + create_get_request_model, + create_post_request_model, + create_request_model, +) +from stac_fastapi.extensions.core import FreeTextExtension +from stac_fastapi.extensions.core.free_text import FreeTextConformanceClasses +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient + + +class DummyCoreClient(BaseCoreClient): + def all_collections(self, *args, **kwargs): + return kwargs.pop("q", None) + + def get_collection(self, *args, **kwargs): + raise NotImplementedError + + def get_item(self, *args, **kwargs): + raise NotImplementedError + + def get_search(self, *args, **kwargs): + return kwargs.pop("q", None) + + def post_search(self, *args, **kwargs): + return args[0].q + + def item_collection(self, *args, **kwargs): + return kwargs.pop("q", None) + + +def test_search_free_text_search(): + """Test search endpoints with free-text ext.""" + settings = ApiSettings() + extensions = [ + FreeTextExtension( + conformance_classes=[FreeTextConformanceClasses.SEARCH_BASIC.value] + ) + ] + + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=extensions, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=create_post_request_model(extensions), + ) + with TestClient(api.app) as client: + response = client.get("/conformance") + assert response.is_success, response.json() + response_dict = response.json() + assert ( + FreeTextConformanceClasses.SEARCH_BASIC.value in response_dict["conformsTo"] + ) + + # /search - GET, no free-text + response = client.get( + "/search", + params={"collections": ["test"]}, + ) + assert response.is_success + assert not response.text + + # /search - GET, free-text option + response = client.get( + "/search", + params={ + "collections": ["test"], + "q": "ocean,coast", + }, + ) + assert response.is_success, response.text + assert response.json() == "ocean,coast" + + # /search - POST, no free-text + response = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + assert response.is_success + assert not response.text + + # /search - POST, free-text option + response = client.post( + "/search", + json={ + "collections": ["test"], + "q": "ocean,coast", + }, + ) + + assert response.is_success, response.text + assert response.json() == "ocean,coast" + + +def test_search_free_text_search_advances(): + """Test search endpoints with free-text ext.""" + settings = ApiSettings() + extensions = [ + FreeTextExtension( + conformance_classes=[FreeTextConformanceClasses.SEARCH_ADVANCED.value] + ) + ] + + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=extensions, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=create_post_request_model(extensions), + ) + with TestClient(api.app) as client: + response = client.get("/conformance") + assert response.is_success, response.json() + response_dict = response.json() + assert ( + FreeTextConformanceClasses.SEARCH_ADVANCED.value + in response_dict["conformsTo"] + ) + + # /search - GET, no free-text + response = client.get( + "/search", + params={"collections": ["test"]}, + ) + assert response.is_success + assert not response.text + + # /search - GET, free-text option + response = client.get( + "/search", + params={ + "collections": ["test"], + "q": "+ocean,-coast", + }, + ) + assert response.is_success, response.text + assert response.json() == "+ocean,-coast" + + # /search - POST, no free-text + response = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + assert response.is_success + assert not response.text + + # /search - POST, free-text option + response = client.post( + "/search", + json={ + "collections": ["test"], + "q": "+ocean,-coast", + }, + ) + + assert response.is_success, response.text + assert response.json() == "+ocean,-coast" + + +def test_search_free_text_complete(): + """Test search,collections,items endpoints with free-text ext.""" + settings = ApiSettings() + + free_text = FreeTextExtension( + conformance_classes=[ + FreeTextConformanceClasses.SEARCH_BASIC.value, + FreeTextConformanceClasses.ITEMS_BASIC.value, + FreeTextConformanceClasses.COLLECTIONS_BASIC.value, + ] + ) + + search_get_model = create_get_request_model([free_text]) + search_post_model = create_post_request_model([free_text]) + items_get_model = create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[free_text.GET], + ) + + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=[free_text], + search_get_request_model=search_get_model, + search_post_request_model=search_post_model, + collections_get_request_model=free_text.GET, + items_get_request_model=items_get_model, + ) + with TestClient(api.app) as client: + response = client.get("/conformance") + assert response.is_success, response.json() + response_dict = response.json() + assert ( + FreeTextConformanceClasses.SEARCH_BASIC.value in response_dict["conformsTo"] + ) + assert FreeTextConformanceClasses.ITEMS_BASIC.value in response_dict["conformsTo"] + assert ( + FreeTextConformanceClasses.COLLECTIONS_BASIC.value + in response_dict["conformsTo"] + ) + + # /search - GET, no free-text + response = client.get( + "/search", + params={"collections": ["test"]}, + ) + assert response.is_success + assert not response.text + + # /search - GET, free-text option + response = client.get( + "/search", + params={ + "collections": ["test"], + "q": "ocean,coast", + }, + ) + assert response.is_success, response.text + assert response.json() == "ocean,coast" + + # /collections - GET, free-text option + response = client.get( + "/collections", + params={ + "q": "ocean,coast", + }, + ) + assert response.is_success, response.text + assert response.json() == "ocean,coast" + + # /items - GET, free-text option + response = client.get( + "/collections/test/items", + params={ + "q": "ocean,coast", + }, + ) + assert response.is_success, response.text + assert response.json() == "ocean,coast" diff --git a/stac_fastapi/extensions/tests/test_query.py b/stac_fastapi/extensions/tests/test_query.py new file mode 100644 index 000000000..7674547a1 --- /dev/null +++ b/stac_fastapi/extensions/tests/test_query.py @@ -0,0 +1,95 @@ +import json +from typing import Iterator +from urllib.parse import quote_plus, unquote_plus + +import pytest +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.extensions.core import QueryExtension +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient + + +class DummyCoreClient(BaseCoreClient): + def all_collections(self, *args, **kwargs): + raise NotImplementedError + + def get_collection(self, *args, **kwargs): + raise NotImplementedError + + def get_item(self, *args, **kwargs): + raise NotImplementedError + + def get_search(self, *args, **kwargs): + return kwargs.pop("query", None) + + def post_search(self, *args, **kwargs): + return args[0].query + + def item_collection(self, *args, **kwargs): + raise NotImplementedError + + +@pytest.fixture +def client() -> Iterator[TestClient]: + settings = ApiSettings() + extensions = [QueryExtension()] + + api = StacApi( + settings=settings, + client=DummyCoreClient(), + extensions=extensions, + search_get_request_model=create_get_request_model(extensions), + search_post_request_model=create_post_request_model(extensions), + ) + with TestClient(api.app) as client: + yield client + + +def test_search_query_get(client: TestClient): + """Test search GET endpoints with query ext.""" + response = client.get( + "/search", + params={"collections": ["test"]}, + ) + assert response.is_success + assert not response.text + + response = client.get( + "/search", + params={ + "collections": ["test"], + "query": quote_plus( + json.dumps({"eo:cloud_cover": {"gte": 95}}), + ), + }, + ) + assert response.is_success, response.json() + query = json.loads(unquote_plus(response.json())) + assert query["eo:cloud_cover"] == {"gte": 95} + + +def test_search_query_post(client: TestClient): + """Test search POST endpoints with query ext.""" + response = client.post( + "/search", + json={ + "collections": ["test"], + }, + ) + + assert response.is_success + assert not response.text + + response = client.post( + "/search", + json={ + "collections": ["test"], + "query": {"eo:cloud_cover": {"gte": 95}}, + }, + ) + + assert response.is_success, response.json() + assert response.json()["eo:cloud_cover"] == {"gte": 95} diff --git a/stac_fastapi/extensions/tests/test_transaction.py b/stac_fastapi/extensions/tests/test_transaction.py index fc5acc2cf..689e519d2 100644 --- a/stac_fastapi/extensions/tests/test_transaction.py +++ b/stac_fastapi/extensions/tests/test_transaction.py @@ -2,13 +2,15 @@ from typing import Iterator, Union import pytest +from stac_pydantic import Collection +from stac_pydantic.item import Item +from stac_pydantic.item_collection import ItemCollection from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.extensions.core import TransactionExtension from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient -from stac_fastapi.types.stac import Item, ItemCollection class DummyCoreClient(BaseCoreClient): @@ -32,25 +34,32 @@ def item_collection(self, *args, **kwargs): class DummyTransactionsClient(BaseTransactionsClient): - """Defines a pattern for implementing the STAC transaction extension.""" + """Dummy client returning parts of the request, rather than proper STAC items.""" def create_item(self, item: Union[Item, ItemCollection], *args, **kwargs): - return {"created": True, "type": item["type"]} + return {"created": True, "type": item.type} - def update_item(self, *args, **kwargs): - raise NotImplementedError + def update_item(self, collection_id: str, item_id: str, item: Item, **kwargs): + return { + "path_collection_id": collection_id, + "path_item_id": item_id, + "type": item.type, + } - def delete_item(self, *args, **kwargs): - raise NotImplementedError + def delete_item(self, item_id: str, collection_id: str, **kwargs): + return { + "path_collection_id": collection_id, + "path_item_id": item_id, + } - def create_collection(self, *args, **kwargs): - raise NotImplementedError + def create_collection(self, collection: Collection, **kwargs): + return {"type": collection.type} - def update_collection(self, *args, **kwargs): - raise NotImplementedError + def update_collection(self, collection_id: str, collection: Collection, **kwargs): + return {"path_collection_id": collection_id, "type": collection.type} - def delete_collection(self, *args, **kwargs): - raise NotImplementedError + def delete_collection(self, collection_id: str, **kwargs): + return {"path_collection_id": collection_id} def test_create_item(client: TestClient, item: Item) -> None: @@ -69,6 +78,42 @@ def test_create_item_collection( assert response.json()["type"] == "FeatureCollection" +def test_update_item(client: TestClient, item: Item) -> None: + response = client.put( + "/collections/a-collection/items/an-item", content=json.dumps(item) + ) + assert response.is_success, response.text + assert response.json()["path_collection_id"] == "a-collection" + assert response.json()["path_item_id"] == "an-item" + assert response.json()["type"] == "Feature" + + +def test_delete_item(client: TestClient) -> None: + response = client.delete("/collections/a-collection/items/an-item") + assert response.is_success, response.text + assert response.json()["path_collection_id"] == "a-collection" + assert response.json()["path_item_id"] == "an-item" + + +def test_create_collection(client: TestClient, collection: Collection) -> None: + response = client.post("/collections", content=json.dumps(collection)) + assert response.is_success, response.text + assert response.json()["type"] == "Collection" + + +def test_update_collection(client: TestClient, collection: Collection) -> None: + response = client.put("/collections/a-collection", content=json.dumps(collection)) + assert response.is_success, response.text + assert response.json()["path_collection_id"] == "a-collection" + assert response.json()["type"] == "Collection" + + +def test_delete_collection(client: TestClient, collection: Collection) -> None: + response = client.delete("/collections/a-collection") + assert response.is_success, response.text + assert response.json()["path_collection_id"] == "a-collection" + + @pytest.fixture def client( core_client: DummyCoreClient, transactions_client: DummyTransactionsClient @@ -114,8 +159,26 @@ def item() -> Item: "id": "test_item", "geometry": {"type": "Point", "coordinates": [-105, 40]}, "bbox": [-105, 40, -105, 40], - "properties": {}, + "properties": {"datetime": "2020-06-13T13:00:00Z"}, "links": [], "assets": {}, "collection": "test_collection", } + + +@pytest.fixture +def collection() -> Collection: + return { + "type": "Collection", + "stac_version": "1.0.0", + "stac_extensions": [], + "id": "test_collection", + "description": "A test collection", + "extent": { + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": {"interval": [["2000-01-01T00:00:00Z", "2024-01-01T00:00:00Z"]]}, + }, + "links": [], + "assets": {}, + "license": "proprietary", + } diff --git a/stac_fastapi/types/setup.py b/stac_fastapi/types/setup.py index 9a06fda95..9fa0ad9ee 100644 --- a/stac_fastapi/types/setup.py +++ b/stac_fastapi/types/setup.py @@ -6,11 +6,10 @@ desc = f.read() install_requires = [ - "fastapi>=0.73.0", - "attrs", - "pydantic[dotenv]<2", - "stac_pydantic==2.0.*", - "pystac==1.*", + "fastapi-slim>=0.111.0", + "attrs>=23.2.0", + "pydantic-settings>=2", + "stac_pydantic~=3.1", "iso8601>=1.0.2,<2.2.0", ] @@ -37,6 +36,10 @@ "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", ], keywords="STAC FastAPI COG", diff --git a/stac_fastapi/types/stac_fastapi/types/config.py b/stac_fastapi/types/stac_fastapi/types/config.py index b3f22fb65..75d0bd399 100644 --- a/stac_fastapi/types/stac_fastapi/types/config.py +++ b/stac_fastapi/types/stac_fastapi/types/config.py @@ -1,7 +1,8 @@ """stac_fastapi.types.config module.""" -from typing import Optional, Set -from pydantic import BaseSettings +from typing import Optional + +from pydantic_settings import BaseSettings, SettingsConfigDict class ApiSettings(BaseSettings): @@ -18,9 +19,10 @@ class ApiSettings(BaseSettings): as distinct columns in the database. """ - # TODO: Remove `default_includes` attribute so we can use - # `pydantic.BaseSettings` instead - default_includes: Optional[Set[str]] = None + stac_fastapi_title: str = "stac-fastapi" + stac_fastapi_description: str = "stac-fastapi" + stac_fastapi_version: str = "0.1" + stac_fastapi_landing_id: str = "stac-fastapi" app_host: str = "0.0.0.0" app_port: int = 8000 @@ -30,11 +32,7 @@ class ApiSettings(BaseSettings): openapi_url: str = "/api" docs_url: str = "/api.html" - class Config: - """Model config (https://pydantic-docs.helpmanual.io/usage/model_config/).""" - - extra = "allow" - env_file = ".env" + model_config = SettingsConfigDict(env_file=".env", extra="allow") class Settings: diff --git a/stac_fastapi/types/stac_fastapi/types/conformance.py b/stac_fastapi/types/stac_fastapi/types/conformance.py index 13836aaf5..840584c1b 100644 --- a/stac_fastapi/types/stac_fastapi/types/conformance.py +++ b/stac_fastapi/types/stac_fastapi/types/conformance.py @@ -1,4 +1,5 @@ """Conformance Classes.""" + from enum import Enum diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 79b18f57b..003a765ed 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -1,27 +1,43 @@ """Base clients.""" import abc +import importlib +import warnings from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin import attr from fastapi import Request +from geojson_pydantic.geometries import Geometry +from stac_pydantic import Collection, Item, ItemCollection +from stac_pydantic.api.version import STAC_API_VERSION from stac_pydantic.links import Relations from stac_pydantic.shared import BBox, MimeTypes -from stac_pydantic.version import STAC_VERSION from starlette.responses import Response -from stac_fastapi.types import stac as stac_types +from stac_fastapi.types import stac +from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.search import BaseSearchPostRequest -from stac_fastapi.types.stac import Conformance + +__all__ = [ + "NumType", + "StacType", + "BaseTransactionsClient", + "AsyncBaseTransactionsClient", + "LandingPageMixin", + "BaseCoreClient", + "AsyncBaseCoreClient", +] NumType = Union[float, int] StacType = Dict[str, Any] +api_settings = ApiSettings() + @attr.s # type:ignore class BaseTransactionsClient(abc.ABC): @@ -31,9 +47,9 @@ class BaseTransactionsClient(abc.ABC): def create_item( self, collection_id: str, - item: Union[stac_types.Item, stac_types.ItemCollection], + item: Union[Item, ItemCollection], **kwargs, - ) -> Optional[Union[stac_types.Item, Response, None]]: + ) -> Optional[Union[stac.Item, Response, None]]: """Create a new item. Called with `POST /collections/{collection_id}/items`. @@ -49,8 +65,8 @@ def create_item( @abc.abstractmethod def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + self, collection_id: str, item_id: str, item: Item, **kwargs + ) -> Optional[Union[stac.Item, Response]]: """Perform a complete update on an existing item. Called with `PUT /collections/{collection_id}/items`. It is expected @@ -70,7 +86,7 @@ def update_item( @abc.abstractmethod def delete_item( self, item_id: str, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + ) -> Optional[Union[stac.Item, Response]]: """Delete an item from a collection. Called with `DELETE /collections/{collection_id}/items/{item_id}` @@ -86,8 +102,8 @@ def delete_item( @abc.abstractmethod def create_collection( - self, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection: Collection, **kwargs + ) -> Optional[Union[stac.Collection, Response]]: """Create a new collection. Called with `POST /collections`. @@ -102,14 +118,14 @@ def create_collection( @abc.abstractmethod def update_collection( - self, collection_id: str, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection_id: str, collection: Collection, **kwargs + ) -> Optional[Union[stac.Collection, Response]]: """Perform a complete update on an existing collection. - Called with `PUT /collections/{collection_id}`. It is expected that this item - already exists. The update should do a diff against the saved collection and - perform any necessary updates. Partial updates are not supported by the - transactions extension. + Called with `PUT /collections/{collection_id}`. It is expected that this + collection already exists. The update should do a diff against the saved + collection and perform any necessary updates. Partial updates are not + supported by the transactions extension. Args: collection_id: id of the existing collection to be updated @@ -123,7 +139,7 @@ def update_collection( @abc.abstractmethod def delete_collection( self, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + ) -> Optional[Union[stac.Collection, Response]]: """Delete a collection. Called with `DELETE /collections/{collection_id}` @@ -145,9 +161,9 @@ class AsyncBaseTransactionsClient(abc.ABC): async def create_item( self, collection_id: str, - item: Union[stac_types.Item, stac_types.ItemCollection], + item: Union[Item, ItemCollection], **kwargs, - ) -> Optional[Union[stac_types.Item, Response, None]]: + ) -> Optional[Union[stac.Item, Response, None]]: """Create a new item. Called with `POST /collections/{collection_id}/items`. @@ -163,8 +179,8 @@ async def create_item( @abc.abstractmethod async def update_item( - self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + self, collection_id: str, item_id: str, item: Item, **kwargs + ) -> Optional[Union[stac.Item, Response]]: """Perform a complete update on an existing item. Called with `PUT /collections/{collection_id}/items`. It is expected @@ -183,7 +199,7 @@ async def update_item( @abc.abstractmethod async def delete_item( self, item_id: str, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Item, Response]]: + ) -> Optional[Union[stac.Item, Response]]: """Delete an item from a collection. Called with `DELETE /collections/{collection_id}/items/{item_id}` @@ -199,8 +215,8 @@ async def delete_item( @abc.abstractmethod async def create_collection( - self, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection: Collection, **kwargs + ) -> Optional[Union[stac.Collection, Response]]: """Create a new collection. Called with `POST /collections`. @@ -215,8 +231,8 @@ async def create_collection( @abc.abstractmethod async def update_collection( - self, collection_id: str, collection: stac_types.Collection, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + self, collection_id: str, collection: Collection, **kwargs + ) -> Optional[Union[stac.Collection, Response]]: """Perform a complete update on an existing collection. Called with `PUT /collections/{collection_id}`. It is expected that this item @@ -236,7 +252,7 @@ async def update_collection( @abc.abstractmethod async def delete_collection( self, collection_id: str, **kwargs - ) -> Optional[Union[stac_types.Collection, Response]]: + ) -> Optional[Union[stac.Collection, Response]]: """Delete a collection. Called with `DELETE /collections/{collection_id}` @@ -254,18 +270,18 @@ async def delete_collection( class LandingPageMixin(abc.ABC): """Create a STAC landing page (GET /).""" - stac_version: str = attr.ib(default=STAC_VERSION) - landing_page_id: str = attr.ib(default="stac-fastapi") - title: str = attr.ib(default="stac-fastapi") - description: str = attr.ib(default="stac-fastapi") + stac_version: str = attr.ib(default=STAC_API_VERSION) + landing_page_id: str = attr.ib(default=api_settings.stac_fastapi_landing_id) + title: str = attr.ib(default=api_settings.stac_fastapi_title) + description: str = attr.ib(default=api_settings.stac_fastapi_description) def _landing_page( self, base_url: str, conformance_classes: List[str], extension_schemas: List[str], - ) -> stac_types.LandingPage: - landing_page = stac_types.LandingPage( + ) -> stac.LandingPage: + landing_page = stac.LandingPage( type="Catalog", id=self.landing_page_id, title=self.title, @@ -275,35 +291,35 @@ def _landing_page( links=[ { "rel": Relations.self.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { "rel": Relations.root.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "href": base_url, }, { - "rel": "data", - "type": MimeTypes.json, + "rel": Relations.data.value, + "type": MimeTypes.json.value, "href": urljoin(base_url, "collections"), }, { "rel": Relations.conformance.value, - "type": MimeTypes.json, + "type": MimeTypes.json.value, "title": "STAC/OGC conformance classes implemented by this server", "href": urljoin(base_url, "conformance"), }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "GET", }, { "rel": Relations.search.value, - "type": MimeTypes.geojson, + "type": MimeTypes.geojson.value, "title": "STAC search", "href": urljoin(base_url, "search"), "method": "POST", @@ -311,6 +327,7 @@ def _landing_page( ], stac_extensions=extension_schemas, ) + return landing_page @@ -353,7 +370,7 @@ def list_conformance_classes(self): return base_conformance - def landing_page(self, **kwargs) -> stac_types.LandingPage: + def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -363,14 +380,46 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: """ request: Request = kwargs["request"] base_url = get_base_url(request) + landing_page = self._landing_page( base_url=base_url, conformance_classes=self.conformance_classes(), extension_schemas=[], ) + # Add Queryables link + if self.extension_is_enabled("FilterExtension"): + landing_page["links"].append( + { + "rel": Relations.queryables.value, + "type": MimeTypes.jsonschema.value, + "title": "Queryables", + "href": urljoin(base_url, "queryables"), + } + ) + + # Add Aggregation links + if self.extension_is_enabled("AggregationExtension"): + landing_page["links"].extend( + [ + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": urljoin(base_url, "aggregate"), + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": urljoin(base_url, "aggregations"), + }, + ] + ) + # Add Collections links collections = self.all_collections(request=kwargs["request"]) + for collection in collections["collections"]: landing_page["links"].append( { @@ -384,30 +433,26 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", - "href": urljoin( - str(request.base_url), request.app.openapi_url.lstrip("/") - ), + "href": str(request.url_for("openapi")), } ) # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", - "href": urljoin( - str(request.base_url), request.app.docs_url.lstrip("/") - ), + "href": str(request.url_for("swagger_ui_html")), } ) - return landing_page + return stac.LandingPage(**landing_page) - def conformance(self, **kwargs) -> stac_types.Conformance: + def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -415,12 +460,12 @@ def conformance(self, **kwargs) -> stac_types.Conformance: Returns: Conformance classes which the server conforms to. """ - return Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -439,15 +484,11 @@ def get_search( collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -458,7 +499,7 @@ def get_search( ... @abc.abstractmethod - def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Item: + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -473,7 +514,7 @@ def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Ite ... @abc.abstractmethod - def all_collections(self, **kwargs) -> stac_types.Collections: + def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -484,7 +525,7 @@ def all_collections(self, **kwargs) -> stac_types.Collections: ... @abc.abstractmethod - def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: + def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -506,7 +547,7 @@ def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` @@ -551,7 +592,7 @@ def extension_is_enabled(self, extension: str) -> bool: """Check if an api extension is enabled.""" return any([type(ext).__name__ == extension for ext in self.extensions]) - async def landing_page(self, **kwargs) -> stac_types.LandingPage: + async def landing_page(self, **kwargs) -> stac.LandingPage: """Landing page. Called with `GET /`. @@ -561,12 +602,47 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: """ request: Request = kwargs["request"] base_url = get_base_url(request) + landing_page = self._landing_page( base_url=base_url, conformance_classes=self.conformance_classes(), extension_schemas=[], ) + + # Add Queryables link + if self.extension_is_enabled("FilterExtension"): + landing_page["links"].append( + { + "rel": Relations.queryables.value, + "type": MimeTypes.jsonschema.value, + "title": "Queryables", + "href": urljoin(base_url, "queryables"), + "method": "GET", + } + ) + + # Add Aggregation links + if self.extension_is_enabled("AggregationExtension"): + landing_page["links"].extend( + [ + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": urljoin(base_url, "aggregate"), + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": urljoin(base_url, "aggregations"), + }, + ] + ) + + # Add Collections links collections = await self.all_collections(request=kwargs["request"]) + for collection in collections["collections"]: landing_page["links"].append( { @@ -580,26 +656,26 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: # Add OpenAPI URL landing_page["links"].append( { - "rel": "service-desc", - "type": "application/vnd.oai.openapi+json;version=3.0", + "rel": Relations.service_desc.value, + "type": MimeTypes.openapi.value, "title": "OpenAPI service description", - "href": urljoin(base_url, request.app.openapi_url.lstrip("/")), + "href": str(request.url_for("openapi")), } ) # Add human readable service-doc landing_page["links"].append( { - "rel": "service-doc", - "type": "text/html", + "rel": Relations.service_doc.value, + "type": MimeTypes.html.value, "title": "OpenAPI service documentation", - "href": urljoin(base_url, request.app.docs_url.lstrip("/")), + "href": str(request.url_for("swagger_ui_html")), } ) - return landing_page + return stac.LandingPage(**landing_page) - async def conformance(self, **kwargs) -> stac_types.Conformance: + async def conformance(self, **kwargs) -> stac.Conformance: """Conformance classes. Called with `GET /conformance`. @@ -607,12 +683,12 @@ async def conformance(self, **kwargs) -> stac_types.Conformance: Returns: Conformance classes which the server conforms to. """ - return Conformance(conformsTo=self.conformance_classes()) + return stac.Conformance(conformsTo=self.conformance_classes()) @abc.abstractmethod async def post_search( self, search_request: BaseSearchPostRequest, **kwargs - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (POST). Called with `POST /search`. @@ -631,15 +707,11 @@ async def get_search( collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, - query: Optional[str] = None, - token: Optional[str] = None, - fields: Optional[List[str]] = None, - sortby: Optional[str] = None, - intersects: Optional[str] = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Cross catalog search (GET). Called with `GET /search`. @@ -650,9 +722,7 @@ async def get_search( ... @abc.abstractmethod - async def get_item( - self, item_id: str, collection_id: str, **kwargs - ) -> stac_types.Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item: """Get item by id. Called with `GET /collections/{collection_id}/items/{item_id}`. @@ -667,7 +737,7 @@ async def get_item( ... @abc.abstractmethod - async def all_collections(self, **kwargs) -> stac_types.Collections: + async def all_collections(self, **kwargs) -> stac.Collections: """Get all available collections. Called with `GET /collections`. @@ -678,9 +748,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: ... @abc.abstractmethod - async def get_collection( - self, collection_id: str, **kwargs - ) -> stac_types.Collection: + async def get_collection(self, collection_id: str, **kwargs) -> stac.Collection: """Get collection by id. Called with `GET /collections/{collection_id}`. @@ -702,7 +770,7 @@ async def item_collection( limit: int = 10, token: str = None, **kwargs, - ) -> stac_types.ItemCollection: + ) -> stac.ItemCollection: """Get all items from a specific collection. Called with `GET /collections/{collection_id}/items` @@ -718,53 +786,16 @@ async def item_collection( ... -@attr.s -class AsyncBaseFiltersClient(abc.ABC): - """Defines a pattern for implementing the STAC filter extension.""" - - async def get_queryables( - self, collection_id: Optional[str] = None, **kwargs - ) -> Dict[str, Any]: - """Get the queryables available for the given collection_id. - - If collection_id is None, returns the intersection of all queryables over all - collections. - - This base implementation returns a blank queryable schema. This is not allowed - under OGC CQL but it is allowed by the STAC API Filter Extension - https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables - """ - return { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "https://example.org/queryables", - "type": "object", - "title": "Queryables for Example STAC API", - "description": "Queryable names for the example STAC API Item Search filter.", - "properties": {}, - } - - -@attr.s -class BaseFiltersClient(abc.ABC): - """Defines a pattern for implementing the STAC filter extension.""" - - def get_queryables( - self, collection_id: Optional[str] = None, **kwargs - ) -> Dict[str, Any]: - """Get the queryables available for the given collection_id. - - If collection_id is None, returns the intersection of all queryables over all - collections. +# TODO: remove for 3.0.0 final release +def __getattr__(name: str) -> Any: + if name in ["AsyncBaseFiltersClient", "BaseFiltersClient"]: + warnings.warn( + f"""importing {name} from `stac_fastapi.types.core` is deprecated, + please import it from `stac_fastapi.extensions.core.filter.client`.""", + DeprecationWarning, + stacklevel=2, + ) + clients = importlib.import_module("stac_fastapi.extensions.core.filter.client") + return getattr(clients, name) - This base implementation returns a blank queryable schema. This is not allowed - under OGC CQL but it is allowed by the STAC API Filter Extension - https://github.com/stac-api-extensions/filter#queryables - """ - return { - "$schema": "https://json-schema.org/draft/2019-09/schema", - "$id": "https://example.org/queryables", - "type": "object", - "title": "Queryables for Example STAC API", - "description": "Queryable names for the example STAC API Item Search filter.", - "properties": {}, - } + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/stac_fastapi/types/stac_fastapi/types/extension.py b/stac_fastapi/types/stac_fastapi/types/extension.py index 732a907bf..55a4a123c 100644 --- a/stac_fastapi/types/stac_fastapi/types/extension.py +++ b/stac_fastapi/types/stac_fastapi/types/extension.py @@ -1,4 +1,5 @@ """Base api extension.""" + import abc from typing import List, Optional diff --git a/stac_fastapi/types/stac_fastapi/types/requests.py b/stac_fastapi/types/stac_fastapi/types/requests.py index c9be8b6f6..4d94736a7 100644 --- a/stac_fastapi/types/stac_fastapi/types/requests.py +++ b/stac_fastapi/types/stac_fastapi/types/requests.py @@ -9,6 +9,4 @@ def get_base_url(request: Request) -> str: if not app.state.router_prefix: return str(request.base_url) else: - return "{}{}/".format( - str(request.base_url), app.state.router_prefix.lstrip("/") - ) + return "{}{}/".format(str(request.base_url), app.state.router_prefix.lstrip("/")) diff --git a/stac_fastapi/types/stac_fastapi/types/rfc3339.py b/stac_fastapi/types/stac_fastapi/types/rfc3339.py index 43baa8d53..77ec993dd 100644 --- a/stac_fastapi/types/stac_fastapi/types/rfc3339.py +++ b/stac_fastapi/types/stac_fastapi/types/rfc3339.py @@ -1,10 +1,11 @@ """rfc3339.""" + import re from datetime import datetime, timezone from typing import Optional, Tuple, Union import iso8601 -from pystac.utils import datetime_to_str +from fastapi import HTTPException RFC33339_PATTERN = ( r"^(\d\d\d\d)\-(\d\d)\-(\d\d)(T|t)(\d\d):(\d\d):(\d\d)([.]\d+)?" @@ -19,6 +20,34 @@ ] +# Borrowed from pystac - https://github.com/stac-utils/pystac/blob/f5e4cf4a29b62e9ef675d4a4dac7977b09f53c8f/pystac/utils.py#L370-L394 +def datetime_to_str(dt: datetime, timespec: str = "auto") -> str: + """Converts a :class:`datetime.datetime` instance to an ISO8601 string in the + `RFC 3339, section 5.6 + `__ format required by + the :stac-spec:`STAC Spec `. + + Args: + dt : The datetime to convert. + timespec: An optional argument that specifies the number of additional + terms of the time to include. Valid options are 'auto', 'hours', + 'minutes', 'seconds', 'milliseconds' and 'microseconds'. The default value + is 'auto'. + + Returns: + str: The ISO8601 (RFC 3339) formatted string representing the datetime. + """ + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + + timestamp = dt.isoformat(timespec=timespec) + zulu = "+00:00" + if timestamp.endswith(zulu): + timestamp = f"{timestamp[: -len(zulu)]}Z" + + return timestamp + + def rfc3339_str_to_datetime(s: str) -> datetime: """Convert a string conforming to RFC 3339 to a :class:`datetime.datetime`. @@ -45,48 +74,78 @@ def rfc3339_str_to_datetime(s: str) -> datetime: return iso8601.parse_date(s) -def str_to_interval( - interval: str, -) -> Optional[DateTimeType]: - """Extract a tuple of datetimes from an interval string. +def parse_single_date(date_str: str) -> datetime: + """ + Parse a single RFC3339 date string into a datetime object. + + Args: + date_str (str): A string representing the date in RFC3339 format. + + Returns: + datetime: A datetime object parsed from the date_str. + + Raises: + ValueError: If the date_str is empty or contains the placeholder '..'. + """ + if ".." in date_str or not date_str: + raise ValueError("Invalid date format.") + return rfc3339_str_to_datetime(date_str) - Interval strings are defined by - OGC API - Features Part 1 for the datetime query parameter value. These follow the - form '1985-04-12T23:20:50.52Z/1986-04-12T23:20:50.52Z', and allow either the start - or end (but not both) to be open-ended with '..' or ''. + +def str_to_interval(interval: Optional[str]) -> Optional[DateTimeType]: + """ + Extract a single datetime object or a tuple of datetime objects from an + interval string defined by the OGC API. The interval can either be a + single datetime or a range with start and end datetime. Args: - interval (str) : The interval string to convert to a :class:`datetime.datetime` - tuple. + interval (Optional[str]): The interval string to convert to datetime objects, + or None if no datetime is specified. + + Returns: + Optional[DateTimeType]: A single datetime.datetime object, a tuple of + datetime.datetime objects, or None if input is None. Raises: - ValueError: If the string is not a valid interval string. + HTTPException: If the string is not valid for various reasons such as being empty, + having more than one slash, or if date formats are invalid. """ + if interval is None: + return None + if not interval: - raise ValueError("Empty interval string is invalid.") + raise HTTPException(status_code=400, detail="Empty interval string is invalid.") values = interval.split("/") - if len(values) == 1: - # Single date for == date case - return rfc3339_str_to_datetime(values[0]) - elif len(values) > 2: - raise ValueError( - f"Interval string '{interval}' contains more than one forward slash." + if len(values) > 2: + raise HTTPException( + status_code=400, + detail="Interval string contains more than one forward slash.", ) - start = None - end = None - if values[0] not in ["..", ""]: - start = rfc3339_str_to_datetime(values[0]) - if values[1] not in ["..", ""]: - end = rfc3339_str_to_datetime(values[1]) + try: + start = parse_single_date(values[0]) if values[0] not in ["..", ""] else None + if len(values) == 1: + return start + + end = ( + parse_single_date(values[1]) + if len(values) > 1 and values[1] not in ["..", ""] + else None + ) + except (ValueError, iso8601.ParseError) as e: + raise HTTPException(status_code=400, detail=str(e)) if start is None and end is None: - raise ValueError("Double open-ended intervals are not allowed.") + raise HTTPException( + status_code=400, detail="Double open-ended intervals are not allowed." + ) if start is not None and end is not None and start > end: - raise ValueError("Start datetime cannot be before end datetime.") - else: - return start, end + raise HTTPException( + status_code=400, detail="Start datetime cannot be before end datetime." + ) + + return start, end def now_in_utc() -> datetime: diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index 0851c1d30..064ae10cb 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -1,87 +1,35 @@ """stac_fastapi.types.search module. -# TODO: replace with stac-pydantic """ -import abc -import operator -from datetime import datetime -from enum import auto -from types import DynamicClassAttribute -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Dict, List, Optional, Union import attr -from geojson_pydantic.geometries import ( - GeometryCollection, - LineString, - MultiLineString, - MultiPoint, - MultiPolygon, - Point, - Polygon, - _GeometryBase, -) -from pydantic import BaseModel, ConstrainedInt, Field, validator -from pydantic.errors import NumberNotGtError -from pydantic.validators import int_validator +from fastapi import Query +from pydantic import Field, PositiveInt +from pydantic.functional_validators import AfterValidator +from stac_pydantic.api import Search from stac_pydantic.shared import BBox -from stac_pydantic.utils import AutoValueEnum +from typing_extensions import Annotated from stac_fastapi.types.rfc3339 import DateTimeType, str_to_interval -# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 -NumType = Union[float, int] - - -class Limit(ConstrainedInt): - """An positive integer that maxes out at 10,000.""" - - ge: int = 1 - le: int = 10_000 - - @classmethod - def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: - """Yield the relevant validators.""" - yield int_validator - yield cls.validate - - @classmethod - def validate(cls, value: int) -> int: - """Validate the integer value.""" - if value < cls.ge: - raise NumberNotGtError(limit_value=cls.ge) - if value > cls.le: - return cls.le - return value - - -class Operator(str, AutoValueEnum): - """Defines the set of operators supported by the API.""" - eq = auto() - ne = auto() - lt = auto() - lte = auto() - gt = auto() - gte = auto() +def crop(v: PositiveInt) -> PositiveInt: + """Crop value to 10,000.""" + limit = 10_000 + if v > limit: + v = limit + return v - # TODO: These are defined in the spec but aren't currently implemented by the api - # startsWith = auto() - # endsWith = auto() - # contains = auto() - # in = auto() - @DynamicClassAttribute - def operator(self) -> Callable[[Any, Any], bool]: - """Return python operator.""" - return getattr(operator, self._value_) - - -def str2list(x: str) -> Optional[List]: +def str2list(x: str) -> Optional[List[str]]: """Convert string to list base on , delimiter.""" if x: return x.split(",") + return None + def str2bbox(x: str) -> Optional[BBox]: """Convert string to BBox based on , delimiter.""" @@ -90,9 +38,76 @@ def str2bbox(x: str) -> Optional[BBox]: assert len(t) == 4 return t + return None + + +def _collection_converter( + val: Annotated[ + Optional[str], + Query( + description="Array of collection Ids to search for items.", + json_schema_extra={ + "example": "collection1,collection2", + }, + ), + ] = None, +) -> Optional[List[str]]: + return str2list(val) + + +def _ids_converter( + val: Annotated[ + Optional[str], + Query( + description="Array of Item ids to return.", + json_schema_extra={ + "example": "item1,item2", + }, + ), + ] = None, +) -> Optional[List[str]]: + return str2list(val) + + +def _bbox_converter( + val: Annotated[ + Optional[str], + Query( + description="Only return items intersecting this bounding box. Mutually exclusive with **intersects**.", # noqa: E501 + json_schema_extra={ + "example": "-175.05,-85.05,175.05,85.05", + }, + ), + ] = None, +) -> Optional[BBox]: + return str2bbox(val) + + +def _datetime_converter( + val: Annotated[ + Optional[str], + Query( + description="""Only return items that have a temporal property that intersects this value.\n +Either a date-time or an interval, open or closed. Date and time expressions adhere to RFC 3339. Open intervals are expressed using double-dots.""", # noqa: E501 + openapi_examples={ + "datetime": {"value": "2018-02-12T23:20:50Z"}, + "closed-interval": {"value": "2018-02-12T00:00:00Z/2018-03-18T12:31:12Z"}, + "open-interval-from": {"value": "2018-02-12T00:00:00Z/.."}, + "open-interval-to": {"value": "../2018-03-18T12:31:12Z"}, + }, + ), + ] = None, +): + return str_to_interval(val) -@attr.s # type:ignore -class APIRequest(abc.ABC): + +# Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 +NumType = Union[float, int] +Limit = Annotated[PositiveInt, AfterValidator(crop)] + + +@attr.s +class APIRequest: """Generic API Request base class.""" def kwargs(self) -> Dict: @@ -105,118 +120,71 @@ def kwargs(self) -> Dict: class BaseSearchGetRequest(APIRequest): """Base arguments for GET Request.""" - collections: Optional[str] = attr.ib(default=None, converter=str2list) - ids: Optional[str] = attr.ib(default=None, converter=str2list) - bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) - intersects: Optional[str] = attr.ib(default=None, converter=str2list) - datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) - limit: Optional[int] = attr.ib(default=10) - - -class BaseSearchPostRequest(BaseModel): - """Search model. - - Replace base model in STAC-pydantic as it includes additional fields, not in the core - model. - https://github.com/radiantearth/stac-api-spec/tree/master/item-search#query-parameter-table - - PR to fix this: - https://github.com/stac-utils/stac-pydantic/pull/100 - """ - - collections: Optional[List[str]] - ids: Optional[List[str]] - bbox: Optional[BBox] - intersects: Optional[ - Union[ - Point, - MultiPoint, - LineString, - MultiLineString, - Polygon, - MultiPolygon, - GeometryCollection, - ] - ] - datetime: Optional[DateTimeType] - limit: Optional[Limit] = Field(default=10) - - @property - def start_date(self) -> Optional[datetime]: - """Extract the start date from the datetime string.""" - return self.datetime[0] if self.datetime else None - - @property - def end_date(self) -> Optional[datetime]: - """Extract the end date from the datetime string.""" - return self.datetime[1] if self.datetime else None - - @validator("intersects") - def validate_spatial(cls, v, values): - """Check bbox and intersects are not both supplied.""" - if v and values["bbox"]: - raise ValueError("intersects and bbox parameters are mutually exclusive") - return v - - @validator("bbox", pre=True) - def validate_bbox(cls, v: Union[str, BBox]) -> BBox: - """Check order of supplied bbox coordinates.""" - if v: - if type(v) == str: - v = str2bbox(v) - # Validate order - if len(v) == 4: - xmin, ymin, xmax, ymax = v - else: - xmin, ymin, min_elev, xmax, ymax, max_elev = v - if max_elev < min_elev: - raise ValueError( - "Maximum elevation must greater than minimum elevation" - ) - - if xmax < xmin: - raise ValueError( - "Maximum longitude must be greater than minimum longitude" - ) - - if ymax < ymin: - raise ValueError( - "Maximum longitude must be greater than minimum longitude" - ) - - # Validate against WGS84 - if xmin < -180 or ymin < -90 or xmax > 180 or ymax > 90: - raise ValueError("Bounding box must be within (-180, -90, 180, 90)") - - return v - - @validator("datetime", pre=True) - def validate_datetime(cls, v: Union[str, DateTimeType]) -> DateTimeType: - """Parse datetime.""" - if type(v) == str: - v = str_to_interval(v) - return v - - @property - def spatial_filter(self) -> Optional[_GeometryBase]: - """Return a geojson-pydantic object representing the spatial filter for the search - request. - - Check for both because the ``bbox`` and ``intersects`` parameters are - mutually exclusive. - """ - if self.bbox: - return Polygon( - coordinates=[ - [ - [self.bbox[0], self.bbox[3]], - [self.bbox[2], self.bbox[3]], - [self.bbox[2], self.bbox[1]], - [self.bbox[0], self.bbox[1]], - [self.bbox[0], self.bbox[3]], - ] - ] - ) - if self.intersects: - return self.intersects - return + collections: Optional[List[str]] = attr.ib( + default=None, converter=_collection_converter + ) + ids: Optional[List[str]] = attr.ib(default=None, converter=_ids_converter) + bbox: Optional[BBox] = attr.ib(default=None, converter=_bbox_converter) + intersects: Annotated[ + Optional[str], + Query( + description="""Only return items intersecting this GeoJSON Geometry. Mutually exclusive with **bbox**. \n +*Remember to URL encode the GeoJSON geometry when using GET request*.""", # noqa: E501 + openapi_examples={ + "madrid": { + "value": { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [-3.8549260500072933, 40.54923557897152], + [-3.8549260500072933, 40.29428000041938], + [-3.516597069715033, 40.29428000041938], + [-3.516597069715033, 40.54923557897152], + [-3.8549260500072933, 40.54923557897152], + ] + ], + "type": "Polygon", + }, + }, + }, + "new-york": { + "value": { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [-74.50117532354284, 41.128266394414055], + [-74.50117532354284, 40.35633909727355], + [-73.46713183168603, 40.35633909727355], + [-73.46713183168603, 41.128266394414055], + [-74.50117532354284, 41.128266394414055], + ] + ], + "type": "Polygon", + }, + }, + }, + }, + ), + ] = attr.ib(default=None) + datetime: Optional[DateTimeType] = attr.ib( + default=None, converter=_datetime_converter + ) + limit: Annotated[ + Optional[int], + Query( + description="Limits the number of results that are included in each page of the response." # noqa: E501 + ), + ] = attr.ib(default=10) + + +class BaseSearchPostRequest(Search): + """Base arguments for POST Request.""" + + limit: Optional[Limit] = Field( + 10, + description="Limits the number of results that are included in each page of the response.", # noqa: E501 + ) diff --git a/stac_fastapi/types/stac_fastapi/types/stac.py b/stac_fastapi/types/stac_fastapi/types/stac.py index 51bb6e652..b9c93fd80 100644 --- a/stac_fastapi/types/stac_fastapi/types/stac.py +++ b/stac_fastapi/types/stac_fastapi/types/stac.py @@ -1,4 +1,5 @@ """STAC types.""" + import sys from typing import Any, Dict, List, Literal, Optional, Union @@ -6,9 +7,9 @@ # Avoids a Pydantic error: # TypeError: You should use `typing_extensions.TypedDict` instead of -# `typing.TypedDict` with Python < 3.9.2. Without it, there is no way to +# `typing.TypedDict` with Python < 3.12.0. Without it, there is no way to # differentiate required and optional fields when subclassed. -if sys.version_info < (3, 9, 2): +if sys.version_info < (3, 12, 0): from typing_extensions import TypedDict else: from typing import TypedDict @@ -16,35 +17,28 @@ NumType = Union[float, int] -class LandingPage(TypedDict, total=False): - """STAC Landing Page.""" +class Catalog(TypedDict, total=False): + """STAC Catalog.""" type: str stac_version: str stac_extensions: Optional[List[str]] id: str - title: str + title: Optional[str] description: str - conformsTo: List[str] links: List[Dict[str, Any]] -class Conformance(TypedDict): - """STAC Conformance Classes.""" +class LandingPage(Catalog, total=False): + """STAC Landing Page.""" conformsTo: List[str] -class Catalog(TypedDict, total=False): - """STAC Catalog.""" +class Conformance(TypedDict): + """STAC Conformance Classes.""" - type: str - stac_version: str - stac_extensions: Optional[List[str]] - id: str - title: Optional[str] - description: str - links: List[Dict[str, Any]] + conformsTo: List[str] class Collection(Catalog, total=False): @@ -84,7 +78,6 @@ class ItemCollection(TypedDict, total=False): class Collections(TypedDict, total=False): """All collections endpoint. - https://github.com/radiantearth/stac-api-spec/tree/master/collections """ diff --git a/stac_fastapi/types/stac_fastapi/types/version.py b/stac_fastapi/types/stac_fastapi/types/version.py index bb0c7c379..7296e8a98 100644 --- a/stac_fastapi/types/stac_fastapi/types/version.py +++ b/stac_fastapi/types/stac_fastapi/types/version.py @@ -1,2 +1,2 @@ """Library version.""" -__version__ = "2.4.9" +__version__ = "3.0.0b2" diff --git a/stac_fastapi/types/tests/test_rfc3339.py b/stac_fastapi/types/tests/test_rfc3339.py index 0a402699a..dc4c897d5 100644 --- a/stac_fastapi/types/tests/test_rfc3339.py +++ b/stac_fastapi/types/tests/test_rfc3339.py @@ -1,6 +1,7 @@ -from datetime import timezone +from datetime import datetime, timezone import pytest +from fastapi import HTTPException from stac_fastapi.types.rfc3339 import ( now_in_utc, @@ -85,14 +86,42 @@ def test_parse_valid_str_to_datetime(test_input): @pytest.mark.parametrize("test_input", invalid_intervals) -def test_parse_invalid_interval_to_datetime(test_input): - with pytest.raises(ValueError): +def test_str_to_interval_with_invalid_interval(test_input): + with pytest.raises(HTTPException) as exc_info: + str_to_interval(test_input) + assert ( + exc_info.value.status_code == 400 + ), "str_to_interval should return a 400 status code for invalid interval" + + +@pytest.mark.parametrize("test_input", invalid_datetimes) +def test_str_to_interval_with_invalid_datetime(test_input): + with pytest.raises(HTTPException) as exc_info: str_to_interval(test_input) + assert ( + exc_info.value.status_code == 400 + ), "str_to_interval should return a 400 status code for invalid datetime" @pytest.mark.parametrize("test_input", valid_intervals) -def test_parse_valid_interval_to_datetime(test_input): - assert str_to_interval(test_input) +def test_str_to_interval_with_valid_interval(test_input): + assert isinstance( + str_to_interval(test_input), tuple + ), "str_to_interval should return tuple for multi-value input" + + +@pytest.mark.parametrize("test_input", valid_datetimes) +def test_str_to_interval_with_valid_datetime(test_input): + assert isinstance( + str_to_interval(test_input), datetime + ), "str_to_interval should return single datetime for single-value input" + + +def test_str_to_interval_with_none(): + """Test that str_to_interval returns None when provided with None.""" + assert ( + str_to_interval(None) is None + ), "str_to_interval should return None when input is None" def test_now_functions() -> None: