diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c22ab8a..5ff0aff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,6 +99,7 @@ jobs: # STACK_ALARM_EMAIL: ${{ secrets.ALARM_EMAIL }} STACK_ROLE_ARN: ${{ secrets.lambda_role_arn }} STACK_STAGE: staging + STACK_ADDITIONAL_ENV: '{"TITILER_CMR_S3_AUTH_STRATEGY":"iam"}' # Build and deploy to production deployment whenever there a new tag is pushed - name: Build & Deploy Production @@ -108,3 +109,4 @@ jobs: # STACK_ALARM_EMAIL: ${{ secrets.ALARM_EMAIL }} STACK_ROLE_ARN: ${{ secrets.lambda_role_arn }} STACK_STAGE: production + STACK_ADDITIONAL_ENV: '{"TITILER_CMR_S3_AUTH_STRATEGY":"iam"}' diff --git a/infrastructure/aws/lambda/handler.py b/infrastructure/aws/lambda/handler.py index 1ace204..1c5924a 100644 --- a/infrastructure/aws/lambda/handler.py +++ b/infrastructure/aws/lambda/handler.py @@ -1,12 +1,32 @@ """AWS Lambda handler.""" +import asyncio import logging +import os +import earthaccess from mangum import Mangum from titiler.cmr.main import app +from titiler.cmr.settings import AuthSettings + +auth_config = AuthSettings() logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) + +@app.on_event("startup") +async def startup_event() -> None: + """startup.""" + if auth_config.strategy == "environment": + app.state.cmr_auth = earthaccess.login(strategy="environment") + else: + app.state.cmr_auth = None + + handler = Mangum(app, lifespan="off") + +if "AWS_EXECUTION_ENV" in os.environ: + loop = asyncio.get_event_loop() + loop.run_until_complete(app.router.startup()) diff --git a/pyproject.toml b/pyproject.toml index a00f632..b9056a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,14 @@ classifiers = [ dependencies = [ "orjson", "titiler.core>=0.17.0,<0.18", + "titiler.mosaic>=0.17.0,<0.18", + "rio_tiler[s3]>=6.4.0,<7.0", "xarray", "rioxarray", + "cftime", + "h5netcdf", + "fsspec", + "s3fs", "earthaccess", "ciso8601~=2.3", "pydantic>=2.4,<3.0", diff --git a/titiler/cmr/backend.py b/titiler/cmr/backend.py new file mode 100644 index 0000000..04ae190 --- /dev/null +++ b/titiler/cmr/backend.py @@ -0,0 +1,289 @@ +"""TiTiler.cmr custom Mosaic Backend.""" + +from typing import Any, Dict, List, Optional, Tuple, Type, TypedDict + +import attr +import earthaccess +import rasterio +from cachetools import TTLCache, cached +from cachetools.keys import hashkey +from cogeo_mosaic.backends import BaseBackend +from cogeo_mosaic.errors import NoAssetFoundError +from cogeo_mosaic.mosaic import MosaicJSON +from earthaccess.auth import Auth +from morecantile import Tile, TileMatrixSet +from rasterio.crs import CRS +from rasterio.warp import transform_bounds +from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS +from rio_tiler.io import BaseReader, Reader +from rio_tiler.models import ImageData +from rio_tiler.mosaic import mosaic_reader +from rio_tiler.types import BBox + +from titiler.cmr.settings import AuthSettings, CacheSettings, RetrySettings +from titiler.cmr.utils import retry + +cache_config = CacheSettings() +retry_config = RetrySettings() +s3_auth_config = AuthSettings() + + +@cached( # type: ignore + TTLCache(maxsize=100, ttl=60), + key=lambda auth, daac: hashkey(auth.tokens[0]["access_token"], daac), +) +def aws_s3_credential(auth: Auth, provider: str) -> Dict: + """Get AWS S3 credential through earthaccess.""" + return auth.get_s3_credentials(provider=provider) + + +class Asset(TypedDict, total=False): + """Simple Asset model.""" + + url: str + type: str + provider: str + + +@attr.s +class CMRBackend(BaseBackend): + """CMR Mosaic Backend.""" + + # ConceptID + input: str = attr.ib() + + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + minzoom: int = attr.ib() + maxzoom: int = attr.ib() + + reader: Type[BaseReader] = attr.ib(default=Reader) + reader_options: Dict = attr.ib(factory=dict) + + # default values for bounds + bounds: BBox = attr.ib(default=(-180, -90, 180, 90)) + + crs: CRS = attr.ib(default=WGS84_CRS) + geographic_crs: CRS = attr.ib(default=WGS84_CRS) + + # The reader is read-only (outside init) + mosaic_def: MosaicJSON = attr.ib(init=False) + + auth: Optional[Auth] = attr.ib(default=None) + + _backend_name = "CMR" + + def __attrs_post_init__(self) -> None: + """Post Init.""" + # Construct a FAKE mosaicJSON + # mosaic_def has to be defined. + # we set `tiles` to an empty list. + self.mosaic_def = MosaicJSON( + mosaicjson="0.0.3", + name=self.input, + bounds=self.bounds, + minzoom=self.minzoom, + maxzoom=self.maxzoom, + tiles={}, + ) + + @minzoom.default + def _minzoom(self): + return self.tms.minzoom + + @maxzoom.default + def _maxzoom(self): + return self.tms.maxzoom + + def write(self, overwrite: bool = True) -> None: + """This method is not used but is required by the abstract class.""" + pass + + def update(self) -> None: + """We overwrite the default method.""" + pass + + def _read(self) -> MosaicJSON: + """This method is not used but is required by the abstract class.""" + pass + + def assets_for_tile(self, x: int, y: int, z: int, **kwargs: Any) -> List[Asset]: + """Retrieve assets for tile.""" + bbox = self.tms.bounds(Tile(x, y, z)) + return self.get_assets(*bbox, **kwargs) + + def assets_for_point( + self, + lng: float, + lat: float, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List[Asset]: + """Retrieve assets for point.""" + raise NotImplementedError + + def assets_for_bbox( + self, + xmin: float, + ymin: float, + xmax: float, + ymax: float, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List[Asset]: + """Retrieve assets for bbox.""" + if coord_crs != WGS84_CRS: + xmin, ymin, xmax, ymax = transform_bounds( + coord_crs, + WGS84_CRS, + xmin, + ymin, + xmax, + ymax, + ) + + return self.get_assets(xmin, ymin, xmax, ymax, **kwargs) + + @cached( # type: ignore + TTLCache(maxsize=cache_config.maxsize, ttl=cache_config.ttl), + key=lambda self, xmin, ymin, xmax, ymax, **kwargs: hashkey( + self.input, str(xmin), str(ymin), str(xmax), str(ymax), **kwargs + ), + ) + @retry( + tries=retry_config.retry, + delay=retry_config.delay, + exceptions=(), + ) + def get_assets( + self, + xmin: float, + ymin: float, + xmax: float, + ymax: float, + limit: int = 100, + **kwargs: Any, + ) -> List[Asset]: + """Find assets.""" + results = earthaccess.search_data( + concept_id=self.input, + bounding_box=(xmin, ymin, xmax, ymax), + count=limit, + **kwargs, + ) + + assets: List[Asset] = [] + for r in results: + assets.append( + { + "url": r.data_links(access="direct")[ + 0 + ], # NOTE: should we not do this? + "provider": r["meta"]["provider-id"], + } + ) + + return assets + + @property + def _quadkeys(self) -> List[str]: + return [] + + def tile( + self, + tile_x: int, + tile_y: int, + tile_z: int, + cmr_query: Dict, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Get Tile from multiple observation.""" + mosaic_assets = self.assets_for_tile( + tile_x, + tile_y, + tile_z, + **cmr_query, + ) + + if not mosaic_assets: + raise NoAssetFoundError( + f"No assets found for tile {tile_z}-{tile_x}-{tile_y}" + ) + + def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData: + if s3_auth_config.type == "environment" and self.auth: + s3_credentials = aws_s3_credential(self.auth, asset["provider"]) + + else: + s3_credentials = None + + if isinstance(self.reader, Reader): + aws_session = None + if s3_credentials: + aws_session = rasterio.session.AWSSession( + aws_access_key_id=s3_credentials["accessKeyId"], + aws_secret_access_key=s3_credentials["secretAccessKey"], + aws_session_token=s3_credentials["sessionToken"], + ) + + with rasterio.Env(aws_session): + with self.reader( + asset["url"], + tms=self.tms, + **self.reader_options, + ) as src_dst: + return src_dst.tile(x, y, z, **kwargs) + + if s3_credentials: + options = { + **self.reader_options, + "s3_credentials": { + "key": s3_credentials["accessKeyId"], + "secret": s3_credentials["secretAccessKey"], + "token": s3_credentials["sessionToken"], + }, + } + else: + options = self.reader_options + + with self.reader( + asset["url"], + tms=self.tms, + **options, + ) as src_dst: + return src_dst.tile(x, y, z, **kwargs) + + return mosaic_reader(mosaic_assets, _reader, tile_x, tile_y, tile_z, **kwargs) + + def point( + self, + lon: float, + lat: float, + cmr_query: Dict, + coord_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> List: + """Get Point value from multiple observation.""" + raise NotImplementedError + + def part( + self, + bbox: BBox, + cmr_query: Dict, + dst_crs: Optional[CRS] = None, + bounds_crs: CRS = WGS84_CRS, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Create an Image from multiple items for a bbox.""" + raise NotImplementedError + + def feature( + self, + shape: Dict, + cmr_query: Dict, + dst_crs: Optional[CRS] = None, + shape_crs: CRS = WGS84_CRS, + max_size: int = 1024, + **kwargs: Any, + ) -> Tuple[ImageData, List[str]]: + """Create an Image from multiple items for a GeoJSON feature.""" + raise NotImplementedError diff --git a/titiler/cmr/dependencies.py b/titiler/cmr/dependencies.py index 2c34d48..9f5495d 100644 --- a/titiler/cmr/dependencies.py +++ b/titiler/cmr/dependencies.py @@ -1,6 +1,6 @@ """titiler-cmr dependencies.""" -from typing import List, Literal, Optional, get_args +from typing import Dict, List, Literal, Optional, get_args from ciso8601 import parse_rfc3339 from fastapi import HTTPException, Query @@ -8,50 +8,10 @@ from typing_extensions import Annotated from titiler.cmr.enums import MediaType -from titiler.cmr.errors import InvalidBBox ResponseType = Literal["json", "html"] -def s_intersects(bbox: List[float], spatial_extent: List[float]) -> bool: - """Check if bbox intersects with spatial extent.""" - return ( - (bbox[0] < spatial_extent[2]) - and (bbox[2] > spatial_extent[0]) - and (bbox[3] > spatial_extent[1]) - and (bbox[1] < spatial_extent[3]) - ) - - -def t_intersects(interval: List[str], temporal_extent: List[Optional[str]]) -> bool: - """Check if dates intersect with temporal extent.""" - if len(interval) == 1: - start = end = parse_rfc3339(interval[0]) - - else: - start = parse_rfc3339(interval[0]) if interval[0] not in ["..", ""] else None - end = parse_rfc3339(interval[1]) if interval[1] not in ["..", ""] else None - - mint, maxt = temporal_extent - min_ext = parse_rfc3339(mint) if mint is not None else None - max_ext = parse_rfc3339(maxt) if maxt is not None else None - - if len(interval) == 1: - if start == min_ext or start == max_ext: - return True - - if not start: - return max_ext <= end or min_ext <= end - - elif not end: - return min_ext >= start or max_ext >= start - - else: - return min_ext >= start and max_ext <= end - - return False - - def accept_media_type(accept: str, mediatypes: List[MediaType]) -> Optional[MediaType]: """Return MediaType based on accept header and available mediatype. @@ -116,42 +76,8 @@ def OutputType( return accept_media_type(request.headers.get("accept", ""), accepted_media) -def bbox_query( - bbox: Annotated[ - Optional[str], - Query( - description="A bounding box, expressed in WGS84 (westLong,southLat,eastLong,northLat) or WGS84h (westLong,southLat,minHeight,eastLong,northLat,maxHeight) CRS, by which to filter out all collections whose spatial extent does not intersect with the bounding box.", - openapi_examples={ - "simple": {"value": "160.6,-55.95,-170,-25.89"}, - }, - ), - ] = None -) -> Optional[List[float]]: - """BBox dependency.""" - if bbox: - bounds = list(map(float, bbox.split(","))) - if len(bounds) == 4: - if abs(bounds[0]) > 180 or abs(bounds[2]) > 180: - raise InvalidBBox(f"Invalid longitude in bbox: {bounds}") - if abs(bounds[1]) > 90 or abs(bounds[3]) > 90: - raise InvalidBBox(f"Invalid latitude in bbox: {bounds}") - - elif len(bounds) == 6: - if abs(bounds[0]) > 180 or abs(bounds[3]) > 180: - raise InvalidBBox(f"Invalid longitude in bbox: {bounds}") - if abs(bounds[1]) > 90 or abs(bounds[4]) > 90: - raise InvalidBBox(f"Invalid latitude in bbox: {bounds}") - - else: - raise InvalidBBox(f"Invalid bbox: {bounds}") - - return bounds - - return None - - -def datetime_query( - datetime: Annotated[ +def cmr_query( + temporal: Annotated[ Optional[str], Query( description="Either a date-time or an interval. Date and time expressions adhere to [RFC 3339](https://www.rfc-editor.org/rfc/rfc3339). Intervals may be bounded or half-bounded (double-dots at start or end).", @@ -165,13 +91,21 @@ def datetime_query( }, ), ] = None, -) -> Optional[List[str]]: - """Datetime dependency.""" - if datetime: - dt = datetime.split("/") +) -> Dict: + """CMR Query options.""" + query = {} + if temporal: + dt = temporal.split("/") if len(dt) > 2: - raise HTTPException(status_code=422, detail="Invalid datetime: {datetime}") + raise HTTPException(status_code=422, detail="Invalid temporal: {temporal}") - return dt + if len(dt) == 1: + start = end = parse_rfc3339(dt[0]) - return None + else: + start = parse_rfc3339(dt[0]) if dt[0] not in ["..", ""] else None + end = parse_rfc3339(dt[1]) if dt[1] not in ["..", ""] else None + + query["temporal"] = [start, end] + + return query diff --git a/titiler/cmr/factory.py b/titiler/cmr/factory.py index be9e850..c1b6c93 100644 --- a/titiler/cmr/factory.py +++ b/titiler/cmr/factory.py @@ -3,22 +3,36 @@ import json import re from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union +from urllib.parse import urlencode import jinja2 +import numpy import orjson -from fastapi import APIRouter, Depends, Path +from fastapi import APIRouter, Depends, Path, Query from fastapi.responses import ORJSONResponse from morecantile import tms as default_tms from morecantile.defaults import TileMatrixSets +from pydantic import conint +from rio_tiler.io import Reader +from rio_tiler.types import RIOResampling, WarpResampling from starlette.requests import Request +from starlette.responses import Response from starlette.routing import compile_path, replace_params from starlette.templating import Jinja2Templates, _TemplateResponse from typing_extensions import Annotated from titiler.cmr import models -from titiler.cmr.dependencies import OutputType +from titiler.cmr.backend import CMRBackend +from titiler.cmr.dependencies import OutputType, cmr_query from titiler.cmr.enums import MediaType +from titiler.cmr.reader import ZarrReader +from titiler.core import dependencies +from titiler.core.algorithm import algorithms as available_algorithms +from titiler.core.factory import img_endpoint_params +from titiler.core.models.mapbox import TileJSON +from titiler.core.resources.enums import ImageType +from titiler.core.utils import render_image jinja2_env = jinja2.Environment( loader=jinja2.ChoiceLoader([jinja2.PackageLoader(__package__, "templates")]) @@ -127,6 +141,7 @@ def __post_init__(self): self.register_landing() self.register_conformance() self.register_tilematrixsets() + self.register_tiles() def register_landing(self) -> None: """register landing page endpoint.""" @@ -350,3 +365,394 @@ async def tilematrixset( ) return data + + def register_tiles(self): # noqa: C901 + """Register tileset endpoints.""" + + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}.{format}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + @self.router.get( + "/collections/{collectionId}/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}", + **img_endpoint_params, + tags=["Raster Tiles"], + ) + def tiles_endpoint( + request: Request, + collectionId: Annotated[ + str, + Path( + description="A CMR concept id, in the format '-' " + ), + ], + tileMatrixSetId: Annotated[ + Literal[tuple(self.supported_tms.list())], + Path(description="Identifier for a supported TileMatrixSet"), + ], + z: Annotated[ + int, + Path( + description="Identifier (Z) selecting one of the scales defined in the TileMatrixSet and representing the scaleDenominator the tile.", + ), + ], + x: Annotated[ + int, + Path( + description="Column (X) index of the tile on the selected TileMatrix. It cannot exceed the MatrixHeight-1 for the selected TileMatrix.", + ), + ], + y: Annotated[ + int, + Path( + description="Row (Y) index of the tile on the selected TileMatrix. It cannot exceed the MatrixWidth-1 for the selected TileMatrix.", + ), + ], + scale: Annotated[ # type: ignore + conint(gt=0, le=4), "Tile size scale. 1=256x256, 2=512x512..." + ] = 1, + format: Annotated[ + ImageType, + "Default will be automatically defined if the output image needs a mask (png) or not (jpeg).", + ] = None, + ################################################################### + # CMR options + query=Depends(cmr_query), + ################################################################### + backend: Annotated[ + Literal["cog", "xarray"], + Query(description="Backend to read the CMR dataset"), + ] = "cog", + ################################################################### + # ZarrReader Options + ################################################################### + variable: Annotated[ + Optional[str], + Query(description="Xarray Variable"), + ] = None, + drop_dim: Annotated[ + Optional[str], + Query(description="Dimension to drop"), + ] = None, + time_slice: Annotated[ + Optional[str], Query(description="Slice of time to read (if available)") + ] = None, + decode_times: Annotated[ + Optional[bool], + Query( + title="decode_times", + description="Whether to decode times", + ), + ] = None, + ################################################################### + # COG Reader Options + ################################################################### + indexes: Annotated[ + Optional[List[int]], + Query( + title="Band indexes", + alias="bidx", + description="Dataset band indexes", + ), + ] = None, + expression: Annotated[ + Optional[str], + Query( + title="Band Math expression", + description="rio-tiler's band math expression", + ), + ] = None, + unscale: Annotated[ + Optional[bool], + Query( + title="Apply internal Scale/Offset", + description="Apply internal Scale/Offset. Defaults to `False`.", + ), + ] = None, + resampling_method: Annotated[ + Optional[RIOResampling], + Query( + alias="resampling", + description="RasterIO resampling algorithm. Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Reader options + ################################################################### + nodata: Annotated[ + Optional[Union[str, int, float]], + Query( + title="Nodata value", + description="Overwrite internal Nodata value", + ), + ] = None, + reproject_method: Annotated[ + Optional[WarpResampling], + Query( + alias="reproject", + description="WarpKernel resampling algorithm (only used when doing re-projection). Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Rendering Options + ################################################################### + post_process=Depends(available_algorithms.dependency), + rescale=Depends(dependencies.RescalingParams), + color_formula=Depends(dependencies.ColorFormulaParams), + colormap=Depends(dependencies.ColorMapParams), + render_params=Depends(dependencies.ImageRenderingParams), + ) -> Response: + """Create map tile from a dataset.""" + resampling_method = resampling_method or "nearest" + reproject_method = reproject_method or "nearest" + if nodata is not None: + nodata = numpy.nan if nodata == "nan" else float(nodata) + + tms = self.supported_tms.get(tileMatrixSetId) + + read_options: Dict[str, Any] = {} + reader_options: Dict[str, Any] = {} + + if backend != "cog": + reader = ZarrReader + read_options = {} + + options = { + "variable": variable, + "decode_times": decode_times, + "drop_dim": drop_dim, + "time_slice": time_slice, + } + reader_options = {k: v for k, v in options.items() if v is not None} + else: + reader = Reader + options = { + "indexes": indexes, # type: ignore + "expression": expression, + "unscale": unscale, + "resampling_method": resampling_method, + } + read_options = {k: v for k, v in options.items() if v is not None} + + reader_options = {} + + with CMRBackend( + collectionId, + tms=tms, + reader=reader, + reader_options=reader_options, + auth=request.app.cmr_auth, + ) as src_dst: + image = src_dst.tile( + x, + y, + z, + tilesize=scale * 256, + cmr_query=cmr_query, + nodata=nodata, + reproject_method=reproject_method, + **read_options, + ) + + if post_process: + image = post_process(image) + + if rescale: + image.rescale(rescale) + + if color_formula: + image.apply_color_formula(color_formula) + + content, media_type = render_image( + image, + output_format=format, + colormap=colormap, + **render_params, + ) + + return Response(content, media_type=media_type) + + @self.router.get( + "/collections/{collectionId}/{tileMatrixSetId}/tilejson.json", + response_model=TileJSON, + responses={200: {"description": "Return a tilejson"}}, + response_model_exclude_none=True, + tags=["TileJSON"], + ) + def tilejson_endpoint( # type: ignore + request: Request, + collectionId: Annotated[ + str, + Path( + description="A CMR concept id, in the format '-' " + ), + ], + tileMatrixSetId: Annotated[ + Literal[tuple(self.supported_tms.list())], + Path(description="Identifier for a supported TileMatrixSet"), + ], + tile_format: Annotated[ + Optional[ImageType], + Query( + description="Default will be automatically defined if the output image needs a mask (png) or not (jpeg).", + ), + ] = None, + tile_scale: Annotated[ + int, + Query( + gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..." + ), + ] = 1, + minzoom: Annotated[ + Optional[int], + Query(description="Overwrite default minzoom."), + ] = None, + maxzoom: Annotated[ + Optional[int], + Query(description="Overwrite default maxzoom."), + ] = None, + ################################################################### + # CMR options + query=Depends(cmr_query), + ################################################################### + backend: Annotated[ + Literal["cog", "xarray"], + Query(description="Backend to read the CMR dataset"), + ] = "cog", + ################################################################### + # ZarrReader Options + ################################################################### + variable: Annotated[ + Optional[str], + Query(description="Xarray Variable"), + ] = None, + drop_dim: Annotated[ + Optional[str], + Query(description="Dimension to drop"), + ] = None, + time_slice: Annotated[ + Optional[str], Query(description="Slice of time to read (if available)") + ] = None, + decode_times: Annotated[ + Optional[bool], + Query( + title="decode_times", + description="Whether to decode times", + ), + ] = None, + ################################################################### + # COG Reader Options + ################################################################### + indexes: Annotated[ + Optional[List[int]], + Query( + title="Band indexes", + alias="bidx", + description="Dataset band indexes", + ), + ] = None, + expression: Annotated[ + Optional[str], + Query( + title="Band Math expression", + description="rio-tiler's band math expression", + ), + ] = None, + unscale: Annotated[ + Optional[bool], + Query( + title="Apply internal Scale/Offset", + description="Apply internal Scale/Offset. Defaults to `False`.", + ), + ] = None, + resampling_method: Annotated[ + Optional[RIOResampling], + Query( + alias="resampling", + description="RasterIO resampling algorithm. Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Reader options + ################################################################### + nodata: Annotated[ + Optional[Union[str, int, float]], + Query( + title="Nodata value", + description="Overwrite internal Nodata value", + ), + ] = None, + reproject_method: Annotated[ + Optional[WarpResampling], + Query( + alias="reproject", + description="WarpKernel resampling algorithm (only used when doing re-projection). Defaults to `nearest`.", + ), + ] = None, + ################################################################### + # Rendering Options + ################################################################### + post_process=Depends(available_algorithms.dependency), + rescale=Depends(dependencies.RescalingParams), + color_formula=Depends(dependencies.ColorFormulaParams), + colormap=Depends(dependencies.ColorMapParams), + render_params=Depends(dependencies.ImageRenderingParams), + ) -> Dict: + """Return TileJSON document for a dataset.""" + route_params = { + "z": "{z}", + "x": "{x}", + "y": "{y}", + "scale": tile_scale, + "tileMatrixSetId": tileMatrixSetId, + } + if tile_format: + route_params["format"] = tile_format.value + + tiles_url = self.url_for(request, "tiles_endpoint", **route_params) + + qs_key_to_remove = [ + "tilematrixsetid", + "tile_format", + "tile_scale", + "minzoom", + "maxzoom", + ] + qs = [ + (key, value) + for (key, value) in request.query_params._list + if key.lower() not in qs_key_to_remove + ] + if qs: + tiles_url += f"?{urlencode(qs)}" + + tms = self.supported_tms.get(tileMatrixSetId) + + # TODO: can we get metadata from the collection? + with CMRBackend( + collectionId, + auth=request.app.cmr_auth, + tms=tms, + ) as src_dst: + minx, miny, maxx, maxy = zip( + [-180, -90, 180, 90], list(src_dst.geographic_bounds) + ) + bounds = [max(minx), max(miny), min(maxx), min(maxy)] + + return { + "bounds": bounds, + "minzoom": minzoom if minzoom is not None else src_dst.minzoom, + "maxzoom": maxzoom if maxzoom is not None else src_dst.maxzoom, + "tiles": [tiles_url], + } diff --git a/titiler/cmr/main.py b/titiler/cmr/main.py index bada398..3d1d60a 100644 --- a/titiler/cmr/main.py +++ b/titiler/cmr/main.py @@ -1,5 +1,8 @@ """TiTiler+cmr FastAPI application.""" +from contextlib import asynccontextmanager + +import earthaccess import jinja2 from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware @@ -7,7 +10,7 @@ from titiler.cmr import __version__ as titiler_cmr_version from titiler.cmr.factory import Endpoints -from titiler.cmr.settings import ApiSettings +from titiler.cmr.settings import ApiSettings, AuthSettings from titiler.core.middleware import CacheControlMiddleware jinja2_env = jinja2.Environment( @@ -20,6 +23,18 @@ templates = Jinja2Templates(env=jinja2_env) settings = ApiSettings() +auth_config = AuthSettings() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """FastAPI Lifespan.""" + if auth_config.strategy == "environment": + app.state.cmr_auth = earthaccess.login(strategy="environment") + else: + app.state.cmr_auth = None + + yield app = FastAPI( @@ -38,6 +53,7 @@ """, version=titiler_cmr_version, root_path=settings.root_path, + lifespan=lifespan, ) diff --git a/titiler/cmr/reader.py b/titiler/cmr/reader.py new file mode 100644 index 0000000..76871de --- /dev/null +++ b/titiler/cmr/reader.py @@ -0,0 +1,286 @@ +"""ZarrReader. + +Originaly from titiler-xarray +""" + +import contextlib +import pickle +import re +from typing import Any, Dict, List, Optional + +import attr +import fsspec +import numpy +import s3fs +import xarray +from cachetools import TTLCache +from morecantile import TileMatrixSet +from rasterio.crs import CRS +from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS +from rio_tiler.io.xarray import XarrayReader +from rio_tiler.types import BBox + +from titiler.cmr.settings import CacheSettings + +# Use simple in-memory cache for now (we can switch to redis later) +cache_config = CacheSettings() +cache_client: Any = TTLCache(maxsize=cache_config.maxsize, ttl=cache_config.ttl) + + +def parse_protocol(src_path: str, reference: Optional[bool] = False): + """ + Parse protocol from path. + """ + match = re.match(r"^(s3|https|http)", src_path) + protocol = "file" + if match: + protocol = match.group(0) + + # override protocol if reference + if reference: + protocol = "reference" + + return protocol + + +def xarray_engine(src_path: str): + """ + Parse xarray engine from path. + """ + # ".hdf", ".hdf5", ".h5" will be supported once we have tests + expand the type permitted for the group parameter + H5NETCDF_EXTENSIONS = [".nc", ".nc4"] + lower_filename = src_path.lower() + if any(lower_filename.endswith(ext) for ext in H5NETCDF_EXTENSIONS): + return "h5netcdf" + else: + return "zarr" + + +def get_filesystem( + src_path: str, + protocol: str, + xr_engine: str, + anon: bool = True, + s3_credentials: Optional[Dict] = None, +): + """ + Get the filesystem for the given source path. + """ + if protocol == "s3": + s3_credentials = s3_credentials or {} + s3_filesystem = s3fs.S3FileSystem(**s3_credentials) + return ( + s3_filesystem.open(src_path) + if xr_engine == "h5netcdf" + else s3fs.S3Map(root=src_path, s3=s3_filesystem) + ) + + elif protocol == "reference": + reference_args = {"fo": src_path, "remote_options": {"anon": anon}} + return fsspec.filesystem("reference", **reference_args).get_mapper("") + + elif protocol in ["https", "http", "file"]: + filesystem = fsspec.filesystem(protocol) # type: ignore + return ( + filesystem.open(src_path) + if xr_engine == "h5netcdf" + else filesystem.get_mapper(src_path) + ) + + else: + raise ValueError(f"Unsupported protocol: {protocol}") + + +def xarray_open_dataset( + src_path: str, + group: Optional[Any] = None, + reference: Optional[bool] = False, + decode_times: Optional[bool] = True, + consolidated: Optional[bool] = True, + s3_credentials: Optional[Dict] = None, +) -> xarray.Dataset: + """Open dataset.""" + # Generate cache key and attempt to fetch the dataset from cache + cache_key = f"{src_path}_{group}" if group is not None else src_path + data_bytes = cache_client.get(cache_key, None) + if data_bytes: + return pickle.loads(data_bytes) + + protocol = parse_protocol(src_path, reference=reference) + xr_engine = xarray_engine(src_path) + file_handler = get_filesystem( + src_path, protocol, xr_engine, s3_credentials=s3_credentials + ) + + # Arguments for xarray.open_dataset + # Default args + xr_open_args: Dict[str, Any] = { + "decode_coords": "all", + "decode_times": decode_times, + } + + # Argument if we're opening a datatree + if isinstance(group, int): + xr_open_args["group"] = group + + # NetCDF arguments + if xr_engine == "h5netcdf": + xr_open_args["engine"] = "h5netcdf" + xr_open_args["lock"] = False + + else: + # Zarr arguments + xr_open_args["engine"] = "zarr" + xr_open_args["consolidated"] = consolidated + + # Additional arguments when dealing with a reference file. + if reference: + xr_open_args["consolidated"] = False + xr_open_args["backend_kwargs"] = {"consolidated": False} + + ds = xarray.open_dataset(file_handler, **xr_open_args) + + # Serialize the dataset to bytes using pickle + cache_client[cache_key] = pickle.dumps(ds) + + return ds + + +def arrange_coordinates(da: xarray.DataArray) -> xarray.DataArray: + """ + Arrange coordinates to DataArray. + An rioxarray.exceptions.InvalidDimensionOrder error is raised if the coordinates are not in the correct order time, y, and x. + See: https://github.com/corteva/rioxarray/discussions/674 + We conform to using x and y as the spatial dimension names. You can do this a bit more elegantly with metpy but that is a heavy dependency. + """ + if "x" not in da.dims and "y" not in da.dims: + latitude_var_name = "lat" + longitude_var_name = "lon" + if "latitude" in da.dims: + latitude_var_name = "latitude" + if "longitude" in da.dims: + longitude_var_name = "longitude" + da = da.rename({latitude_var_name: "y", longitude_var_name: "x"}) + if "time" in da.dims: + da = da.transpose("time", "y", "x") + else: + da = da.transpose("y", "x") + return da + + +def get_variable( + ds: xarray.Dataset, + variable: str, + time_slice: Optional[str] = None, + drop_dim: Optional[str] = None, +) -> xarray.DataArray: + """Get Xarray variable as DataArray.""" + da = ds[variable] + da = arrange_coordinates(da) + # TODO: add test + if drop_dim: + dim_to_drop, dim_val = drop_dim.split("=") + da = da.sel({dim_to_drop: dim_val}).drop(dim_to_drop) + da = arrange_coordinates(da) + + if (da.x > 180).any(): + # Adjust the longitude coordinates to the -180 to 180 range + da = da.assign_coords(x=(da.x + 180) % 360 - 180) + + # Sort the dataset by the updated longitude coordinates + da = da.sortby(da.x) + + # Make sure we have a valid CRS + crs = da.rio.crs or "epsg:4326" + da.rio.write_crs(crs, inplace=True) + + if "time" in da.dims: + if time_slice: + time_as_str = time_slice.split("T")[0] + if da["time"].dtype == "O": + da["time"] = da["time"].astype("datetime64[ns]") + da = da.sel( + time=numpy.array(time_as_str, dtype=numpy.datetime64), method="nearest" + ) + else: + da = da.isel(time=0) + + return da + + +@attr.s +class ZarrReader(XarrayReader): + """ZarrReader: Open Zarr file and access DataArray.""" + + src_path: str = attr.ib() + variable: str = attr.ib() + + # xarray.Dataset options + reference: bool = attr.ib(default=False) + decode_times: bool = attr.ib(default=False) + group: Optional[Any] = attr.ib(default=None) + consolidated: Optional[bool] = attr.ib(default=True) + s3_credentials: Optional[Dict] = attr.ib(default=None) + + # xarray.DataArray options + time_slice: Optional[str] = attr.ib(default=None) + drop_dim: Optional[str] = attr.ib(default=None) + + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + geographic_crs: CRS = attr.ib(default=WGS84_CRS) + + ds: xarray.Dataset = attr.ib(init=False) + input: xarray.DataArray = attr.ib(init=False) + + bounds: BBox = attr.ib(init=False) + crs: CRS = attr.ib(init=False) + + _minzoom: int = attr.ib(init=False, default=None) + _maxzoom: int = attr.ib(init=False, default=None) + + _dims: List = attr.ib(init=False, factory=list) + _ctx_stack = attr.ib(init=False, factory=contextlib.ExitStack) + + def __attrs_post_init__(self): + """Set bounds and CRS.""" + self.ds = self._ctx_stack.enter_context( + xarray_open_dataset( + self.src_path, + group=self.group, + reference=self.reference, + consolidated=self.consolidated, + s3_credentials=self.s3_credentials, + ), + ) + self.input = get_variable( + self.ds, + self.variable, + time_slice=self.time_slice, + drop_dim=self.drop_dim, + ) + + self.bounds = tuple(self.input.rio.bounds()) + self.crs = self.input.rio.crs + + self._dims = [ + d + for d in self.input.dims + if d not in [self.input.rio.x_dim, self.input.rio.y_dim] + ] + + @classmethod + def list_variables( + cls, + src_path: str, + group: Optional[Any] = None, + reference: Optional[bool] = False, + consolidated: Optional[bool] = True, + ) -> List[str]: + """List available variable in a dataset.""" + with xarray_open_dataset( + src_path, + group=group, + reference=reference, + consolidated=consolidated, + ) as ds: + return list(ds.data_vars) # type: ignore diff --git a/titiler/cmr/settings.py b/titiler/cmr/settings.py index 236afac..11a18bf 100644 --- a/titiler/cmr/settings.py +++ b/titiler/cmr/settings.py @@ -1,7 +1,10 @@ """API settings.""" -from pydantic import field_validator +from typing import Literal + +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings +from typing_extensions import Annotated class ApiSettings(BaseSettings): @@ -22,3 +25,51 @@ class ApiSettings(BaseSettings): def parse_cors_origin(cls, v): """Parse CORS origins.""" return [origin.strip() for origin in v.split(",")] + + +class CacheSettings(BaseSettings): + """Cache settings""" + + # TTL of the cache in seconds + ttl: int = 300 + + # Maximum size of the cache in Number of element + maxsize: int = 512 + + # Whether or not caching is enabled + disable: bool = False + + model_config = {"env_prefix": "TITILER_CMR_CACHE_", "env_file": ".env"} + + @model_validator(mode="after") + def check_enable(self): + """Check if cache is disabled.""" + if self.disable: + self.ttl = 0 + self.maxsize = 0 + + return self + + +class RetrySettings(BaseSettings): + """Retry settings""" + + retry: Annotated[int, Field(ge=0)] = 3 + delay: Annotated[float, Field(ge=0.0)] = 0.0 + + model_config = { + "env_prefix": "TITILER_CMR_API_", + "env_file": ".env", + "extra": "ignore", + } + + +class AuthSettings(BaseSettings): + """AWS credential settings.""" + + strategy: Literal["environment", "iam"] = "environment" + + model_config = { + "env_prefix": "TITILER_CMR_S3_AUTH_", + "env_file": ".env", + } diff --git a/titiler/cmr/templates/collection.html b/titiler/cmr/templates/collection.html deleted file mode 100644 index 5ebc174..0000000 --- a/titiler/cmr/templates/collection.html +++ /dev/null @@ -1,77 +0,0 @@ -{% include "header.html" %} - - - -

