Skip to content

Commit

Permalink
add collection_request_model to the client
Browse files Browse the repository at this point in the history
  • Loading branch information
hrodmn committed Sep 25, 2024
1 parent fd5c48b commit bfd94f4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
8 changes: 7 additions & 1 deletion stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

# will only use parameters defined in collections_get_request_model
collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch)

api = StacApi(
settings=settings,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
client=CoreCrudClient(
post_request_model=post_request_model, # type: ignore
collection_request_model=collection_search_model, # type: ignore
),
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
Expand Down
3 changes: 3 additions & 0 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ def clean_search_args( # noqa: C901
filter_query = to_cql2(parse_cql2_text(filter_query))
filter_lang = "cql2-json"

base_args["filter"] = orjson.loads(filter_query)
base_args["filter-lang"] = filter_lang

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

Expand Down
19 changes: 17 additions & 2 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
from stac_fastapi.extensions.core import (
CollectionSearchExtension,
FieldsExtension,
TransactionExtension,
)
from stac_fastapi.types import stac as stac_types

from stac_fastapi.pgstac.core import CoreCrudClient, Settings
Expand Down Expand Up @@ -726,12 +730,23 @@ async def get_collection(
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)
collection_search_extension = CollectionSearchExtension.from_extensions(
extensions=extensions
)
collections_get_request_model = collection_search_extension.GET
api = StacApi(
client=Client(post_request_model=post_request_model),
client=Client(
post_request_model=post_request_model,
collection_request_model=collection_search_model,
),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
search_get_request_model=get_request_model,
collections_get_request_model=collections_get_request_model,
)
app = api.app
await connect_to_db(app)
Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,17 @@ def api_client(request, database):
)

collections_get_request_model = collection_search_extension.GET
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)

api = StacApi(
settings=api_settings,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(post_request_model=search_post_request_model),
client=CoreCrudClient(
post_request_model=search_post_request_model,
collection_request_model=collection_search_model,
),
items_get_request_model=items_get_request_model,
search_get_request_model=search_get_request_model,
search_post_request_model=search_post_request_model,
Expand Down

0 comments on commit bfd94f4

Please sign in to comment.