diff --git a/src/cogserver/extensions/__init__.py b/src/cogserver/extensions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cogserver/extensions.py b/src/cogserver/extensions/mosaicjson.py similarity index 82% rename from src/cogserver/extensions.py rename to src/cogserver/extensions/mosaicjson.py index 13a71e2..c3456f1 100644 --- a/src/cogserver/extensions.py +++ b/src/cogserver/extensions/mosaicjson.py @@ -1,24 +1,27 @@ -from dataclasses import dataclass, field -from typing import List, Optional +from dataclasses import dataclass +from typing import List, Optional + +from cogeo_mosaic.mosaic import MosaicJSON from cogserver.dependencies import SignedDatasetPaths -from typing_extensions import Annotated -from fastapi import Depends, FastAPI, Query +from fastapi import Depends, Query +from pydantic import BaseModel from titiler.core.factory import BaseTilerFactory, FactoryExtension -from cogeo_mosaic.mosaic import MosaicJSON from titiler.core.resources.responses import JSONResponse -from pydantic import BaseModel +from typing_extensions import Annotated urls = Annotated[List[str], Query(..., description="Dataset URLs")] + + class MosaicJsonCreateItem(BaseModel): - #url: List[str] = Query(..., description="Dataset URL") - urls:List[str] = urls + # url: List[str] = Query(..., description="Dataset URL") + urls: List[str] = urls minzoom: int = 0 maxzoom: int = 22 attribution: str = None @dataclass -class createMosaicJsonExtension(FactoryExtension): +class MosaicJsonExtension(FactoryExtension): def create_mosaic_json(self, urls=None, minzoom=None, maxzoom=None, attribution=None): mosaicjson = MosaicJSON.from_urls(urls=urls, minzoom=minzoom, maxzoom=maxzoom, ) @@ -37,7 +40,7 @@ def register(self, factory: BaseTilerFactory): 200: {"description": "Return a MosaicJSON from multiple COGs."}}, ) def build_mosaicJSON( - url = Depends(SignedDatasetPaths), + url=Depends(SignedDatasetPaths), minzoom: Optional[int] = 0, maxzoom: Optional[int] = 22, attribution: Optional[str] = None @@ -59,6 +62,4 @@ def build_mosaicJSON(payload: MosaicJsonCreateItem): maxzoom = payload.maxzoom attribution = payload.attribution - return self.create_mosaic_json(urls=url,minzoom=minzoom, maxzoom=maxzoom, attribution=attribution) - - + return self.create_mosaic_json(urls=url, minzoom=minzoom, maxzoom=maxzoom, attribution=attribution) diff --git a/src/cogserver/extensions/vrt.py b/src/cogserver/extensions/vrt.py new file mode 100644 index 0000000..a920370 --- /dev/null +++ b/src/cogserver/extensions/vrt.py @@ -0,0 +1,156 @@ +import re +import tempfile +from typing import List, Literal, Optional + +from fastapi import Query, Response +from osgeo import gdal +from titiler.core.factory import FactoryExtension +from cogserver.vrt import VRTFactory +from xml.etree import ElementTree as ET + + +async def create_vrt_from_urls( + urls: List[str], + resolution: Literal["highest", "lowest", "average", "user"] = "average", + xRes: float = 0.1, + yRes: float = 0.1, + vrtNoData: List[str] = 0, + srcNoData: List[str] = 0, + resamplingAlg: Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"] = "nearest", + +): + """ + Create a VRT from multiple COGs supplied as URLs + + Args: + urls (List[str]): List of URLs + resolution (Literal["highest", "lowest", "average", "user"], optional): Resolution to use for the resulting VRT. Defaults to "average". + xRes (float, optional): X resolution. Defaults to 0.1. Ignored if resolution is not "user". + yRes (float, optional): Y resolution. Defaults to 0.1. Ignored if resolution is not "user". + vrtNoData (List[str], optional): Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to 0. + srcNoData (List[str], optional): Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to 0. + resamplingAlg (Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"], optional): Resampling algorithm. Defaults to "nearest". + + Returns: + str: VRT XML + """ + urls = [f"/vsicurl/{url}" for url in urls] + if vrtNoData: + vrtNoData = " ".join(vrtNoData) + if srcNoData: + srcNoData = " ".join(srcNoData) + options = gdal.BuildVRTOptions( + separate=True, + bandList=list(range(1, len(urls) + 1)), + xRes=xRes, + yRes=yRes, + resampleAlg=resamplingAlg, + VRTNodata=vrtNoData, + srcNodata=srcNoData, + resolution=resolution, + ) + + with tempfile.NamedTemporaryFile() as temp: + gdal.BuildVRT(temp.name, urls, options=options) + with open(temp.name, "r") as file: + file_text = ET.fromstring(file.read()) + available_bands = file_text.findall("VRTRasterBand") + for source_band in available_bands: + source = None + complex_source = source_band.find("ComplexSource") + simple_source = source_band.find("SimpleSource") + if complex_source is not None: + source = complex_source.find("SourceFilename").text + elif simple_source is not None: + source = simple_source.find("SourceFilename").text + if source is not None: + gdalInfo = gdal.Info(source) + color_interp = gdalInfo.split("ColorInterp=")[-1].split("\n")[0] + dataset_metadata = gdalInfo.split("Metadata:")[-1] + metadata_dict = {} + pattern = r"(.*?)=(.*)" + + matches = re.findall(pattern, dataset_metadata) + for match in matches: + metadata_dict[match[0].strip()] = match[1].strip() + + source_band.append(ET.Element("Metadata")) + source_band.append(ET.Element("ColorInterp")) + source_band.find("ColorInterp").text = color_interp + for key, value in metadata_dict.items(): + metadata = ET.Element("MDI") + metadata.set("key", key) + metadata.text = value + source_band.find("Metadata").append(metadata) + return ET.tostring(file_text, encoding="unicode") + + +class VRTExtension(FactoryExtension): + """ + VRT Extension for the VRTFactory + """ + + def register(self, factory: VRTFactory): + """ + Register the VRT extension to the VRTFactory + + Args: + factory (VRTFactory): VRTFactory instance + + Returns: + None + """ + + @factory.router.get( + "", + response_class=Response, + responses={200: {"description": "Return a VRT from multiple COGs."}}, + summary="Create a VRT from multiple COGs", + ) + async def create_vrt( + urls: List[str] = Query(..., description="Dataset URLs"), + + srcNoData: List[str] = Query(None, + description="Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets."), + vrtNoData: List[str] = Query(None, + description="Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets."), + resamplingAlg: Literal[ + "nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"] = Query("nearest", + description="Resampling algorithm"), + resolution: Literal["highest", "lowest", "average", "user"] = Query("average", + description="Resolution to use for the resulting VRT"), + xRes: Optional[float] = Query(None, + description="X resolution. Applicable only when `resolution` is `user`"), + yRes: Optional[float] = Query(None, + description="Y resolution. Applicable only when `resolution` is `user`") + ): + """ + Create a VRT from multiple COGs supplied as URLs + + Args: + urls (List[str]): List of URLs + srcNoData (List[str], optional): Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to None. + vrtNoData (List[str], optional): Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to None. + resamplingAlg (Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"], optional): Resampling algorithm. Defaults to "nearest". + resolution (Literal["highest", "lowest", "average", "user"], optional): Resolution to use for the resulting VRT. Defaults to "average". + xRes (Optional[float], optional): X resolution. Defaults to None. + yRes (Optional[float], optional): Y resolution. Defaults to None. + + Returns: + Response: VRT XML + """ + if len(urls) < 1: + return Response("Please provide at least two URLs", status_code=400) + + if resolution == "user" and (not xRes or not yRes): + return Response("Please provide xRes and yRes for user resolution", status_code=400) + + return Response(await create_vrt_from_urls( + urls=urls, + xRes=xRes, + yRes=yRes, + srcNoData=srcNoData, + vrtNoData=vrtNoData, + resamplingAlg=resamplingAlg, + resolution=resolution + ), media_type="application/xml") diff --git a/src/cogserver/server.py b/src/cogserver/server.py index 0f79cb2..553c9cb 100644 --- a/src/cogserver/server.py +++ b/src/cogserver/server.py @@ -4,7 +4,7 @@ from rio_tiler.io import STACReader import logging from fastapi import FastAPI -from titiler.core.factory import TilerFactory, MultiBandTilerFactory, MultiBaseTilerFactory, AlgorithmFactory +from titiler.core.factory import TilerFactory, MultiBaseTilerFactory, AlgorithmFactory from titiler.application import __version__ as titiler_version from cogserver.landing import setup_landing from starlette.middleware.cors import CORSMiddleware @@ -13,13 +13,14 @@ from titiler.mosaic.errors import MOSAIC_STATUS_CODES from titiler.extensions.stac import stacExtension +from cogserver.vrt import VRTFactory +from cogserver.extensions.mosaicjson import MosaicJsonExtension +from cogserver.extensions.vrt import VRTExtension + logger = logging.getLogger(__name__) api_settings = default.api_settings - - - #################################### APP ###################################### app = FastAPI( title=api_settings.name, @@ -57,20 +58,22 @@ ############################################################################### ############################# MosaicJSON ###################################### -from cogserver.extensions import createMosaicJsonExtension + + mosaic = MosaicTilerFactory( router_prefix="/mosaicjson", path_dependency=SignedDatasetPath, process_dependency=algorithms.dependency, extensions=[ - createMosaicJsonExtension() + MosaicJsonExtension() ] ) app.include_router(mosaic.router, prefix="/mosaicjson", tags=["MosaicJSON"]) - ############################################################################### + + ############################# STAC ####################################### # STAC endpoints @@ -78,7 +81,7 @@ reader=STACReader, router_prefix="/stac", extensions=[ - #stacViewerExtension(), + # stacViewerExtension(), ], path_dependency=SignedDatasetPath, process_dependency=algorithms.dependency, @@ -94,8 +97,6 @@ ############################# MultiBand ####################################### - - ############################################################################### @@ -108,12 +109,25 @@ ) -############################################################################### +############################## VRT ################################### -############################# TileMatrixSets ################################## +vrt = VRTFactory( + router_prefix="/vrt", + path_dependency=SignedDatasetPath, + extensions=[ + VRTExtension() + ] +) + +app.include_router(vrt.router, prefix="/vrt", tags=["VRT"]) + +############################################################################### + + +############################# TileMatrixSets ################################## ############################################################################### @@ -123,6 +137,7 @@ def ping(): """Health check.""" return {"ping": "pong!"} + setup_landing(app) add_exception_handlers(app, DEFAULT_STATUS_CODES) @@ -136,4 +151,5 @@ def ping(): allow_credentials=True, allow_methods=['*'], allow_headers=['*'], - ) \ No newline at end of file + ) + diff --git a/src/cogserver/util.py b/src/cogserver/util.py index 68551f9..7ae9045 100644 --- a/src/cogserver/util.py +++ b/src/cogserver/util.py @@ -1,6 +1,8 @@ from fastapi import FastAPI from fastapi.dependencies.utils import get_parameterless_sub_dependant from fastapi import Depends + + def get_path_dependency(app:FastAPI=None, arg_name=None): """ Extract the first dependency of any kind whose arg name is arg_name @@ -38,6 +40,4 @@ def replace_dependency(app=None, new_dependency=None, arg_name=None): depends=depends, path=r.path_format # type: ignore ), ) - r.dependencies.extend([depends]) - #print([[e.call for n in e.query_params if n.name == arg_name] for e in r.dependant.dependencies if e]) - #print(r.dependencies) + r.dependencies.extend([depends]) \ No newline at end of file diff --git a/src/cogserver/vrt.py b/src/cogserver/vrt.py new file mode 100644 index 0000000..69a76b0 --- /dev/null +++ b/src/cogserver/vrt.py @@ -0,0 +1,15 @@ +from fastapi import APIRouter +from titiler.core.factory import MultiBaseTilerFactory + + +router = APIRouter() + + +class VRTFactory(MultiBaseTilerFactory): + """ + Override the MultiBaseTilerFactory to add a VRT endpoint + Empty register_routes method to override all routes and have no routes + """ + + def register_routes(self): + pass \ No newline at end of file