Skip to content

Commit

Permalink
Merge pull request #56 from UNDP-Data/feat/vrt-endpoint
Browse files Browse the repository at this point in the history
feat: implement vrt endpoint
  • Loading branch information
Thuhaa authored Mar 6, 2024
2 parents 16d2e74 + 409efd8 commit dbcf687
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 29 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from dataclasses import dataclass, field
from typing import List, Optional
from dataclasses import dataclass
from typing import List, Optional

from cogeo_mosaic.mosaic import MosaicJSON
from cogserver.dependencies import SignedDatasetPaths
from typing_extensions import Annotated
from fastapi import Depends, FastAPI, Query
from fastapi import Depends, Query
from pydantic import BaseModel
from titiler.core.factory import BaseTilerFactory, FactoryExtension
from cogeo_mosaic.mosaic import MosaicJSON
from titiler.core.resources.responses import JSONResponse
from pydantic import BaseModel
from typing_extensions import Annotated

urls = Annotated[List[str], Query(..., description="Dataset URLs")]


class MosaicJsonCreateItem(BaseModel):
#url: List[str] = Query(..., description="Dataset URL")
urls:List[str] = urls
# url: List[str] = Query(..., description="Dataset URL")
urls: List[str] = urls
minzoom: int = 0
maxzoom: int = 22
attribution: str = None


@dataclass
class createMosaicJsonExtension(FactoryExtension):
class MosaicJsonExtension(FactoryExtension):

