Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collection search extension #136

Merged
merged 19 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,12 @@
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

# will only use parameters defined in collections_get_request_model
collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch)

api = StacApi(
settings=settings,
extensions=extensions + [collection_search_extension],
hrodmn marked this conversation as resolved.
Show resolved Hide resolved
client=CoreCrudClient(
post_request_model=post_request_model, # type: ignore
collection_request_model=collection_search_model, # type: ignore
collections_get_request_model=collections_get_request_model, # type: ignore
),
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
Expand Down
20 changes: 13 additions & 7 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Item crud client."""

import json
import re
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import unquote_plus, urljoin
Expand All @@ -13,7 +14,7 @@
from pygeofilter.backends.cql2_json import to_cql2
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from pypgstac.hydration import hydrate
from stac_fastapi.api.models import JSONResponse
from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
Expand All @@ -38,7 +39,7 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

collection_request_model = attr.ib(default=PgstacSearch)
collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest)

async def all_collections( # noqa: C901
self,
Expand Down Expand Up @@ -83,7 +84,8 @@ async def all_collections( # noqa: C901

# Do the request
try:
search_request = self.collection_request_model(**clean)
search_request = self.collections_get_request_model(**clean)
print(search_request)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
Expand All @@ -93,7 +95,7 @@ async def all_collections( # noqa: C901

async def _collection_search_base( # noqa: C901
self,
search_request: PgstacSearch,
search_request: APIRequest,
request: Request,
) -> Collections:
"""Cross catalog search (GET).
Expand All @@ -107,8 +109,12 @@ async def _collection_search_base( # noqa: C901
All collections which match the search criteria.
"""
base_url = get_base_url(request)
search_request_json = search_request.model_dump_json(
exclude_none=True, by_alias=True
search_request_json = json.dumps(
{
key: value
for key, value in search_request.__dict__.items()
if value is not None
}
)

try:
Expand Down Expand Up @@ -533,7 +539,7 @@ def clean_search_args( # noqa: C901
filter_lang = "cql2-json"

base_args["filter"] = orjson.loads(filter_query)
base_args["filter-lang"] = filter_lang
base_args["filter_lang"] = filter_lang

if datetime:
base_args["datetime"] = format_datetime_range(datetime)
Expand Down
10 changes: 4 additions & 6 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,23 +730,21 @@ async def get_collection(
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)

collection_search_extension = CollectionSearchExtension.from_extensions(
extensions=extensions
)
collections_get_request_model = collection_search_extension.GET

api = StacApi(
client=Client(
post_request_model=post_request_model,
collection_request_model=collection_search_model,
collections_get_request_model=collection_search_extension.GET,
),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
search_get_request_model=get_request_model,
collections_get_request_model=collections_get_request_model,
collections_get_request_model=collection_search_extension.GET,
)
app = api.app
await connect_to_db(app)
Expand Down
5 changes: 1 addition & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,13 @@ def api_client(request, database):
)

collections_get_request_model = collection_search_extension.GET
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)

api = StacApi(
settings=api_settings,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(
post_request_model=search_post_request_model,
collection_request_model=collection_search_model,
collections_get_request_model=collections_get_request_model,
),
items_get_request_model=items_get_request_model,
search_get_request_model=search_get_request_model,
Expand Down
Loading