Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update fields extension and make sure the app can work without any extension #123

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## [Unreleased]

- Update stac-fastapi libraries to `~=3.0.0a3`
- make sure the application can work without any extension

## [3.0.0a1] - 2024-05-22

- Update stac-fastapi libraries to `~=3.0.0a1`
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"orjson",
"pydantic",
"stac_pydantic==3.1.*",
"stac-fastapi.api~=3.0.0a1",
"stac-fastapi.extensions~=3.0.0a1",
"stac-fastapi.types~=3.0.0a1",
"stac-fastapi.api~=3.0.0a3",
"stac-fastapi.extensions~=3.0.0a3",
"stac-fastapi.types~=3.0.0a3",
"asyncpg",
"buildpg",
"brotli_asgi",
Expand Down
4 changes: 2 additions & 2 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@
extensions = list(extensions_map.values())

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)

get_request_model = create_get_request_model(extensions)
api = StacApi(
settings=settings,
extensions=extensions,
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
search_get_request_model=create_get_request_model(extensions),
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
)
app = api.app
Expand Down
27 changes: 11 additions & 16 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Item crud client."""

import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import unquote_plus, urljoin

import attr
Expand Down Expand Up @@ -184,12 +184,9 @@ async def _search_base( # noqa: C901
prev: Optional[str] = items.pop("prev", None)
collection = ItemCollection(**items)

exclude = search_request.fields.exclude
if exclude and len(exclude) == 0:
exclude = None
include = search_request.fields.include
if include and len(include) == 0:
include = None
fields = getattr(search_request, "fields", None)
include: Set[str] = fields.include if fields and fields.include else set()
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()

async def _add_item_links(
feature: Item,
Expand All @@ -204,11 +201,7 @@ async def _add_item_links(
collection_id = feature.get("collection") or collection_id
item_id = feature.get("id") or item_id

if (
search_request.fields.exclude is None
or "links" not in search_request.fields.exclude
and all([collection_id, item_id])
):
if not exclude or "links" not in exclude and all([collection_id, item_id]):
feature["links"] = await ItemLinks(
collection_id=collection_id, # type: ignore
item_id=item_id, # type: ignore
Expand Down Expand Up @@ -252,6 +245,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
next=next,
prev=prev,
).get_links()

return collection

async def item_collection(
Expand Down Expand Up @@ -295,14 +289,14 @@ async def item_collection(
if v is not None and v != []:
clean[k] = v

search_request = self.post_request_model(
**clean,
)
search_request = self.post_request_model(**clean)
item_collection = await self._search_base(search_request, request=request)

links = await ItemCollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
item_collection["links"] = links

return item_collection

async def get_item(
Expand Down Expand Up @@ -355,15 +349,16 @@ async def get_search( # noqa: C901
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[BBox] = None,
intersects: Optional[str] = None,
datetime: Optional[DateTimeType] = None,
limit: Optional[int] = None,
# Extensions
query: Optional[str] = None,
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
intersects: Optional[str] = None,
**kwargs,
) -> ItemCollection:
"""Cross catalog search (GET).
Expand Down
2 changes: 1 addition & 1 deletion stac_fastapi/pgstac/extensions/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from buildpg import render
from fastapi import Request
from stac_fastapi.types.core import AsyncBaseFiltersClient
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
from stac_fastapi.types.errors import NotFoundError


Expand Down
112 changes: 111 additions & 1 deletion tests/api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from datetime import datetime, timedelta
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
from urllib.parse import quote_plus
Expand All @@ -6,9 +7,11 @@
import pytest
from fastapi import Request
from httpx import ASGITransport, AsyncClient
from pypgstac.db import PgstacDB
from pypgstac.load import Loader
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_post_request_model
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
from stac_fastapi.types import stac as stac_types

Expand All @@ -17,6 +20,9 @@
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")


STAC_CORE_ROUTES = [
"GET /",
"GET /collections",
Expand Down Expand Up @@ -669,11 +675,13 @@ async def get_collection(
FieldsExtension(),
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
api = StacApi(
client=Client(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
search_get_request_model=get_request_model,
)
app = api.app
await connect_to_db(app)
Expand All @@ -695,3 +703,105 @@ async def get_collection(
assert response.status_code == 200
finally:
await close_db_connection(app)


@pytest.mark.asyncio
@pytest.mark.parametrize("validation", [True, False])
@pytest.mark.parametrize("hydrate", [True, False])
async def test_no_extension(
hydrate, validation, load_test_data, database, pgstac
) -> None:
"""test PgSTAC with no extension."""
connection = f"postgresql://{database.user}:{database.password}@{database.host}:{database.port}/{database.dbname}"
with PgstacDB(dsn=connection) as db:
loader = Loader(db=db)
loader.load_collections(os.path.join(DATA_DIR, "test_collection.json"))
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))

settings = Settings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
testing=True,
use_api_hydrate=hydrate,
enable_response_models=validation,
)
extensions = []
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
client=CoreCrudClient(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
)
app = api.app
await connect_to_db(app)
try:
async with AsyncClient(transport=ASGITransport(app=app)) as client:
landing = await client.get("http://test/")
assert landing.status_code == 200, landing.text

collection = await client.get("http://test/collections/test-collection")
assert collection.status_code == 200, collection.text

collections = await client.get("http://test/collections")
assert collections.status_code == 200, collections.text

item = await client.get(
"http://test/collections/test-collection/items/test-item"
)
assert item.status_code == 200, item.text

item_collection = await client.get(
"http://test/collections/test-collection/items",
params={"limit": 10},
)
assert item_collection.status_code == 200, item_collection.text

get_search = await client.get(
"http://test/search",
params={
"collections": ["test-collection"],
},
)
assert get_search.status_code == 200, get_search.text

post_search = await client.post(
"http://test/search",
json={
"collections": ["test-collection"],
},
)
assert post_search.status_code == 200, post_search.text

get_search = await client.get(
"http://test/search",
params={
"collections": ["test-collection"],
"fields": "properties.datetime",
},
)
# fields should be ignored
assert get_search.status_code == 200, get_search.text
props = get_search.json()["features"][0]["properties"]
assert len(props) > 1

post_search = await client.post(
"http://test/search",
json={
"collections": ["test-collection"],
"fields": {
"include": ["properties.datetime"],
},
},
)
# fields should be ignored
assert post_search.status_code == 200, post_search.text
props = get_search.json()["features"][0]["properties"]
assert len(props) > 1
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved

finally:
await close_db_connection(app)
Loading