From 3c1e618eae45ef9c74a45e13282015dc3ca92169 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 23 Jan 2024 19:11:37 +0100 Subject: [PATCH 1/6] start sketching the CMR mosaic backend --- pyproject.toml | 1 + titiler/cmr/backend.py | 227 +++++++++++++++++++++++++++++++++ titiler/cmr/reader.py | 279 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 507 insertions(+) create mode 100644 titiler/cmr/backend.py create mode 100644 titiler/cmr/reader.py diff --git a/pyproject.toml b/pyproject.toml index a00f632..d5b8e28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ dependencies = [ "orjson", "titiler.core>=0.17.0,<0.18", + "titiler.mosaic>0.17.0,<0.18", "xarray", "rioxarray", "earthaccess", diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py new file mode 100644 index 0000000..2f38de4 --- /dev/null +++ b/titiler/cmr/backend.py @@ -0,0 +1,227 @@ +"""TiTiler.cmr custom Mosaic Backend.""" + +import itertools +from typing import Any, Dict, List, Optional, Tuple, Type + +import attr +import earthaccess +from cachetools import TTLCache, cached +from cachetools.keys import hashkey +from cogeo_mosaic.backends import BaseBackend +from cogeo_mosaic.errors import NoAssetFoundError +from cogeo_mosaic.mosaic import MosaicJSON +from morecantile import Tile, TileMatrixSet +from rasterio.crs import CRS +from rasterio.warp import transform_bounds +from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS +from rio_tiler.io import Reader +from rio_tiler.models import ImageData +from rio_tiler.mosaic import mosaic_reader +from rio_tiler.types import BBox + +from titiler.cmr.reader import ZarrReader +from titiler.pgstac.settings import CacheSettings, RetrySettings +from titiler.pgstac.utils import retry + +cache_config = CacheSettings() +retry_config = RetrySettings() + + +@attr.s +class CMRBackend(BaseBackend): + """CMR Mosaic Backend.""" + + # ConceptID + input: str = attr.ib() + + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + minzoom: int = attr.ib() + maxzoom: int = attr.ib() + + # default values for bounds + bounds: BBox = attr.ib(default=(-180, -90, 180, 90)) + + crs: CRS = attr.ib(default=WGS84_CRS) + geographic_crs: CRS = attr.ib(default=WGS84_CRS) + + # The reader is read-only (outside init) + mosaic_def: MosaicJSON = attr.ib(init=False) + + # NOT USED + # Use reader (outside init) + reader: Type[Reader] = attr.ib(default=Reader, init=False) + reader_options: Dict = attr.ib(factory=dict, init=False) + + _backend_name = "CMR" + + def __attrs_post_init__(self) -> None: + """Post Init.""" + # Construct a FAKE mosaicJSON + # mosaic_def has to be defined. + # we set `tiles` to an empty list. + self.mosaic_def = MosaicJSON( + mosaicjson="0.0.3", + name=self.input, + bounds=self.bounds, + minzoom=self.minzoom, + maxzoom=self.maxzoom, + tiles={}, + ) + + @minzoom.default + def _minzoom(self): + return self.tms.minzoom + + @maxzoom.default + def _maxzoom(self): + return self.tms.maxzoom + + def write(self, overwrite: bool = True) -> None: + """This method is not used but is required by the abstract class.""" + pass + + def update(self) -> None: + """We overwrite the default method.""" + pass + + def _read(self) -> MosaicJSON: + """This method is not used but is required by the abstract class.""" + pass + + def assets_for_tile(self, x: int, y: int, z: int, **kwargs: Any) -> List[str]: + """Retrieve assets for tile.""" + bbox = self.tms.bounds(Tile(x, y, z)) + return self.get_assets(*bbox, **kwargs) + + def assets_for_point( + self, + lng: float, + lat: float, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List[str]: + """Retrieve assets for point.""" + raise NotImplementedError + + def assets_for_bbox( + self, + xmin: float, + ymin: float, + xmax: float, + ymax: float, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List[Dict]: + """Retrieve assets for bbox.""" + if coord_crs != WGS84_CRS: + xmin, ymin, xmax, ymax = transform_bounds( + coord_crs, + WGS84_CRS, + xmin, + ymin, + xmax, + ymax, + ) + + return self.get_assets(xmin, ymin, xmax, ymax, **kwargs) + + @cached( # type: ignore + TTLCache(maxsize=cache_config.maxsize, ttl=cache_config.ttl), + key=lambda self, xmin, ymin, xmax, ymax, **kwargs: hashkey( + self.input, str(xmin), str(ymin), str(xmax), str(ymax), **kwargs + ), + ) + @retry( + tries=retry_config.retry, + delay=retry_config.delay, + exceptions=(), + ) + def get_assets( + self, + xmin: float, + ymin: float, + xmax: float, + ymax: float, + limit: int = 100, + **kwargs: Any, + ) -> List[str]: + """Find assets.""" + results = earthaccess.search_data( + concept_id=self.input, + bounding_box=(xmin, ymin, xmax, ymax), + count=limit, + **kwargs, + ) + return list( + itertools.chain.from_iterable([res.data_links() for res in results]) + ) + + @property + def _quadkeys(self) -> List[str]: + return [] + + def tile( + self, + tile_x: int, + tile_y: int, + tile_z: int, + cmr_query: Dict, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Get Tile from multiple observation.""" + mosaic_assets = self.assets_for_tile( + tile_x, + tile_y, + tile_z, + **cmr_query, + ) + + if not mosaic_assets: + raise NoAssetFoundError( + f"No assets found for tile {tile_z}-{tile_x}-{tile_y}" + ) + + def _reader(src_path: str, x: int, y: int, z: int, **kwargs: Any) -> ImageData: + if src_path.endswith(".tif"): + reader = Reader + else: + reader = ZarrReader + + with reader(src_path, tms=self.tms) as src_dst: + return src_dst.tile(x, y, z, **kwargs) + + return mosaic_reader(mosaic_assets, _reader, tile_x, tile_y, tile_z, **kwargs) + + def point( + self, + lon: float, + lat: float, + cmr_query: Dict, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List: + """Get Point value from multiple observation.""" + raise NotImplementedError + + def part( + self, + bbox: BBox, + cmr_query: Dict, + dst_crs: Optional[CRS] = None, + bounds_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Create an Image from multiple items for a bbox.""" + raise NotImplementedError + + def feature( + self, + shape: Dict, + cmr_query: Dict, + dst_crs: Optional[CRS] = None, + shape_crs: CRS = WGS84_CRS, + max_size: int = 1024, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Create an Image from multiple items for a GeoJSON feature.""" + raise NotImplementedError diff --git a/titiler/cmr/reader.py b/titiler/cmr/reader.py new file mode 100644 index 0000000..9cc59ac --- /dev/null +++ b/titiler/cmr/reader.py @@ -0,0 +1,279 @@ +"""ZarrReader. + +Originaly from titiler-xarray +""" + +import contextlib +import pickle +import re +from typing import Any, Dict, List, Optional + +import attr +import fsspec +import numpy +import s3fs +import xarray +from cachetools import TTLCache +from morecantile import TileMatrixSet +from rasterio.crs import CRS +from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS +from rio_tiler.io.xarray import XarrayReader +from rio_tiler.types import BBox + +from titiler.pgstac.settings import CacheSettings + +# Use simple in-memory cache for now (we can switch to redis later) +cache_config = CacheSettings() +cache_client: Any = TTLCache(maxsize=cache_config.maxsize, ttl=cache_config.ttl) + + +def parse_protocol(src_path: str, reference: Optional[bool] = False): + """ + Parse protocol from path. + """ + match = re.match(r"^(s3|https|http)", src_path) + protocol = "file" + if match: + protocol = match.group(0) + + # override protocol if reference + if reference: + protocol = "reference" + + return protocol + + +def xarray_engine(src_path: str): + """ + Parse xarray engine from path. + """ + # ".hdf", ".hdf5", ".h5" will be supported once we have tests + expand the type permitted for the group parameter + H5NETCDF_EXTENSIONS = [".nc", ".nc4"] + lower_filename = src_path.lower() + if any(lower_filename.endswith(ext) for ext in H5NETCDF_EXTENSIONS): + return "h5netcdf" + else: + return "zarr" + + +def get_filesystem( + src_path: str, + protocol: str, + xr_engine: str, + anon: bool = True, +): + """ + Get the filesystem for the given source path. + """ + if protocol == "s3": + s3_filesystem = s3fs.S3FileSystem() + return ( + s3_filesystem.open(src_path) + if xr_engine == "h5netcdf" + else s3fs.S3Map(root=src_path, s3=s3_filesystem) + ) + + elif protocol == "reference": + reference_args = {"fo": src_path, "remote_options": {"anon": anon}} + return fsspec.filesystem("reference", **reference_args).get_mapper("") + + elif protocol in ["https", "http", "file"]: + filesystem = fsspec.filesystem(protocol) # type: ignore + return ( + filesystem.open(src_path) + if xr_engine == "h5netcdf" + else filesystem.get_mapper(src_path) + ) + + else: + raise ValueError(f"Unsupported protocol: {protocol}") + + +def xarray_open_dataset( + src_path: str, + group: Optional[Any] = None, + reference: Optional[bool] = False, + decode_times: Optional[bool] = True, + consolidated: Optional[bool] = True, +) -> xarray.Dataset: + """Open dataset.""" + # Generate cache key and attempt to fetch the dataset from cache + cache_key = f"{src_path}_{group}" if group is not None else src_path + data_bytes = cache_client.get(cache_key, None) + if data_bytes: + return pickle.loads(data_bytes) + + protocol = parse_protocol(src_path, reference=reference) + xr_engine = xarray_engine(src_path) + file_handler = get_filesystem(src_path, protocol, xr_engine) + + # Arguments for xarray.open_dataset + # Default args + xr_open_args: Dict[str, Any] = { + "decode_coords": "all", + "decode_times": decode_times, + } + + # Argument if we're opening a datatree + if isinstance(group, int): + xr_open_args["group"] = group + + # NetCDF arguments + if xr_engine == "h5netcdf": + xr_open_args["engine"] = "h5netcdf" + xr_open_args["lock"] = False + + else: + # Zarr arguments + xr_open_args["engine"] = "zarr" + xr_open_args["consolidated"] = consolidated + + # Additional arguments when dealing with a reference file. + if reference: + xr_open_args["consolidated"] = False + xr_open_args["backend_kwargs"] = {"consolidated": False} + + ds = xarray.open_dataset(file_handler, **xr_open_args) + + # Serialize the dataset to bytes using pickle + cache_client[cache_key] = pickle.dumps(ds) + + return ds + + +def arrange_coordinates(da: xarray.DataArray) -> xarray.DataArray: + """ + Arrange coordinates to DataArray. + An rioxarray.exceptions.InvalidDimensionOrder error is raised if the coordinates are not in the correct order time, y, and x. + See: https://github.com/corteva/rioxarray/discussions/674 + We conform to using x and y as the spatial dimension names. You can do this a bit more elegantly with metpy but that is a heavy dependency. + """ + if "x" not in da.dims and "y" not in da.dims: + latitude_var_name = "lat" + longitude_var_name = "lon" + if "latitude" in da.dims: + latitude_var_name = "latitude" + if "longitude" in da.dims: + longitude_var_name = "longitude" + da = da.rename({latitude_var_name: "y", longitude_var_name: "x"}) + if "time" in da.dims: + da = da.transpose("time", "y", "x") + else: + da = da.transpose("y", "x") + return da + + +def get_variable( + ds: xarray.Dataset, + variable: str, + time_slice: Optional[str] = None, + drop_dim: Optional[str] = None, +) -> xarray.DataArray: + """Get Xarray variable as DataArray.""" + da = ds[variable] + da = arrange_coordinates(da) + # TODO: add test + if drop_dim: + dim_to_drop, dim_val = drop_dim.split("=") + da = da.sel({dim_to_drop: dim_val}).drop(dim_to_drop) + da = arrange_coordinates(da) + + if (da.x > 180).any(): + # Adjust the longitude coordinates to the -180 to 180 range + da = da.assign_coords(x=(da.x + 180) % 360 - 180) + + # Sort the dataset by the updated longitude coordinates + da = da.sortby(da.x) + + # Make sure we have a valid CRS + crs = da.rio.crs or "epsg:4326" + da.rio.write_crs(crs, inplace=True) + + if "time" in da.dims: + if time_slice: + time_as_str = time_slice.split("T")[0] + if da["time"].dtype == "O": + da["time"] = da["time"].astype("datetime64[ns]") + da = da.sel( + time=numpy.array(time_as_str, dtype=numpy.datetime64), method="nearest" + ) + else: + da = da.isel(time=0) + + return da + + +@attr.s +class ZarrReader(XarrayReader): + """ZarrReader: Open Zarr file and access DataArray.""" + + src_path: str = attr.ib() + variable: str = attr.ib() + + # xarray.Dataset options + reference: bool = attr.ib(default=False) + decode_times: bool = attr.ib(default=False) + group: Optional[Any] = attr.ib(default=None) + consolidated: Optional[bool] = attr.ib(default=True) + + # xarray.DataArray options + time_slice: Optional[str] = attr.ib(default=None) + drop_dim: Optional[str] = attr.ib(default=None) + + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + geographic_crs: CRS = attr.ib(default=WGS84_CRS) + + ds: xarray.Dataset = attr.ib(init=False) + input: xarray.DataArray = attr.ib(init=False) + + bounds: BBox = attr.ib(init=False) + crs: CRS = attr.ib(init=False) + + _minzoom: int = attr.ib(init=False, default=None) + _maxzoom: int = attr.ib(init=False, default=None) + + _dims: List = attr.ib(init=False, factory=list) + _ctx_stack = attr.ib(init=False, factory=contextlib.ExitStack) + + def __attrs_post_init__(self): + """Set bounds and CRS.""" + self.ds = self._ctx_stack.enter_context( + xarray_open_dataset( + self.src_path, + group=self.group, + reference=self.reference, + consolidated=self.consolidated, + ), + ) + self.input = get_variable( + self.ds, + self.variable, + time_slice=self.time_slice, + drop_dim=self.drop_dim, + ) + + self.bounds = tuple(self.input.rio.bounds()) + self.crs = self.input.rio.crs + + self._dims = [ + d + for d in self.input.dims + if d not in [self.input.rio.x_dim, self.input.rio.y_dim] + ] + + @classmethod + def list_variables( + cls, + src_path: str, + group: Optional[Any] = None, + reference: Optional[bool] = False, + consolidated: Optional[bool] = True, + ) -> List[str]: + """List available variable in a dataset.""" + with xarray_open_dataset( + src_path, + group=group, + reference=reference, + consolidated=consolidated, + ) as ds: + return list(ds.data_vars) # type: ignore From 34411e2aaaa7f857f671c37f5d84ed4d5382ca0b Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Wed, 24 Jan 2024 13:23:16 +0100 Subject: [PATCH 2/6] add tiles endpoints --- pyproject.toml | 1 + titiler/cmr/backend.py | 20 +- titiler/cmr/dependencies.py | 102 ++------- titiler/cmr/factory.py | 423 +++++++++++++++++++++++++++++++++++- 4 files changed, 447 insertions(+), 99 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5b8e28..b15653e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "orjson", "titiler.core>=0.17.0,<0.18", "titiler.mosaic>0.17.0,<0.18", + "rio_tiler>=6.4.0,<7.0", "xarray", "rioxarray", "earthaccess", diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py index 2f38de4..554c100 100644 --- a/titiler/cmr/backend.py +++ b/titiler/cmr/backend.py @@ -19,7 +19,6 @@ from rio_tiler.mosaic import mosaic_reader from rio_tiler.types import BBox -from titiler.cmr.reader import ZarrReader from titiler.pgstac.settings import CacheSettings, RetrySettings from titiler.pgstac.utils import retry @@ -38,6 +37,9 @@ class CMRBackend(BaseBackend): minzoom: int = attr.ib() maxzoom: int = attr.ib() + reader: Type[Reader] = attr.ib(default=Reader) + reader_options: Dict = attr.ib(factory=dict) + # default values for bounds bounds: BBox = attr.ib(default=(-180, -90, 180, 90)) @@ -47,11 +49,6 @@ class CMRBackend(BaseBackend): # The reader is read-only (outside init) mosaic_def: MosaicJSON = attr.ib(init=False) - # NOT USED - # Use reader (outside init) - reader: Type[Reader] = attr.ib(default=Reader, init=False) - reader_options: Dict = attr.ib(factory=dict, init=False) - _backend_name = "CMR" def __attrs_post_init__(self) -> None: @@ -182,12 +179,11 @@ def tile( ) def _reader(src_path: str, x: int, y: int, z: int, **kwargs: Any) -> ImageData: - if src_path.endswith(".tif"): - reader = Reader - else: - reader = ZarrReader - - with reader(src_path, tms=self.tms) as src_dst: + with self.reader( + src_path, + tms=self.tms, + **self.reader_options, + ) as src_dst: return src_dst.tile(x, y, z, **kwargs) return mosaic_reader(mosaic_assets, _reader, tile_x, tile_y, tile_z, **kwargs) diff --git a/titiler/cmr/dependencies.py b/titiler/cmr/dependencies.py index 2c34d48..9f5495d 100644 --- a/titiler/cmr/dependencies.py +++ b/titiler/cmr/dependencies.py @@ -1,6 +1,6 @@ """titiler-cmr dependencies.""" -from typing import List, Literal, Optional, get_args +from typing import Dict, List, Literal, Optional, get_args from ciso8601 import parse_rfc3339 from fastapi import HTTPException, Query @@ -8,50 +8,10 @@ from typing_extensions import Annotated from titiler.cmr.enums import MediaType -from titiler.cmr.errors import InvalidBBox ResponseType = Literal["json", "html"] -def s_intersects(bbox: List[float], spatial_extent: List[float]) -> bool: - """Check if bbox intersects with spatial extent.""" - return ( - (bbox[0] < spatial_extent[2]) - and (bbox[2] > spatial_extent[0]) - and (bbox[3] > spatial_extent[1]) - and (bbox[1] < spatial_extent[3]) - ) - - -def t_intersects(interval: List[str], temporal_extent: List[Optional[str]]) -> bool: - """Check if dates intersect with temporal extent.""" - if len(interval) == 1: - start = end = parse_rfc3339(interval[0]) - - else: - start = parse_rfc3339(interval[0]) if interval[0] not in ["..", ""] else None - end = parse_rfc3339(interval[1]) if interval[1] not in ["..", ""] else None - - mint, maxt = temporal_extent - min_ext = parse_rfc3339(mint) if mint is not None else None - max_ext = parse_rfc3339(maxt) if maxt is not None else None - - if len(interval) == 1: - if start == min_ext or start == max_ext: - return True - - if not start: - return max_ext <= end or min_ext <= end - - elif not end: - return min_ext >= start or max_ext >= start - - else: - return min_ext >= start and max_ext <= end - - return False - - def accept_media_type(accept: str, mediatypes: List[MediaType]) -> Optional[MediaType]: """Return MediaType based on accept header and available mediatype. @@ -116,42 +76,8 @@ def OutputType( return accept_media_type(request.headers.get("accept", ""), accepted_media) -def bbox_query( - bbox: Annotated[ - Optional[str], - Query( - description="A bounding box, expressed in WGS84 (westLong,southLat,eastLong,northLat) or WGS84h (westLong,southLat,minHeight,eastLong,northLat,maxHeight) CRS, by which to filter out all collections whose spatial extent does not intersect with the bounding box.", - openapi_examples={ - "simple": {"value": "160.6,-55.95,-170,-25.89"}, - }, - ), - ] = None -) -> Optional[List[float]]: - """BBox dependency.""" - if bbox: - bounds = list(map(float, bbox.split(","))) - if len(bounds) == 4: - if abs(bounds[0]) > 180 or abs(bounds[2]) > 180: - raise InvalidBBox(f"Invalid longitude in bbox: {bounds}") - if abs(bounds[1]) > 90 or abs(bounds[3]) > 90: - raise InvalidBBox(f"Invalid latitude in bbox: {bounds}") - - elif len(bounds) == 6: - if abs(bounds[0]) > 180 or abs(bounds[3]) > 180: - raise InvalidBBox(f"Invalid longitude in bbox: {bounds}") - if abs(bounds[1]) > 90 or abs(bounds[4]) > 90: - raise InvalidBBox(f"Invalid latitude in bbox: {bounds}") - - else: - raise InvalidBBox(f"Invalid bbox: {bounds}") - - return bounds - - return None - - -def datetime_query( - datetime: Annotated[ +def cmr_query( + temporal: Annotated[ Optional[str], Query( description="Either a date-time or an interval. Date and time expressions adhere to [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339). Intervals may be bounded or half-bounded (double-dots at start or end).", @@ -165,13 +91,21 @@ def datetime_query( }, ), ] = None, -) -> Optional[List[str]]: - """Datetime dependency.""" - if datetime: - dt = datetime.split("/") +) -> Dict: + """CMR Query options.""" + query = {} + if temporal: + dt = temporal.split("/") if len(dt) > 2: - raise HTTPException(status_code=422, detail="Invalid datetime: {datetime}") + raise HTTPException(status_code=422, detail="Invalid temporal: {temporal}") - return dt + if len(dt) == 1: + start = end = parse_rfc3339(dt[0]) - return None + else: + start = parse_rfc3339(dt[0]) if dt[0] not in ["..", ""] else None + end = parse_rfc3339(dt[1]) if dt[1] not in ["..", ""] else None + + query["temporal"] = [start, end] + + return query diff --git a/titiler/cmr/factory.py b/titiler/cmr/factory.py index be9e850..cc24589 100644 --- a/titiler/cmr/factory.py +++ b/titiler/cmr/factory.py @@ -3,22 +3,36 @@ import json import re from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union +from urllib.parse import urlencode import jinja2 +import numpy import orjson -from fastapi import APIRouter, Depends, Path +from fastapi import APIRouter, Depends, Path, Query from fastapi.responses import ORJSONResponse from morecantile import tms as default_tms from morecantile.defaults import TileMatrixSets +from pydantic import conint +from rio_tiler.io import Reader +from rio_tiler.types import RIOResampling, WarpResampling from starlette.requests import Request +from starlette.responses import Response from starlette.routing import compile_path, replace_params from starlette.templating import Jinja2Templates, _TemplateResponse from typing_extensions import Annotated from titiler.cmr import models -from titiler.cmr.dependencies import OutputType +from titiler.cmr.backend import CMRBackend +from titiler.cmr.dependencies import OutputType, cmr_query from titiler.cmr.enums import MediaType +from titiler.cmr.reader import ZarrReader +from titiler.core import dependencies +from titiler.core.algorithm import algorithms as available_algorithms +from titiler.core.factory import img_endpoint_params +from titiler.core.models.mapbox import TileJSON +from titiler.core.resources.enums import ImageType +from titiler.core.utils import render_image jinja2_env = jinja2.Environment( loader=jinja2.ChoiceLoader([jinja2.PackageLoader(__package__, "templates")]) @@ -127,6 +141,7 @@ def __post_init__(self): self.register_landing() self.register_conformance() self.register_tilematrixsets() + self.register_tiles() def register_landing(self) -> None: """register landing page endpoint.""" @@ -350,3 +365,405 @@ async def tilematrixset( ) return data + + def register_tiles(self): # noqa: C901 + """Register tileset endpoints.""" + + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}.{format}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + def tiles_endpoint( + collectionId: Annotated[ + str, + Path( + description="A CMR concept id, in the format '-' " + ), + ], + tileMatrixSetId: Annotated[ + Literal[tuple(self.supported_tms.list())], + Path(description="Identifier for a supported TileMatrixSet"), + ], + z: Annotated[ + int, + Path( + description="Identifier (Z) selecting one of the scales defined in the TileMatrixSet and representing the scaleDenominator the tile.", + ), + ], + x: Annotated[ + int, + Path( + description="Column (X) index of the tile on the selected TileMatrix. It cannot exceed the MatrixHeight-1 for the selected TileMatrix.", + ), + ], + y: Annotated[ + int, + Path( + description="Row (Y) index of the tile on the selected TileMatrix. It cannot exceed the MatrixWidth-1 for the selected TileMatrix.", + ), + ], + scale: Annotated[ # type: ignore + conint(gt=0, le=4), "Tile size scale. 1=256x256, 2=512x512..." + ] = 1, + format: Annotated[ + ImageType, + "Default will be automatically defined if the output image needs a mask (png) or not (jpeg).", + ] = None, + ################################################################### + # CMR options + query=Depends(cmr_query), + ################################################################### + backend: Annotated[ + Literal["cog", "xarray"], + Query(description="Backend to read the CMR dataset"), + ] = "cog", + ################################################################### + # ZarrReader Options + ################################################################### + variable: Annotated[ + Optional[str], + Query(description="Xarray Variable"), + ] = None, + drop_dim: Annotated[ + Optional[str], + Query(description="Dimension to drop"), + ] = None, + time_slice: Annotated[ + Optional[str], Query(description="Slice of time to read (if available)") + ] = None, + decode_times: Annotated[ + Optional[bool], + Query( + title="decode_times", + description="Whether to decode times", + ), + ] = None, + ################################################################### + # COG Reader Options + ################################################################### + indexes: Annotated[ + Optional[List[int]], + Query( + title="Band indexes", + alias="bidx", + description="Dataset band indexes", + ), + ] = None, + expression: Annotated[ + Optional[str], + Query( + title="Band Math expression", + description="rio-tiler's band math expression", + ), + ] = None, + unscale: Annotated[ + Optional[bool], + Query( + title="Apply internal Scale/Offset", + description="Apply internal Scale/Offset. Defaults to `False`.", + ), + ] = None, + resampling_method: Annotated[ + Optional[RIOResampling], + Query( + alias="resampling", + description="RasterIO resampling algorithm. Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Reader options + ################################################################### + nodata: Annotated[ + Optional[Union[str, int, float]], + Query( + title="Nodata value", + description="Overwrite internal Nodata value", + ), + ] = None, + reproject_method: Annotated[ + Optional[WarpResampling], + Query( + alias="reproject", + description="WarpKernel resampling algorithm (only used when doing re-projection). Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Rendering Options + ################################################################### + post_process=Depends(available_algorithms.dependency), + rescale=Depends(dependencies.RescalingParams), + color_formula=Depends(dependencies.ColorFormulaParams), + colormap=Depends(dependencies.ColorMapParams), + render_params=Depends(dependencies.ImageRenderingParams), + ) -> Response: + """Create map tile from a dataset.""" + resampling_method = resampling_method or "nearest" + reproject_method = reproject_method or "nearest" + if nodata is not None: + nodata = numpy.nan if nodata == "nan" else float(nodata) + + tms = self.supported_tms.get(tileMatrixSetId) + + read_options: Dict[str, Any] = {} + reader_options: Dict[str, Any] = {} + + if backend != "cog": + reader = ZarrReader + read_options = {} + + options = { + "variable": variable, + "decode_times": decode_times, + "drop_dim": drop_dim, + "time_slice": time_slice, + } + reader_options = {k: v for k, v in options.items() if v is not None} + else: + reader = Reader + options = { + "indexes": indexes, # type: ignore + "expression": expression, + "unscale": unscale, + "resampling_method": resampling_method, + } + read_options = {k: v for k, v in options.items() if v is not None} + + reader_options = {} + + with CMRBackend( + collectionId, + tms=tms, + reader=reader, + reader_options=reader_options, + ) as src_dst: + image = src_dst.tile( + x, + y, + z, + tilesize=scale * 256, + cmr_query=cmr_query, + nodata=nodata, + reproject_method=reproject_method, + **read_options, + ) + + if post_process: + image = post_process(image) + + if rescale: + image.rescale(rescale) + + if color_formula: + image.apply_color_formula(color_formula) + + content, media_type = render_image( + image, + output_format=format, + colormap=colormap, + **render_params, + ) + + return Response(content, media_type=media_type) + + @self.router.get( + "/collections/{collectionId}/{tileMatrixSetId}/tilejson.json", + response_model=TileJSON, + responses={200: {"description": "Return a tilejson"}}, + response_model_exclude_none=True, + tags=["TileJSON"], + ) + def tilejson_endpoint( # type: ignore + request: Request, + collectionId: Annotated[ + str, + Path( + description="A CMR concept id, in the format '-' " + ), + ], + tileMatrixSetId: Annotated[ + Literal[tuple(self.supported_tms.list())], + Path(description="Identifier for a supported TileMatrixSet"), + ], + tile_format: Annotated[ + Optional[ImageType], + Query( + description="Default will be automatically defined if the output image needs a mask (png) or not (jpeg).", + ), + ] = None, + tile_scale: Annotated[ + int, + Query( + gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." + ), + ] = 1, + minzoom: Annotated[ + Optional[int], + Query(description="Overwrite default minzoom."), + ] = None, + maxzoom: Annotated[ + Optional[int], + Query(description="Overwrite default maxzoom."), + ] = None, + ################################################################### + # CMR options + query=Depends(cmr_query), + ################################################################### + backend: Annotated[ + Literal["cog", "xarray"], + Query(description="Backend to read the CMR dataset"), + ] = "cog", + ################################################################### + # ZarrReader Options + ################################################################### + variable: Annotated[ + Optional[str], + Query(description="Xarray Variable"), + ] = None, + drop_dim: Annotated[ + Optional[str], + Query(description="Dimension to drop"), + ] = None, + time_slice: Annotated[ + Optional[str], Query(description="Slice of time to read (if available)") + ] = None, + decode_times: Annotated[ + Optional[bool], + Query( + title="decode_times", + description="Whether to decode times", + ), + ] = None, + ################################################################### + # COG Reader Options + ################################################################### + indexes: Annotated[ + Optional[List[int]], + Query( + title="Band indexes", + alias="bidx", + description="Dataset band indexes", + ), + ] = None, + expression: Annotated[ + Optional[str], + Query( + title="Band Math expression", + description="rio-tiler's band math expression", + ), + ] = None, + unscale: Annotated[ + Optional[bool], + Query( + title="Apply internal Scale/Offset", + description="Apply internal Scale/Offset. Defaults to `False`.", + ), + ] = None, + resampling_method: Annotated[ + Optional[RIOResampling], + Query( + alias="resampling", + description="RasterIO resampling algorithm. Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Reader options + ################################################################### + nodata: Annotated[ + Optional[Union[str, int, float]], + Query( + title="Nodata value", + description="Overwrite internal Nodata value", + ), + ] = None, + reproject_method: Annotated[ + Optional[WarpResampling], + Query( + alias="reproject", + description="WarpKernel resampling algorithm (only used when doing re-projection). Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Rendering Options + ################################################################### + post_process=Depends(available_algorithms.dependency), + rescale=Depends(dependencies.RescalingParams), + color_formula=Depends(dependencies.ColorFormulaParams), + colormap=Depends(dependencies.ColorMapParams), + render_params=Depends(dependencies.ImageRenderingParams), + ) -> Dict: + """Return TileJSON document for a dataset.""" + route_params = { + "z": "{z}", + "x": "{x}", + "y": "{y}", + "scale": tile_scale, + "tileMatrixSetId": tileMatrixSetId, + } + if tile_format: + route_params["format"] = tile_format.value + + tiles_url = self.url_for(request, "tiles_endpoint", **route_params) + + qs_key_to_remove = [ + "tilematrixsetid", + "tile_format", + "tile_scale", + "minzoom", + "maxzoom", + ] + qs = [ + (key, value) + for (key, value) in request.query_params._list + if key.lower() not in qs_key_to_remove + ] + if qs: + tiles_url += f"?{urlencode(qs)}" + + tms = self.supported_tms.get(tileMatrixSetId) + + if backend != "cog": + reader = ZarrReader + options = { + "variable": variable, + "decode_times": decode_times, + "drop_dim": drop_dim, + "time_slice": time_slice, + } + reader_options = {k: v for k, v in options.items() if v is not None} + else: + reader = Reader + reader_options = {} + + with CMRBackend( + collectionId, + tms=tms, + reader=reader, + reader_options=reader_options, + ) as src_dst: + minx, miny, maxx, maxy = zip( + [-180, -90, 180, 90], list(src_dst.geographic_bounds) + ) + bounds = [max(minx), max(miny), min(maxx), min(maxy)] + + return { + "bounds": bounds, + "minzoom": minzoom if minzoom is not None else src_dst.minzoom, + "maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom, + "tiles": [tiles_url], + } From ee8c20e2b12247d4e7c2e54e05a0afe6ae9b876b Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 25 Jan 2024 14:42:33 +0100 Subject: [PATCH 3/6] change asset type in backend --- pyproject.toml | 6 +- titiler/cmr/backend.py | 44 ++++++--- titiler/cmr/factory.py | 3 +- titiler/cmr/main.py | 11 +++ titiler/cmr/templates/collection.html | 77 ---------------- titiler/cmr/templates/collections.html | 118 ------------------------- 6 files changed, 49 insertions(+), 210 deletions(-) delete mode 100644 titiler/cmr/templates/collection.html delete mode 100644 titiler/cmr/templates/collections.html diff --git a/pyproject.toml b/pyproject.toml index b15653e..026e1fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,10 +25,14 @@ classifiers = [ dependencies = [ "orjson", "titiler.core>=0.17.0,<0.18", - "titiler.mosaic>0.17.0,<0.18", + "titiler.mosaic>=0.17.0,<0.18", "rio_tiler>=6.4.0,<7.0", "xarray", "rioxarray", + "cftime", + "h5netcdf", + "fsspec", + "s3fs", "earthaccess", "ciso8601~=2.3", "pydantic>=2.4,<3.0", diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py index 554c100..9c40893 100644 --- a/titiler/cmr/backend.py +++ b/titiler/cmr/backend.py @@ -1,7 +1,6 @@ """TiTiler.cmr custom Mosaic Backend.""" -import itertools -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, TypedDict import attr import earthaccess @@ -14,7 +13,7 @@ from rasterio.crs import CRS from rasterio.warp import transform_bounds from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS -from rio_tiler.io import Reader +from rio_tiler.io import BaseReader, Reader from rio_tiler.models import ImageData from rio_tiler.mosaic import mosaic_reader from rio_tiler.types import BBox @@ -26,6 +25,14 @@ retry_config = RetrySettings() +class Asset(TypedDict, total=False): + """Simple Asset model.""" + + url: str + type: Optional[str] + provider: Optional[str] + + @attr.s class CMRBackend(BaseBackend): """CMR Mosaic Backend.""" @@ -37,7 +44,7 @@ class CMRBackend(BaseBackend): minzoom: int = attr.ib() maxzoom: int = attr.ib() - reader: Type[Reader] = attr.ib(default=Reader) + reader: Type[BaseReader] = attr.ib(default=Reader) reader_options: Dict = attr.ib(factory=dict) # default values for bounds @@ -85,7 +92,7 @@ def _read(self) -> MosaicJSON: """This method is not used but is required by the abstract class.""" pass - def assets_for_tile(self, x: int, y: int, z: int, **kwargs: Any) -> List[str]: + def assets_for_tile(self, x: int, y: int, z: int, **kwargs: Any) -> List[Asset]: """Retrieve assets for tile.""" bbox = self.tms.bounds(Tile(x, y, z)) return self.get_assets(*bbox, **kwargs) @@ -96,7 +103,7 @@ def assets_for_point( lat: float, coord_crs: CRS = WGS84_CRS, **kwargs: Any, - ) -> List[str]: + ) -> List[Asset]: """Retrieve assets for point.""" raise NotImplementedError @@ -108,7 +115,7 @@ def assets_for_bbox( ymax: float, coord_crs: CRS = WGS84_CRS, **kwargs: Any, - ) -> List[Dict]: + ) -> List[Asset]: """Retrieve assets for bbox.""" if coord_crs != WGS84_CRS: xmin, ymin, xmax, ymax = transform_bounds( @@ -141,7 +148,7 @@ def get_assets( ymax: float, limit: int = 100, **kwargs: Any, - ) -> List[str]: + ) -> List[Asset]: """Find assets.""" results = earthaccess.search_data( concept_id=self.input, @@ -149,9 +156,20 @@ def get_assets( count=limit, **kwargs, ) - return list( - itertools.chain.from_iterable([res.data_links() for res in results]) - ) + + assets: List[Asset] = [] + for r in results: + assets.append( + { + "url": r.data_links(access="direct")[ + 0 + ], # NOTE: should we not do this? + "type": r["meta"].get("concept-type"), + "provider": r["meta"].get("provider-id"), + } + ) + + return assets @property def _quadkeys(self) -> List[str]: @@ -178,9 +196,9 @@ def tile( f"No assets found for tile {tile_z}-{tile_x}-{tile_y}" ) - def _reader(src_path: str, x: int, y: int, z: int, **kwargs: Any) -> ImageData: + def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData: with self.reader( - src_path, + asset["url"], tms=self.tms, **self.reader_options, ) as src_dst: diff --git a/titiler/cmr/factory.py b/titiler/cmr/factory.py index cc24589..a9f11ab 100644 --- a/titiler/cmr/factory.py +++ b/titiler/cmr/factory.py @@ -14,7 +14,7 @@ from morecantile import tms as default_tms from morecantile.defaults import TileMatrixSets from pydantic import conint -from rio_tiler.io import Reader +from rio_tiler.io import BaseReader, Reader from rio_tiler.types import RIOResampling, WarpResampling from starlette.requests import Request from starlette.responses import Response @@ -737,6 +737,7 @@ def tilejson_endpoint( # type: ignore tms = self.supported_tms.get(tileMatrixSetId) + reader: BaseReader if backend != "cog": reader = ZarrReader options = { diff --git a/titiler/cmr/main.py b/titiler/cmr/main.py index bada398..f1e0dcf 100644 --- a/titiler/cmr/main.py +++ b/titiler/cmr/main.py @@ -1,5 +1,8 @@ """TiTiler+cmr FastAPI application.""" +from contextlib import asynccontextmanager + +import earthaccess import jinja2 from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware @@ -22,6 +25,13 @@ settings = ApiSettings() +@asynccontextmanager +async def lifespan(app: FastAPI): + """FastAPI Lifespan.""" + app.state.cmr_auth = earthaccess.login(strategy="netrc") + yield + + app = FastAPI( title=settings.name, openapi_url="/api", @@ -38,6 +48,7 @@ """, version=titiler_cmr_version, root_path=settings.root_path, + lifespan=lifespan, ) diff --git a/titiler/cmr/templates/collection.html b/titiler/cmr/templates/collection.html deleted file mode 100644 index 5ebc174..0000000 --- a/titiler/cmr/templates/collection.html +++ /dev/null @@ -1,77 +0,0 @@ -{% include "header.html" %} - - - -

Collection: {{ response.title or response.id }}

- -
-
-

{{ response.description or response.title or response.id }}

- {% if "keywords" in response and length(response.keywords) > 0 %} -
-

- {% for keyword in response.keywords %} - {{ keyword }} - {% endfor %} -

-
- {% endif %} - -

Links

- -
-
-
- Loading... -
-
-
- - - -{% include "footer.html" %} diff --git a/titiler/cmr/templates/collections.html b/titiler/cmr/templates/collections.html deleted file mode 100644 index 45bb4ef..0000000 --- a/titiler/cmr/templates/collections.html +++ /dev/null @@ -1,118 +0,0 @@ -{% include "header.html" %} - -{% set show_prev_link = false %} -{% set show_next_link = false %} -{% if 'items?' in url %} - {% set urlq = url + '&' %} - {% else %} - {% set urlq = url + '?' %} -{% endif %} - - - -

Collections

- -

- Number of matching collections: {{ response.numberMatched }}
- Number of returned collections: {{ response.numberReturned }}
- Page: of
-

- -
- {% for link in response.links %} - {% if link.rel == 'prev' %} - - {% endif %} - {% endfor %} -
- -
- {% for link in response.links %} - {% if link.rel == 'next' %} - - {% endif %} - {% endfor %} -
- -
- - - - - - - - - -{% for collection in response.collections %} - - - - - -{% endfor %} - -
TitleTypeDescription
{{ collection.title or collection.id }}{{ collection.itemType }}{{ collection.description or collection.title or collection.id }}
-
- - - -{% include "footer.html" %} From 4b383a5fc07c684c44f443e1190c7b60f1f079a3 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Fri, 26 Jan 2024 21:02:55 +0100 Subject: [PATCH 4/6] implement auto/iam S3 credential --- pyproject.toml | 2 +- titiler/cmr/backend.py | 61 ++++++++++++++++++++++++++++++++++++----- titiler/cmr/factory.py | 22 ++++----------- titiler/cmr/reader.py | 13 +++++++-- titiler/cmr/settings.py | 53 ++++++++++++++++++++++++++++++++++- titiler/cmr/utils.py | 35 +++++++++++++++++++++++ 6 files changed, 157 insertions(+), 29 deletions(-) create mode 100644 titiler/cmr/utils.py diff --git a/pyproject.toml b/pyproject.toml index 026e1fa..b9056a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "orjson", "titiler.core>=0.17.0,<0.18", "titiler.mosaic>=0.17.0,<0.18", - "rio_tiler>=6.4.0,<7.0", + "rio_tiler[s3]>=6.4.0,<7.0", "xarray", "rioxarray", "cftime", diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py index 9c40893..4ae5458 100644 --- a/titiler/cmr/backend.py +++ b/titiler/cmr/backend.py @@ -4,11 +4,13 @@ import attr import earthaccess +import rasterio from cachetools import TTLCache, cached from cachetools.keys import hashkey from cogeo_mosaic.backends import BaseBackend from cogeo_mosaic.errors import NoAssetFoundError from cogeo_mosaic.mosaic import MosaicJSON +from earthaccess.auth import Auth from morecantile import Tile, TileMatrixSet from rasterio.crs import CRS from rasterio.warp import transform_bounds @@ -18,19 +20,29 @@ from rio_tiler.mosaic import mosaic_reader from rio_tiler.types import BBox -from titiler.pgstac.settings import CacheSettings, RetrySettings -from titiler.pgstac.utils import retry +from titiler.cmr.settings import AuthSettings, CacheSettings, RetrySettings +from titiler.cmr.utils import retry cache_config = CacheSettings() retry_config = RetrySettings() +s3_auth_config = AuthSettings() + + +@cached( # type: ignore + TTLCache(maxsize=100, ttl=60), + key=lambda auth, daac: hashkey(auth.tokens[0]["access_token"], daac), +) +def aws_s3_credential(auth: Auth, provider: str) -> Dict: + """Get AWS S3 credential through earthaccess.""" + return auth.get_s3_credentials(provider=provider) class Asset(TypedDict, total=False): """Simple Asset model.""" url: str - type: Optional[str] - provider: Optional[str] + type: str + provider: str @attr.s @@ -39,6 +51,7 @@ class CMRBackend(BaseBackend): # ConceptID input: str = attr.ib() + auth: Auth = attr.ib() tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) minzoom: int = attr.ib() @@ -164,8 +177,7 @@ def get_assets( "url": r.data_links(access="direct")[ 0 ], # NOTE: should we not do this? - "type": r["meta"].get("concept-type"), - "provider": r["meta"].get("provider-id"), + "provider": r["meta"]["provider-id"], } ) @@ -197,10 +209,45 @@ def tile( ) def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData: + if s3_auth_config.type == "auto": + s3_credentials = aws_s3_credential(self.auth, asset["provider"]) + + else: + s3_credentials = None + + if isinstance(self.reader, Reader): + aws_session = None + if s3_credentials: + aws_session = rasterio.session.AWSSession( + aws_access_key_id=s3_credentials["accessKeyId"], + aws_secret_access_key=s3_credentials["secretAccessKey"], + aws_session_token=s3_credentials["sessionToken"], + ) + + with rasterio.Env(aws_session): + with self.reader( + asset["url"], + tms=self.tms, + **self.reader_options, + ) as src_dst: + return src_dst.tile(x, y, z, **kwargs) + + if s3_credentials: + options = { + **self.reader_options, + "s3_credentials": { + "key": s3_credentials["accessKeyId"], + "secret": s3_credentials["secretAccessKey"], + "token": s3_credentials["sessionToken"], + }, + } + else: + options = self.reader_options + with self.reader( asset["url"], tms=self.tms, - **self.reader_options, + **options, ) as src_dst: return src_dst.tile(x, y, z, **kwargs) diff --git a/titiler/cmr/factory.py b/titiler/cmr/factory.py index a9f11ab..31d4f38 100644 --- a/titiler/cmr/factory.py +++ b/titiler/cmr/factory.py @@ -14,7 +14,7 @@ from morecantile import tms as default_tms from morecantile.defaults import TileMatrixSets from pydantic import conint -from rio_tiler.io import BaseReader, Reader +from rio_tiler.io import Reader from rio_tiler.types import RIOResampling, WarpResampling from starlette.requests import Request from starlette.responses import Response @@ -390,6 +390,7 @@ def register_tiles(self): # noqa: C901 tags=["Raster Tiles"], ) def tiles_endpoint( + request: Request, collectionId: Annotated[ str, Path( @@ -548,6 +549,7 @@ def tiles_endpoint( with CMRBackend( collectionId, + auth=request.app.cmr_auth, tms=tms, reader=reader, reader_options=reader_options, @@ -737,25 +739,11 @@ def tilejson_endpoint( # type: ignore tms = self.supported_tms.get(tileMatrixSetId) - reader: BaseReader - if backend != "cog": - reader = ZarrReader - options = { - "variable": variable, - "decode_times": decode_times, - "drop_dim": drop_dim, - "time_slice": time_slice, - } - reader_options = {k: v for k, v in options.items() if v is not None} - else: - reader = Reader - reader_options = {} - + # TODO: can we get metadata from the collection? with CMRBackend( collectionId, + auth=request.app.cmr_auth, tms=tms, - reader=reader, - reader_options=reader_options, ) as src_dst: minx, miny, maxx, maxy = zip( [-180, -90, 180, 90], list(src_dst.geographic_bounds) diff --git a/titiler/cmr/reader.py b/titiler/cmr/reader.py index 9cc59ac..76871de 100644 --- a/titiler/cmr/reader.py +++ b/titiler/cmr/reader.py @@ -20,7 +20,7 @@ from rio_tiler.io.xarray import XarrayReader from rio_tiler.types import BBox -from titiler.pgstac.settings import CacheSettings +from titiler.cmr.settings import CacheSettings # Use simple in-memory cache for now (we can switch to redis later) cache_config = CacheSettings() @@ -61,12 +61,14 @@ def get_filesystem( protocol: str, xr_engine: str, anon: bool = True, + s3_credentials: Optional[Dict] = None, ): """ Get the filesystem for the given source path. """ if protocol == "s3": - s3_filesystem = s3fs.S3FileSystem() + s3_credentials = s3_credentials or {} + s3_filesystem = s3fs.S3FileSystem(**s3_credentials) return ( s3_filesystem.open(src_path) if xr_engine == "h5netcdf" @@ -95,6 +97,7 @@ def xarray_open_dataset( reference: Optional[bool] = False, decode_times: Optional[bool] = True, consolidated: Optional[bool] = True, + s3_credentials: Optional[Dict] = None, ) -> xarray.Dataset: """Open dataset.""" # Generate cache key and attempt to fetch the dataset from cache @@ -105,7 +108,9 @@ def xarray_open_dataset( protocol = parse_protocol(src_path, reference=reference) xr_engine = xarray_engine(src_path) - file_handler = get_filesystem(src_path, protocol, xr_engine) + file_handler = get_filesystem( + src_path, protocol, xr_engine, s3_credentials=s3_credentials + ) # Arguments for xarray.open_dataset # Default args @@ -215,6 +220,7 @@ class ZarrReader(XarrayReader): decode_times: bool = attr.ib(default=False) group: Optional[Any] = attr.ib(default=None) consolidated: Optional[bool] = attr.ib(default=True) + s3_credentials: Optional[Dict] = attr.ib(default=None) # xarray.DataArray options time_slice: Optional[str] = attr.ib(default=None) @@ -243,6 +249,7 @@ def __attrs_post_init__(self): group=self.group, reference=self.reference, consolidated=self.consolidated, + s3_credentials=self.s3_credentials, ), ) self.input = get_variable( diff --git a/titiler/cmr/settings.py b/titiler/cmr/settings.py index 3d43592..c460f8f 100644 --- a/titiler/cmr/settings.py +++ b/titiler/cmr/settings.py @@ -1,7 +1,10 @@ """API settings.""" -from pydantic import field_validator +from typing import Literal + +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings +from typing_extensions import Annotated class ApiSettings(BaseSettings): @@ -23,3 +26,51 @@ class ApiSettings(BaseSettings): def parse_cors_origin(cls, v): """Parse CORS origins.""" return [origin.strip() for origin in v.split(",")] + + +class CacheSettings(BaseSettings): + """Cache settings""" + + # TTL of the cache in seconds + ttl: int = 300 + + # Maximum size of the cache in Number of element + maxsize: int = 512 + + # Whether or not caching is enabled + disable: bool = False + + model_config = {"env_prefix": "TITILER_CMR_CACHE_", "env_file": ".env"} + + @model_validator(mode="after") + def check_enable(self): + """Check if cache is disabled.""" + if self.disable: + self.ttl = 0 + self.maxsize = 0 + + return self + + +class RetrySettings(BaseSettings): + """Retry settings""" + + retry: Annotated[int, Field(ge=0)] = 3 + delay: Annotated[float, Field(ge=0.0)] = 0.0 + + model_config = { + "env_prefix": "TITILER_CMR_API_", + "env_file": ".env", + "extra": "ignore", + } + + +class AuthSettings(BaseSettings): + """AWS credential settings.""" + + type: Literal["auto", "iam"] = "auto" + + model_config = { + "env_prefix": "TITILER_CMR_S3_AUTH_", + "env_file": ".env", + } diff --git a/titiler/cmr/utils.py b/titiler/cmr/utils.py new file mode 100644 index 0000000..320c7d6 --- /dev/null +++ b/titiler/cmr/utils.py @@ -0,0 +1,35 @@ +"""titiler.cmr utilities. + +Code from titiler.pgstac, MIT License. + +""" + +import time +from typing import Any, Sequence, Type, Union + + +def retry( + tries: int, + exceptions: Union[Type[Exception], Sequence[Type[Exception]]] = Exception, + delay: float = 0.0, +): + """Retry Decorator""" + + def _decorator(func: Any): + def _newfn(*args: Any, **kwargs: Any): + + attempt = 0 + while attempt < tries: + try: + return func(*args, **kwargs) + + except exceptions: # type: ignore + + attempt += 1 + time.sleep(delay) + + return func(*args, **kwargs) + + return _newfn + + return _decorator From 2d46b49cd7c4348144ac797de86b05b72677090e Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Mon, 29 Jan 2024 21:09:09 +0100 Subject: [PATCH 5/6] optional cmr login --- titiler/cmr/backend.py | 5 +++-- titiler/cmr/factory.py | 2 +- titiler/cmr/main.py | 9 +++++++-- titiler/cmr/settings.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py index 4ae5458..04ae190 100644 --- a/titiler/cmr/backend.py +++ b/titiler/cmr/backend.py @@ -51,7 +51,6 @@ class CMRBackend(BaseBackend): # ConceptID input: str = attr.ib() - auth: Auth = attr.ib() tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) minzoom: int = attr.ib() @@ -69,6 +68,8 @@ class CMRBackend(BaseBackend): # The reader is read-only (outside init) mosaic_def: MosaicJSON = attr.ib(init=False) + auth: Optional[Auth] = attr.ib(default=None) + _backend_name = "CMR" def __attrs_post_init__(self) -> None: @@ -209,7 +210,7 @@ def tile( ) def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData: - if s3_auth_config.type == "auto": + if s3_auth_config.type == "environment" and self.auth: s3_credentials = aws_s3_credential(self.auth, asset["provider"]) else: diff --git a/titiler/cmr/factory.py b/titiler/cmr/factory.py index 31d4f38..c1b6c93 100644 --- a/titiler/cmr/factory.py +++ b/titiler/cmr/factory.py @@ -549,10 +549,10 @@ def tiles_endpoint( with CMRBackend( collectionId, - auth=request.app.cmr_auth, tms=tms, reader=reader, reader_options=reader_options, + auth=request.app.cmr_auth, ) as src_dst: image = src_dst.tile( x, diff --git a/titiler/cmr/main.py b/titiler/cmr/main.py index f1e0dcf..3d1d60a 100644 --- a/titiler/cmr/main.py +++ b/titiler/cmr/main.py @@ -10,7 +10,7 @@ from titiler.cmr import __version__ as titiler_cmr_version from titiler.cmr.factory import Endpoints -from titiler.cmr.settings import ApiSettings +from titiler.cmr.settings import ApiSettings, AuthSettings from titiler.core.middleware import CacheControlMiddleware jinja2_env = jinja2.Environment( @@ -23,12 +23,17 @@ templates = Jinja2Templates(env=jinja2_env) settings = ApiSettings() +auth_config = AuthSettings() @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI Lifespan.""" - app.state.cmr_auth = earthaccess.login(strategy="netrc") + if auth_config.strategy == "environment": + app.state.cmr_auth = earthaccess.login(strategy="environment") + else: + app.state.cmr_auth = None + yield diff --git a/titiler/cmr/settings.py b/titiler/cmr/settings.py index c460f8f..0fe4868 100644 --- a/titiler/cmr/settings.py +++ b/titiler/cmr/settings.py @@ -68,7 +68,7 @@ class RetrySettings(BaseSettings): class AuthSettings(BaseSettings): """AWS credential settings.""" - type: Literal["auto", "iam"] = "auto" + strategy: Literal["environment", "iam"] = "environment" model_config = { "env_prefix": "TITILER_CMR_S3_AUTH_", From ff90ddb0e9d32128f92aecb46238c4371d736151 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Thu, 1 Feb 2024 09:44:28 +0100 Subject: [PATCH 6/6] update deployment --- .github/workflows/ci.yml | 2 ++ infrastructure/aws/lambda/handler.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c22ab8a..5ff0aff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,6 +99,7 @@ jobs: # STACK_ALARM_EMAIL: ${{ secrets.ALARM_EMAIL }} STACK_ROLE_ARN: ${{ secrets.lambda_role_arn }} STACK_STAGE: staging + STACK_ADDITIONAL_ENV: '{"TITILER_CMR_S3_AUTH_STRATEGY":"iam"}' # Build and deploy to production deployment whenever there a new tag is pushed - name: Build & Deploy Production @@ -108,3 +109,4 @@ jobs: # STACK_ALARM_EMAIL: ${{ secrets.ALARM_EMAIL }} STACK_ROLE_ARN: ${{ secrets.lambda_role_arn }} STACK_STAGE: production + STACK_ADDITIONAL_ENV: '{"TITILER_CMR_S3_AUTH_STRATEGY":"iam"}' diff --git a/infrastructure/aws/lambda/handler.py b/infrastructure/aws/lambda/handler.py index 1ace204..1c5924a 100644 --- a/infrastructure/aws/lambda/handler.py +++ b/infrastructure/aws/lambda/handler.py @@ -1,12 +1,32 @@ """AWS Lambda handler.""" +import asyncio import logging +import os +import earthaccess from mangum import Mangum from titiler.cmr.main import app +from titiler.cmr.settings import AuthSettings + +auth_config = AuthSettings() logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) + +@app.on_event("startup") +async def startup_event() -> None: + """startup.""" + if auth_config.strategy == "environment": + app.state.cmr_auth = earthaccess.login(strategy="environment") + else: + app.state.cmr_auth = None + + handler = Mangum(app, lifespan="off") + +if "AWS_EXECUTION_ENV" in os.environ: + loop = asyncio.get_event_loop() + loop.run_until_complete(app.router.startup())