def create_mosaic_json(self, urls=None, minzoom=None, maxzoom=None, attribution=None):
mosaicjson = MosaicJSON.from_urls(urls=urls, minzoom=minzoom, maxzoom=maxzoom, )
Expand All @@ -37,7 +40,7 @@ def register(self, factory: BaseTilerFactory):
200: {"description": "Return a MosaicJSON from multiple COGs."}},
)
def build_mosaicJSON(
url = Depends(SignedDatasetPaths),
url=Depends(SignedDatasetPaths),
minzoom: Optional[int] = 0,
maxzoom: Optional[int] = 22,
attribution: Optional[str] = None
Expand All @@ -59,6 +62,4 @@ def build_mosaicJSON(payload: MosaicJsonCreateItem):
maxzoom = payload.maxzoom
attribution = payload.attribution

return self.create_mosaic_json(urls=url,minzoom=minzoom, maxzoom=maxzoom, attribution=attribution)


return self.create_mosaic_json(urls=url, minzoom=minzoom, maxzoom=maxzoom, attribution=attribution)
156 changes: 156 additions & 0 deletions src/cogserver/extensions/vrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import re
import tempfile
from typing import List, Literal, Optional

from fastapi import Query, Response
from osgeo import gdal
from titiler.core.factory import FactoryExtension
from cogserver.vrt import VRTFactory
from xml.etree import ElementTree as ET


async def create_vrt_from_urls(
urls: List[str],
resolution: Literal["highest", "lowest", "average", "user"] = "average",
xRes: float = 0.1,
yRes: float = 0.1,
vrtNoData: List[str] = 0,
srcNoData: List[str] = 0,
resamplingAlg: Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"] = "nearest",

):
"""
Create a VRT from multiple COGs supplied as URLs
Args:
urls (List[str]): List of URLs
resolution (Literal["highest", "lowest", "average", "user"], optional): Resolution to use for the resulting VRT. Defaults to "average".
xRes (float, optional): X resolution. Defaults to 0.1. Ignored if resolution is not "user".
yRes (float, optional): Y resolution. Defaults to 0.1. Ignored if resolution is not "user".
vrtNoData (List[str], optional): Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to 0.
srcNoData (List[str], optional): Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to 0.
resamplingAlg (Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"], optional): Resampling algorithm. Defaults to "nearest".
Returns:
str: VRT XML
"""
urls = [f"/vsicurl/{url}" for url in urls]
if vrtNoData:
vrtNoData = " ".join(vrtNoData)
if srcNoData:
srcNoData = " ".join(srcNoData)
options = gdal.BuildVRTOptions(
separate=True,
bandList=list(range(1, len(urls) + 1)),
xRes=xRes,
yRes=yRes,
resampleAlg=resamplingAlg,
VRTNodata=vrtNoData,
srcNodata=srcNoData,
resolution=resolution,
)

with tempfile.NamedTemporaryFile() as temp:
gdal.BuildVRT(temp.name, urls, options=options)
with open(temp.name, "r") as file:
file_text = ET.fromstring(file.read())
available_bands = file_text.findall("VRTRasterBand")
for source_band in available_bands:
source = None
complex_source = source_band.find("ComplexSource")
simple_source = source_band.find("SimpleSource")
if complex_source is not None:
source = complex_source.find("SourceFilename").text
elif simple_source is not None:
source = simple_source.find("SourceFilename").text
if source is not None:
gdalInfo = gdal.Info(source)
color_interp = gdalInfo.split("ColorInterp=")[-1].split("\n")[0]
dataset_metadata = gdalInfo.split("Metadata:")[-1]
metadata_dict = {}
pattern = r"(.*?)=(.*)"

matches = re.findall(pattern, dataset_metadata)
for match in matches:
metadata_dict[match[0].strip()] = match[1].strip()

source_band.append(ET.Element("Metadata"))
source_band.append(ET.Element("ColorInterp"))
source_band.find("ColorInterp").text = color_interp
for key, value in metadata_dict.items():
metadata = ET.Element("MDI")
metadata.set("key", key)
metadata.text = value
source_band.find("Metadata").append(metadata)
return ET.tostring(file_text, encoding="unicode")


class VRTExtension(FactoryExtension):
"""
VRT Extension for the VRTFactory
"""

def register(self, factory: VRTFactory):
"""
Register the VRT extension to the VRTFactory
Args:
factory (VRTFactory): VRTFactory instance
Returns:
None
"""

@factory.router.get(
"",
response_class=Response,
responses={200: {"description": "Return a VRT from multiple COGs."}},
summary="Create a VRT from multiple COGs",
)
async def create_vrt(
urls: List[str] = Query(..., description="Dataset URLs"),

srcNoData: List[str] = Query(None,
description="Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets."),
vrtNoData: List[str] = Query(None,
description="Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets."),
resamplingAlg: Literal[
"nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"] = Query("nearest",
description="Resampling algorithm"),
resolution: Literal["highest", "lowest", "average", "user"] = Query("average",
description="Resolution to use for the resulting VRT"),
xRes: Optional[float] = Query(None,
description="X resolution. Applicable only when `resolution` is `user`"),
yRes: Optional[float] = Query(None,
description="Y resolution. Applicable only when `resolution` is `user`")
):
"""
Create a VRT from multiple COGs supplied as URLs
Args:
urls (List[str]): List of URLs
srcNoData (List[str], optional): Set nodata values for input bands (different values can be supplied for each band). If the option is not specified, the intrinsic nodata settings on the source datasets will be used (if they exist). The value set by this option is written in the NODATA element of each ComplexSource element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to None.
vrtNoData (List[str], optional): Set nodata values at the VRT band level (different values can be supplied for each band). If the option is not specified, intrinsic nodata settings on the first dataset will be used (if they exist). The value set by this option is written in the NoDataValue element of each VRTRasterBand element. Use a value of None to ignore intrinsic nodata settings on the source datasets. Defaults to None.
resamplingAlg (Literal["nearest", "bilinear", "cubic", "cubicspline", "lanczos", "average", "mode"], optional): Resampling algorithm. Defaults to "nearest".
resolution (Literal["highest", "lowest", "average", "user"], optional): Resolution to use for the resulting VRT. Defaults to "average".
xRes (Optional[float], optional): X resolution. Defaults to None.
yRes (Optional[float], optional): Y resolution. Defaults to None.
Returns:
Response: VRT XML
"""
if len(urls) < 1:
return Response("Please provide at least two URLs", status_code=400)

if resolution == "user" and (not xRes or not yRes):
return Response("Please provide xRes and yRes for user resolution", status_code=400)

return Response(await create_vrt_from_urls(
urls=urls,
xRes=xRes,
yRes=yRes,
srcNoData=srcNoData,
vrtNoData=vrtNoData,
resamplingAlg=resamplingAlg,
resolution=resolution
), media_type="application/xml")
42 changes: 29 additions & 13 deletions src/cogserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rio_tiler.io import STACReader
import logging
from fastapi import FastAPI
from titiler.core.factory import TilerFactory, MultiBandTilerFactory, MultiBaseTilerFactory, AlgorithmFactory
from titiler.core.factory import TilerFactory, MultiBaseTilerFactory, AlgorithmFactory
from titiler.application import __version__ as titiler_version
from cogserver.landing import setup_landing
from starlette.middleware.cors import CORSMiddleware
Expand All @@ -13,13 +13,14 @@
from titiler.mosaic.errors import MOSAIC_STATUS_CODES
from titiler.extensions.stac import stacExtension

from cogserver.vrt import VRTFactory
from cogserver.extensions.mosaicjson import MosaicJsonExtension
from cogserver.extensions.vrt import VRTExtension

logger = logging.getLogger(__name__)

api_settings = default.api_settings




#################################### APP ######################################
app = FastAPI(
title=api_settings.name,
Expand Down Expand Up @@ -57,28 +58,30 @@
###############################################################################

############################# MosaicJSON ######################################
from cogserver.extensions import createMosaicJsonExtension


mosaic = MosaicTilerFactory(
router_prefix="/mosaicjson",
path_dependency=SignedDatasetPath,
process_dependency=algorithms.dependency,
extensions=[
createMosaicJsonExtension()
MosaicJsonExtension()
]
)
app.include_router(mosaic.router, prefix="/mosaicjson", tags=["MosaicJSON"])


###############################################################################



############################# STAC #######################################
# STAC endpoints

stac = MultiBaseTilerFactory(
reader=STACReader,
router_prefix="/stac",
extensions=[
#stacViewerExtension(),
# stacViewerExtension(),
],
path_dependency=SignedDatasetPath,
process_dependency=algorithms.dependency,
Expand All @@ -94,8 +97,6 @@
############################# MultiBand #######################################




###############################################################################


Expand All @@ -108,12 +109,25 @@
)


###############################################################################
############################## VRT ###################################


############################# TileMatrixSets ##################################
vrt = VRTFactory(
router_prefix="/vrt",
path_dependency=SignedDatasetPath,
extensions=[
VRTExtension()
]
)

app.include_router(vrt.router, prefix="/vrt", tags=["VRT"])



###############################################################################


############################# TileMatrixSets ##################################


###############################################################################
Expand All @@ -123,6 +137,7 @@ def ping():
"""Health check."""
return {"ping": "pong!"}


setup_landing(app)

add_exception_handlers(app, DEFAULT_STATUS_CODES)
Expand All @@ -136,4 +151,5 @@ def ping():
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
)

6 changes: 3 additions & 3 deletions src/cogserver/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import FastAPI
from fastapi.dependencies.utils import get_parameterless_sub_dependant
from fastapi import Depends


def get_path_dependency(app:FastAPI=None, arg_name=None):
"""
Extract the first dependency of any kind whose arg name is arg_name
Expand Down Expand Up @@ -38,6 +40,4 @@ def replace_dependency(app=None, new_dependency=None, arg_name=None):
depends=depends, path=r.path_format # type: ignore
),
)
r.dependencies.extend([depends])
#print([[e.call for n in e.query_params if n.name == arg_name] for e in r.dependant.dependencies if e])
#print(r.dependencies)
r.dependencies.extend([depends])
15 changes: 15 additions & 0 deletions src/cogserver/vrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fastapi import APIRouter
from titiler.core.factory import MultiBaseTilerFactory


router = APIRouter()


class VRTFactory(MultiBaseTilerFactory):
"""
Override the MultiBaseTilerFactory to add a VRT endpoint
Empty register_routes method to override all routes and have no routes
"""

def register_routes(self):
pass

0 comments on commit dbcf687

Please sign in to comment.