Collection: {{ response.title or response.id }}

- -
-
-

{{ response.description or response.title or response.id }}

- {% if "keywords" in response and length(response.keywords) > 0 %} -
-

- {% for keyword in response.keywords %} - {{ keyword }} - {% endfor %} -

-
- {% endif %} - -

Links

- -
-
-
- Loading... -
-
-
- - - -{% include "footer.html" %} diff --git a/titiler/cmr/templates/collections.html b/titiler/cmr/templates/collections.html deleted file mode 100644 index 45bb4ef..0000000 --- a/titiler/cmr/templates/collections.html +++ /dev/null @@ -1,118 +0,0 @@ -{% include "header.html" %} - -{% set show_prev_link = false %} -{% set show_next_link = false %} -{% if 'items?' in url %} - {% set urlq = url + '&' %} - {% else %} - {% set urlq = url + '?' %} -{% endif %} - - - -

Collections

- -

- Number of matching collections: {{ response.numberMatched }}
- Number of returned collections: {{ response.numberReturned }}
- Page: of
-

- -
- {% for link in response.links %} - {% if link.rel == 'prev' %} - - {% endif %} - {% endfor %} -
- -
- {% for link in response.links %} - {% if link.rel == 'next' %} - - {% endif %} - {% endfor %} -
- -
- - - - - - - - - -{% for collection in response.collections %} - - - - - -{% endfor %} - -
TitleTypeDescription
{{ collection.title or collection.id }}{{ collection.itemType }}{{ collection.description or collection.title or collection.id }}
-
- - - -{% include "footer.html" %} diff --git a/titiler/cmr/utils.py b/titiler/cmr/utils.py new file mode 100644 index 0000000..320c7d6 --- /dev/null +++ b/titiler/cmr/utils.py @@ -0,0 +1,35 @@ +"""titiler.cmr utilities. + +Code from titiler.pgstac, MIT License. + +""" + +import time +from typing import Any, Sequence, Type, Union + + +def retry( + tries: int, + exceptions: Union[Type[Exception], Sequence[Type[Exception]]] = Exception, + delay: float = 0.0, +): + """Retry Decorator""" + + def _decorator(func: Any): + def _newfn(*args: Any, **kwargs: Any): + + attempt = 0 + while attempt < tries: + try: + return func(*args, **kwargs) + + except exceptions: # type: ignore + + attempt += 1 + time.sleep(delay) + + return func(*args, **kwargs) + + return _newfn + + return _decorator