Skip to content

Commit

Permalink
add collection search extension
Browse files Browse the repository at this point in the history
  • Loading branch information
hrodmn committed Aug 2, 2024
1 parent 1f36485 commit 1cd6833
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 61 deletions.
4 changes: 3 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ coverage.xml
*.log
.git
.envrc
*egg-info

venv
venv
env
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

## [Unreleased]

<<<<<<< HEAD
- Enable filter extension for `GET /items` requests and add `Queryables` links in `/collections` and `/collections/{collection_id}` responses ([#89](https://github.com/stac-utils/stac-fastapi-pgstac/pull/89))
=======
- Add collection search extension
>>>>>>> 9b2050a (add collection search extension)
## [3.0.0a4] - 2024-07-10

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"orjson",
"pydantic",
"stac_pydantic==3.1.*",
"stac-fastapi.api~=3.0.0b2",
"stac-fastapi.extensions~=3.0.0b2",
"stac-fastapi.types~=3.0.0b2",
"stac-fastapi.api~=3.0.0b3",
"stac-fastapi.extensions~=3.0.0b3",
"stac-fastapi.types~=3.0.0b3",
"asyncpg",
"buildpg",
"brotli_asgi",
Expand Down
26 changes: 24 additions & 2 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from fastapi.responses import ORJSONResponse
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import (
EmptyRequest,
ItemCollectionUri,
create_get_request_model,
create_post_request_model,
create_request_model,
)
from stac_fastapi.extensions.core import (
CollectionSearchExtension,
FieldsExtension,
FilterExtension,
SortExtension,
Expand Down Expand Up @@ -47,12 +49,26 @@
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
}

collections_extensions_map = {
"collection_search": CollectionSearchExtension(),
}

if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
_enabled_extensions = enabled_extensions.split(",")
extensions = [
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
extension
for key, extension in extensions_map.items()
if key in _enabled_extensions
]
collection_extensions = [
extension
for key, extension in collections_extensions_map.items()
if key in _enabled_extensions
]
else:
extensions = list(extensions_map.values())
collection_extensions = list(collections_extensions_map.values())


if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
items_get_request_model = create_request_model(
Expand All @@ -64,17 +80,23 @@
else:
items_get_request_model = ItemCollectionUri

if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions):
collections_get_request_model = CollectionSearchExtension().GET
else:
collections_get_request_model = EmptyRequest

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

api = StacApi(
settings=settings,
extensions=extensions,
extensions=extensions + collection_extensions,
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
collections_get_request_model=collections_get_request_model,
)
app = api.app

Expand Down
235 changes: 181 additions & 54 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,113 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

async def all_collections(self, request: Request, **kwargs) -> Collections:
"""Read all collections from the database."""
async def all_collections( # noqa: C901
self,
request: Request,
bbox: Optional[BBox] = None,
datetime: Optional[DateTimeType] = None,
limit: Optional[int] = None,
# Extensions
query: Optional[str] = None,
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
**kwargs,
) -> Collections:
"""Cross catalog search (GET).
Called with `GET /collections`.
Returns:
Collections which match the search criteria, returns all
collections by default.
"""
query_params = str(request.query_params)

# Kludgy fix because using factory does not allow alias for filter-lang
if filter_lang is None:
match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE)
if match:
filter_lang = match.group(1)

# Parse request parameters
base_args = {
"bbox": bbox,
"limit": limit,
"token": token,
"query": orjson.loads(unquote_plus(query)) if query else query,
}

clean = clean_search_args(
base_args=base_args,
datetime=datetime,
fields=fields,
sortby=sortby,
filter=filter,
filter_lang=filter_lang,
)

# Do the request
try:
search_request = self.post_request_model(**clean)
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
) from e

return await self._collection_search_base(search_request, request=request)

async def _collection_search_base( # noqa: C901
self,
search_request: PgstacSearch,
request: Request,
) -> Collections:
"""Cross catalog search (POST).
Called with `POST /search`.
Args:
search_request: search request parameters.
Returns:
All collections which match the search criteria.
"""

base_url = get_base_url(request)

async with request.app.state.get_connection(request, "r") as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
"""
)
settings: Settings = request.app.state.settings

if search_request.datetime:
search_request.datetime = format_datetime_range(search_request.datetime)

search_request.conf = search_request.conf or {}
search_request.conf["nohydrate"] = settings.use_api_hydrate

search_request_json = search_request.model_dump_json(
exclude_none=True, by_alias=True
)

try:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM collection_search(:req::text::jsonb);
""",
req=search_request_json,
)
collections_result: Collections = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError as e:
raise InvalidQueryParameter(
f"Datetime parameter {search_request.datetime} is invalid."
) from e

# next: Optional[str] = collections_result["links"].pop("next")
# prev: Optional[str] = collections_result["links"].pop("prev")

linked_collections: List[Collection] = []
collections = collections_result["collections"]
if collections is not None and len(collections) > 0:
for c in collections:
coll = Collection(**c)
Expand All @@ -71,6 +167,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:

linked_collections.append(coll)

# paging_links = await PagingLinks(
# request=request,
# next=next,
# prev=prev,
# ).get_links()

links = [
{
"rel": Relations.root.value,
Expand All @@ -88,8 +190,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
"href": urljoin(base_url, "collections"),
},
]
collection_list = Collections(collections=linked_collections or [], links=links)
return collection_list
return Collections(
collections=linked_collections or [],
links=links, # + paging_links
)

async def get_collection(
self, collection_id: str, request: Request, **kwargs
Expand Down Expand Up @@ -383,7 +487,7 @@ async def post_search(

return ItemCollection(**item_collection)

async def get_search( # noqa: C901
async def get_search(
self,
request: Request,
collections: Optional[List[str]] = None,
Expand Down Expand Up @@ -418,49 +522,15 @@ async def get_search( # noqa: C901
"query": orjson.loads(unquote_plus(query)) if query else query,
}

if filter:
if filter_lang == "cql2-text":
ast = parse_cql2_text(filter)
base_args["filter"] = orjson.loads(to_cql2(ast))
base_args["filter-lang"] = "cql2-json"

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v
clean = clean_search_args(
base_args=base_args,
intersects=intersects,
datetime=datetime,
fields=fields,
sortby=sortby,
filter=filter,
filter_lang=filter_lang,
)

# Do the request
try:
Expand All @@ -471,3 +541,60 @@ async def get_search( # noqa: C901
) from e

return await self.post_search(search_request, request=request)


def clean_search_args( # noqa: C901
base_args: dict[str, Any],
intersects: Optional[str] = None,
datetime: Optional[DateTimeType] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
) -> dict[str, Any]:
"""Clean up search arguments to match format expected by pgstac"""
if filter:
if filter_lang == "cql2-text":
ast = parse_cql2_text(filter)
base_args["filter"] = orjson.loads(to_cql2(ast))
base_args["filter-lang"] = "cql2-json"

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v

return clean
Loading

0 comments on commit 1cd6833

Please sign in to comment.