diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index 257d19a35..d9e5f2235 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -7,13 +7,15 @@ from fastapi import Body, Path from pydantic import BaseModel, create_model from pydantic.fields import UndefinedType +from stac_pydantic.shared import BBox from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.search import ( APIRequest, BaseSearchGetRequest, BaseSearchPostRequest, - str2list, + str2bbox, ) @@ -124,8 +126,8 @@ class ItemCollectionUri(CollectionUri): """Get item collection.""" limit: int = attr.ib(default=10) - bbox: Optional[str] = attr.ib(default=None, converter=str2list) - datetime: Optional[str] = attr.ib(default=None) + bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) + datetime: Optional[DateTimeType] = attr.ib(default=None) class POSTTokenPagination(BaseModel): diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 258cb93a0..739dcdb09 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -1,13 +1,12 @@ """Base clients.""" import abc -from datetime import datetime from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin import attr from fastapi import Request from stac_pydantic.links import Relations -from stac_pydantic.shared import MimeTypes +from stac_pydantic.shared import BBox, MimeTypes from stac_pydantic.version import STAC_VERSION from starlette.responses import Response @@ -15,6 +14,7 @@ from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension from stac_fastapi.types.requests import get_base_url +from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Conformance @@ -429,8 +429,8 @@ def get_search( self, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, - bbox: Optional[List[NumType]] = None, - datetime: Optional[Union[str, datetime]] = None, + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, query: Optional[str] = None, token: Optional[str] = None, @@ -491,8 +491,8 @@ def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: def item_collection( self, collection_id: str, - bbox: Optional[List[NumType]] = None, - datetime: Optional[Union[str, datetime]] = None, + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, limit: int = 10, token: str = None, **kwargs, @@ -626,8 +626,8 @@ async def get_search( self, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, - bbox: Optional[List[NumType]] = None, - datetime: Optional[Union[str, datetime]] = None, + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, limit: Optional[int] = 10, query: Optional[str] = None, token: Optional[str] = None, @@ -692,8 +692,8 @@ async def get_collection( async def item_collection( self, collection_id: str, - bbox: Optional[List[NumType]] = None, - datetime: Optional[Union[str, datetime]] = None, + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, limit: int = 10, token: str = None, **kwargs, diff --git a/stac_fastapi/types/stac_fastapi/types/rfc3339.py b/stac_fastapi/types/stac_fastapi/types/rfc3339.py index 6e3f97761..b81d109ca 100644 --- a/stac_fastapi/types/stac_fastapi/types/rfc3339.py +++ b/stac_fastapi/types/stac_fastapi/types/rfc3339.py @@ -1,13 +1,20 @@ """rfc3339.""" import re from datetime import datetime, timezone -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import iso8601 from pystac.utils import datetime_to_str RFC33339_PATTERN = r"^(\d\d\d\d)\-(\d\d)\-(\d\d)(T|t)(\d\d):(\d\d):(\d\d)([.]\d+)?(Z|([-+])(\d\d):(\d\d))$" +DateTimeType = Union[ + datetime, + Tuple[datetime, datetime], + Tuple[datetime, None], + Tuple[None, datetime], +] + def rfc3339_str_to_datetime(s: str) -> datetime: """Convert a string conforming to RFC 3339 to a :class:`datetime.datetime`. @@ -37,7 +44,7 @@ def rfc3339_str_to_datetime(s: str) -> datetime: def str_to_interval( interval: str, -) -> Optional[Tuple[Optional[datetime], Optional[datetime]]]: +) -> Optional[DateTimeType]: """Extract a tuple of datetimes from an interval string. Interval strings are defined by @@ -56,7 +63,10 @@ def str_to_interval( raise ValueError("Empty interval string is invalid.") values = interval.split("/") - if len(values) != 2: + if len(values) == 1: + # Single date for == date case + return rfc3339_str_to_datetime(values[0]) + elif len(values) > 2: raise ValueError( f"Interval string '{interval}' contains more than one forward slash." ) diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index f12c3c518..aafa85979 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -24,7 +24,7 @@ from stac_pydantic.shared import BBox from stac_pydantic.utils import AutoValueEnum -from stac_fastapi.types.rfc3339 import rfc3339_str_to_datetime, str_to_interval +from stac_fastapi.types.rfc3339 import DateTimeType, str_to_interval # Be careful: https://github.com/samuelcolvin/pydantic/issues/1423#issuecomment-642797287 NumType = Union[float, int] @@ -58,6 +58,14 @@ def str2list(x: str) -> Optional[List]: return x.split(",") +def str2bbox(x: str) -> Optional[BBox]: + """Convert string to BBox based on , delimiter.""" + if x: + t = tuple(float(v) for v in str2list(x)) + assert len(t) == 4 + return t + + @attr.s # type:ignore class APIRequest(abc.ABC): """Generic API Request base class.""" @@ -73,9 +81,9 @@ class BaseSearchGetRequest(APIRequest): collections: Optional[str] = attr.ib(default=None, converter=str2list) ids: Optional[str] = attr.ib(default=None, converter=str2list) - bbox: Optional[str] = attr.ib(default=None, converter=str2list) + bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox) intersects: Optional[str] = attr.ib(default=None, converter=str2list) - datetime: Optional[str] = attr.ib(default=None) + datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval) limit: Optional[int] = attr.ib(default=10) @@ -96,20 +104,18 @@ class BaseSearchPostRequest(BaseModel): intersects: Optional[ Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] ] - datetime: Optional[str] + datetime: Optional[DateTimeType] limit: Optional[conint(gt=0, le=10000)] = 10 @property def start_date(self) -> Optional[datetime]: """Extract the start date from the datetime string.""" - interval = str_to_interval(self.datetime) - return interval[0] if interval else None + return self.datetime[0] if self.datetime else None @property def end_date(self) -> Optional[datetime]: """Extract the end date from the datetime string.""" - interval = str_to_interval(self.datetime) - return interval[1] if interval else None + return self.datetime[1] if self.datetime else None @validator("intersects") def validate_spatial(cls, v, values): @@ -118,10 +124,12 @@ def validate_spatial(cls, v, values): raise ValueError("intersects and bbox parameters are mutually exclusive") return v - @validator("bbox") - def validate_bbox(cls, v: BBox): + @validator("bbox", pre=True) + def validate_bbox(cls, v: Union[str, BBox]) -> BBox: """Check order of supplied bbox coordinates.""" if v: + if type(v) == str: + v = str2bbox(v) # Validate order if len(v) == 4: xmin, ymin, xmax, ymax = v @@ -148,34 +156,11 @@ def validate_bbox(cls, v: BBox): return v - @validator("datetime") - def validate_datetime(cls, v): - """Validate datetime.""" - if "/" in v: - values = v.split("/") - else: - # Single date is interpreted as end date - values = ["..", v] - - dates = [] - for value in values: - if value == ".." or value == "": - dates.append("..") - continue - - # throws ValueError if invalid RFC 3339 string - dates.append(rfc3339_str_to_datetime(value)) - - if dates[0] == ".." and dates[1] == "..": - raise ValueError( - "Invalid datetime range, both ends of range may not be open" - ) - - if ".." not in dates and dates[0] > dates[1]: - raise ValueError( - "Invalid datetime range, must match format (begin_date, end_date)" - ) - + @validator("datetime", pre=True) + def validate_datetime(cls, v: Union[str, DateTimeType]) -> DateTimeType: + """Parse datetime.""" + if type(v) == str: + v = str_to_interval(v) return v @property diff --git a/stac_fastapi/types/stac_fastapi/types/stac.py b/stac_fastapi/types/stac_fastapi/types/stac.py index ef61c2f32..eaf451560 100644 --- a/stac_fastapi/types/stac_fastapi/types/stac.py +++ b/stac_fastapi/types/stac_fastapi/types/stac.py @@ -2,6 +2,8 @@ import sys from typing import Any, Dict, List, Optional, Union +from stac_pydantic.shared import BBox + # Avoids a Pydantic error: # TypeError: You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. # Without it, there is no way to differentiate required and optional fields when subclassed. @@ -63,7 +65,7 @@ class Item(TypedDict, total=False): stac_extensions: Optional[List[str]] id: str geometry: Dict[str, Any] - bbox: List[NumType] + bbox: BBox properties: Dict[str, Any] links: List[Dict[str, Any]] assets: Dict[str, Any]