diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7244eb1..af61341 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - # - id: check-added-large-files - # args: ['--maxkb=65536'] - id: check-ast - id: check-builtin-literals - id: check-byte-order-marker @@ -43,12 +41,6 @@ repos: # language: system # types: [python] # exclude: "(^experiments/|.*_depr)" - # - id: flake8 - # name: flake8 - # entry: flake8 - # language: system - # types: [python] - # exclude: "(^tasks/|.*_depr)" - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: diff --git a/examples/feature_extractors/quantile_feature_extractor.py b/examples/feature_extractors/quantile_feature_extractor.py new file mode 100644 index 0000000..aafd8b1 --- /dev/null +++ b/examples/feature_extractors/quantile_feature_extractor.py @@ -0,0 +1,149 @@ +"""This file provides the example of a basic feature extractor that computes +indices from preprocessed data and extracts quantiles from those indices. + +Implementing a function in a source file, and then calling the +`apply_feature_extractor` in OpenEO should be enoug to run the inference. +""" +import openeo +import xarray as xr + +# Spatial and temporal definitions +from openeo_gfmap import BoundingBoxExtent, TemporalContext + +# Backend and context +from openeo_gfmap.backend import Backend, BackendContext +from openeo_gfmap.features import PatchFeatureExtractor +from openeo_gfmap.features.feature_extractor import apply_feature_extractor + +# Fetching type (either TILE, POLYGON or POINT) +from openeo_gfmap.fetching import FetchType, build_sentinel2_l2a_extractor + +# Preprocessing options +from openeo_gfmap.preprocessing import ( + linear_interpolation, + mask_scl_dilation, + median_compositing, +) + + +class QuantileIndicesExtractor(PatchFeatureExtractor): + """Performs feature extraction by returning qunatile indices of the input + array.""" + + def output_labels(self) -> list: + indices_names = ["NDVI", "NDWI", "NDMI", "NDRE", "NDRE5", "B11", "B12"] + quantile_values = ["10", "50", "90", "IQR"] + quantile_names = [ + index_name + ":" + quantile_value + for index_name in indices_names + for quantile_value in quantile_values + ] + return quantile_names + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + # compute indices + b03 = inarr.sel(bands="S2-B03") + b04 = inarr.sel(bands="S2-B04") + b05 = inarr.sel(bands="S2-B05") + b06 = inarr.sel(bands="S2-B06") + b08 = inarr.sel(bands="S2-B08") + b11 = inarr.sel(bands="S2-B11") + b12 = inarr.sel(bands="S2-B12") + + ndvi = (b08 - b04) / (b08 + b04) + ndwi = (b03 - b08) / (b03 + b08) + ndmi = (b08 - b11) / (b08 + b11) + ndre = (b05 - b08) / (b05 + b08) + ndre5 = (b06 - b08) / (b06 + b08) + + indices = [ndvi, ndwi, ndmi, ndre, ndre5, b11, b12] + + quantile_arrays = [] + + for data in indices: + q01 = data.quantile(0.1, dim=["t"]).drop("quantile") + q05 = data.quantile(0.5, dim=["t"]).drop("quantile") + q09 = data.quantile(0.9, dim=["t"]).drop("quantile") + iqr = q09 - q01 + + quantile_arrays.extend([q01, q05, q09, iqr]) + + # Pack the quantile arrays into a single array + quantile_array = ( + xr.concat(quantile_arrays, dim="bands") + .assign_coords({"bands": self.output_labels()}) + .transpose("bands", "y", "x") + ) + + return quantile_array + + +if __name__ == "__main__": + connection = openeo.connect("https://openeo.vito.be").authenticate_oidc() + + # Define your spatial and temporal context + bbox_extent = BoundingBoxExtent( + west=4.515859656828771, + south=50.81721602547749, + east=4.541689831106636, + north=50.83654859110982, + epsg=4326, + ) + + # Define your temporal context, summer 2022 + temporal_extent = TemporalContext(start_date="2022-06-21", end_date="2022-09-23") + + # Define your backend context + backend_context = BackendContext(backend=Backend.TERRASCOPE) + + # Prepare your S2_L2A extractor + + # The bands that you can extract are defined in the code openeo_gfmap.fetching.s2.BASE_SENTINEL2_L2A_MAPPING + bands = [ + "S2-B03", + "S2-B04", + "S2-B05", + "S2-B06", + "S2-B08", + "S2-B11", + "S2-B12", + "S2-SCL", + ] + + # Use the base feching + fetching_parameters = {} + fetcher = build_sentinel2_l2a_extractor( + backend_context, bands, fetch_type=FetchType.TILE, **fetching_parameters + ) + + cube = fetcher.get_cube( + connection, spatial_context=bbox_extent, temporal_context=temporal_extent + ) + + # Perform pre-processing, compositing & linear interpolation + cube = mask_scl_dilation(cube) + cube = median_compositing(cube, period="dekad") + cube = linear_interpolation(cube) + + # Apply the feature extractor UDF + features = apply_feature_extractor( + QuantileIndicesExtractor, + cube, + parameters={}, # No additional parameter required by your UDF + size=[ + {"dimension": "x", "unit": "px", "value": 128}, + {"dimension": "y", "unit": "px", "value": 128}, + ], + ) + + # Start the job + job = features.create_job( + title="Quantile indices extraction - Tervuren Park", out_format="NetCDF" + ) + + job.start_and_wait() + + # Download the results + for asset in job.get_results().get_assets(): + if asset.metadata["type"].startswith("application/x-netcdf"): + asset.download("/data/users/Public/couchard/test_features.nc") diff --git a/src/openeo_gfmap/__init__.py b/src/openeo_gfmap/__init__.py index 0e1dbe7..3748d2b 100644 --- a/src/openeo_gfmap/__init__.py +++ b/src/openeo_gfmap/__init__.py @@ -7,6 +7,7 @@ """ from .backend import Backend, BackendContext +from .fetching import FetchType from .metadata import FakeMetadata from .spatial import BoundingBoxExtent, SpatialContext from .temporal import TemporalContext @@ -18,4 +19,5 @@ "BoundingBoxExtent", "TemporalContext", "FakeMetadata", + "FetchType", ] diff --git a/src/openeo_gfmap/backend.py b/src/openeo_gfmap/backend.py index b40764e..f63e054 100644 --- a/src/openeo_gfmap/backend.py +++ b/src/openeo_gfmap/backend.py @@ -97,5 +97,4 @@ def eodc_connection() -> openeo.Connection: BACKEND_CONNECTIONS: Dict[Backend, Callable] = { Backend.TERRASCOPE: vito_connection, Backend.CDSE: cdse_connection, - Backend.EODC: eodc_connection, } diff --git a/src/openeo_gfmap/features/__init__.py b/src/openeo_gfmap/features/__init__.py new file mode 100644 index 0000000..9205470 --- /dev/null +++ b/src/openeo_gfmap/features/__init__.py @@ -0,0 +1,17 @@ +from openeo_gfmap.features.feature_extractor import ( + LAT_HARMONIZED_NAME, + LON_HARMONIZED_NAME, + PatchFeatureExtractor, + PointFeatureExtractor, + apply_feature_extractor, + apply_feature_extractor_local, +) + +__all__ = [ + "PatchFeatureExtractor", + "PointFeatureExtractor", + "LAT_HARMONIZED_NAME", + "LON_HARMONIZED_NAME", + "apply_feature_extractor", + "apply_feature_extractor_local", +] diff --git a/src/openeo_gfmap/features/feature_extractor.py b/src/openeo_gfmap/features/feature_extractor.py new file mode 100644 index 0000000..28ceb7a --- /dev/null +++ b/src/openeo_gfmap/features/feature_extractor.py @@ -0,0 +1,289 @@ +"""Feature extractor functionalities. Such as a base class to assist the +implementation of feature extractors of a UDF. +""" + +from abc import ABC, abstractmethod + +import numpy as np +import openeo +import xarray as xr +from openeo.udf import XarrayDataCube +from openeo.udf.run_code import execute_local_udf +from openeo.udf.udf_data import UdfData +from pyproj import Transformer +from pyproj.crs import CRS + +REQUIRED_IMPORTS = """ +from abc import ABC, abstractmethod + +import openeo +from openeo.udf import XarrayDataCube, inspect +from openeo.udf.udf_data import UdfData + +import xarray as xr +import numpy as np +from pyproj import Transformer +from pyproj.crs import CRS + +from typing import Union +""" + + +LAT_HARMONIZED_NAME = "GEO-LAT" +LON_HARMONIZED_NAME = "GEO-LON" +EPSG_HARMONIZED_NAME = "GEO-EPSG" + +# To fill in: EPSG_HARMONIZED_NAME, Is it pixel based and Feature Extractor class +APPLY_DATACUBE_SOURCE_CODE = """ +LAT_HARMONIZED_NAME = "{lat_harmonized_name}" +LON_HARMONIZED_NAME = "{lon_harmonized_name}" +EPSG_HARMONIZED_NAME = "{epsg_harmonized_name}" + +from openeo.udf import XarrayDataCube +from openeo.udf.udf_data import UdfData + +IS_PIXEL_BASED = {is_pixel_based} + +def apply_udf_data(udf_data: UdfData) -> XarrayDataCube: + feature_extractor = {feature_extractor_class}() # User-defined, feature extractor class initialized here + + if not IS_PIXEL_BASED: + assert len(udf_data.datacube_list) == 1, "OpenEO GFMAP Feature extractor pipeline only supports single input cubes for the tile." + + cube = udf_data.datacube_list[0] + parameters = udf_data.user_context + + proj = udf_data.proj + if proj is not None: + proj = proj["EPSG"] + + parameters[EPSG_HARMONIZED_NAME] = proj + + cube = feature_extractor._execute(cube, parameters=parameters) + + udf_data.datacube_list = [cube] + + return udf_data +""" + + +class FeatureExtractor(ABC): + """Base class for all feature extractor UDFs. It provides some common + methods and attributes to be used by other feature extractor. + + The inherited classes are supposed to take care of VectorDataCubes for + point based extraction or dense Cubes for tile/polygon based extraction. + """ + + def _common_preparations( + self, inarr: xr.DataArray, parameters: dict + ) -> xr.DataArray: + """Common preparations to be executed before the feature extractor is + executed. This method should be called by the `_execute` method of the + feature extractor. + """ + self._epsg = parameters.pop(EPSG_HARMONIZED_NAME) + self._parameters = parameters + return inarr + + @property + def epsg(self) -> int: + """Returns the EPSG code of the datacube.""" + return self._epsg + + @abstractmethod + def output_labels(self) -> list: + """Returns a list of output labels to be assigned on the output bands, + needs to be overriden by the user.""" + raise NotImplementedError( + "FeatureExtractor is a base abstract class, please implement the " + "output_labels property." + ) + + def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube: + raise NotImplementedError( + "FeatureExtractor is a base abstract class, please implement the " + "_execute method." + ) + + +class PatchFeatureExtractor(FeatureExtractor): + """Base class for all the tile/polygon based feature extractors. An user + implementing a feature extractor should take care of + """ + + def get_latlons(self, inarr: xr.DataArray) -> xr.DataArray: + """Returns the latitude and longitude coordinates of the given array in + a dataarray. Returns a dataarray with the same width/height of the input + array, but with two bands, one for latitude and one for longitude. The + metadata coordinates of the output array are the same as the input + array, as the array wasn't reprojected but instead new features were + computed. + + The latitude and longitude band names are standardized to the names + `LAT_HARMONIZED_NAME` and `LON_HARMONIZED_NAME` respectively. + """ + + lon = inarr.coords["x"] + lat = inarr.coords["y"] + lon, lat = np.meshgrid(lon, lat) + + if self.epsg is None: + raise Exception( + "EPSG code was not defined, cannot extract lat/lon array " + "as the CRS is unknown." + ) + + # If the coordiantes are not in EPSG:4326, we need to reproject them + if self.epsg != 4326: + # Initializes a pyproj reprojection object + transformer = Transformer.from_crs( + crs_from=CRS.from_epsg(self.epsg), + crs_to=CRS.from_epsg(4326), + always_xy=True, + ) + lon, lat = transformer.transform(xx=lon, yy=lat) + + # Create a two channel numpy array of the lat and lons together by stacking + latlon = np.stack([lat, lon]) + + # Repack in a dataarray + return xr.DataArray( + latlon, + dims=["bands", "y", "x"], + coords={ + "bands": [LAT_HARMONIZED_NAME, LON_HARMONIZED_NAME], + "y": inarr.coords["y"], + "x": inarr.coords["x"], + }, + ) + + def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube: + arr = cube.get_array().transpose("bands", "t", "y", "x") + arr = self._common_preparations(arr, parameters) + arr = self.execute(arr).transpose("bands", "y", "x") + return XarrayDataCube(arr) + + @abstractmethod + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + pass + + +class PointFeatureExtractor(FeatureExtractor): + def __init__(self): + raise NotImplementedError( + "Point based feature extraction on Vector Cubes is not supported yet." + ) + + def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube: + arr = cube.get_array().transpose("bands", "t") + + arr = self._common_preparations(arr, parameters) + + outarr = self.execute(cube.to_array()).transpose("bands", "t") + return XarrayDataCube(outarr) + + @abstractmethod + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + pass + + +def generate_udf_code(feature_extractor_class: FeatureExtractor) -> openeo.UDF: + """Generates the udf code by packing imports of this file, the necessary + superclass and subclasses as well as the user defined feature extractor + class and the apply_datacube function. + """ + import inspect + + # UDF code that will be built here + udf_code = "" + + assert issubclass( + feature_extractor_class, FeatureExtractor + ), "The feature extractor class must be a subclass of FeatureExtractor." + + if issubclass(feature_extractor_class, PatchFeatureExtractor): + udf_code += f"{REQUIRED_IMPORTS}\n\n" + udf_code += f"{inspect.getsource(FeatureExtractor)}\n\n" + udf_code += f"{inspect.getsource(PatchFeatureExtractor)}\n\n" + udf_code += f"{inspect.getsource(feature_extractor_class)}\n\n" + udf_code += APPLY_DATACUBE_SOURCE_CODE.format( + lat_harmonized_name=LAT_HARMONIZED_NAME, + lon_harmonized_name=LON_HARMONIZED_NAME, + epsg_harmonized_name=EPSG_HARMONIZED_NAME, + is_pixel_based=False, + feature_extractor_class=feature_extractor_class.__name__, + ) + elif issubclass(feature_extractor_class, PointFeatureExtractor): + udf_code += f"{REQUIRED_IMPORTS}\n\n" + udf_code += f"{inspect.getsource(FeatureExtractor)}\n\n" + udf_code += f"{inspect.getsource(PointFeatureExtractor)}\n\n" + udf_code += f"{inspect.getsource(feature_extractor_class)}\n\n" + udf_code += APPLY_DATACUBE_SOURCE_CODE.format( + lat_harmonized_name=LAT_HARMONIZED_NAME, + lon_harmonized_name=LON_HARMONIZED_NAME, + epsg_harmonized_name=EPSG_HARMONIZED_NAME, + is_pixel_based=True, + feature_extractor_class=feature_extractor_class.__name__, + ) + else: + raise NotImplementedError( + "The feature extractor must be a subclass of either " + "PatchFeatureExtractor or PointFeatureExtractor." + ) + + return udf_code + + +def apply_feature_extractor( + feature_extractor_class: FeatureExtractor, + cube: openeo.rest.datacube.DataCube, + parameters: dict, + size: list, + overlap: list = [], +) -> openeo.rest.datacube.DataCube: + """Applies an user-defined feature extractor on the cube by using the + `openeo.Cube.apply_neighborhood` method. The defined class as well as the + required subclasses will be packed into a generated UDF file that will be + executed. + + Optimization can be achieved by passing integer values for the cube. By + default, the feature extractor expects to receive S1 and S2 data stored in + uint16 with the harmonized naming as implemented in the fetching module. + """ + + udf_code = generate_udf_code(feature_extractor_class) + + udf = openeo.UDF(code=udf_code, context=parameters) + + cube = cube.apply_neighborhood(process=udf, size=size, overlap=overlap) + return cube.rename_labels( + dimension="bands", target=feature_extractor_class().output_labels() + ) + + +def apply_feature_extractor_local( + feature_extractor_class: FeatureExtractor, cube: xr.DataArray, parameters: dict +) -> xr.DataArray: + """Applies and user-define feature extractor, but locally. The + parameters are the same as in the `apply_feature_extractor` function, + excepts for the cube parameter which expects a `xarray.DataArray` instead of + a `openeo.rest.datacube.DataCube` object. + """ + udf_code = generate_udf_code(feature_extractor_class) + + udf = openeo.UDF(code=udf_code, context=parameters) + + cube = XarrayDataCube(cube) + + out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF") + + output_cubes = out_udf_data.datacube_list + + assert len(output_cubes) == 1, "UDF should have only a single output cube." + + return ( + output_cubes[0] + .get_array() + .assign_coords({"bands": feature_extractor_class().output_labels()}) + ) diff --git a/src/openeo_gfmap/fetching/__init__.py b/src/openeo_gfmap/fetching/__init__.py index f1bbd9f..fccda37 100644 --- a/src/openeo_gfmap/fetching/__init__.py +++ b/src/openeo_gfmap/fetching/__init__.py @@ -10,6 +10,8 @@ from .s2 import build_sentinel2_l2a_extractor __all__ = [ - "build_sentinel2_l2a_extractor", "CollectionFetcher", "FetchType", - "build_sentinel1_grd_extractor" + "build_sentinel2_l2a_extractor", + "CollectionFetcher", + "FetchType", + "build_sentinel1_grd_extractor", ] diff --git a/src/openeo_gfmap/fetching/s1.py b/src/openeo_gfmap/fetching/s1.py index 9765078..4fcee4f 100644 --- a/src/openeo_gfmap/fetching/s1.py +++ b/src/openeo_gfmap/fetching/s1.py @@ -127,28 +127,24 @@ def s1_grd_default_processor(cube: openeo.DataCube, **params): SENTINEL1_GRD_BACKEND_MAP = { Backend.TERRASCOPE: { - "default": partial( - get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD" - ), + "default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"), "preprocessor": partial( get_s1_grd_default_processor, collection_name="SENTINEL1_GRD" - ) + ), }, Backend.CDSE: { - "default": partial( - get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD" - ), + "default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"), "preprocessor": partial( get_s1_grd_default_processor, collection_name="SENTINEL1_GRD" - ) - } + ), + }, } def build_sentinel1_grd_extractor( backend_context: BackendContext, bands: list, fetch_type: FetchType, **params ) -> CollectionFetcher: - """ Creates a S1 GRD collection extractor for the given backend.""" + """Creates a S1 GRD collection extractor for the given backend.""" backend_functions = SENTINEL1_GRD_BACKEND_MAP.get(backend_context.backend) fetcher, preprocessor = ( @@ -156,6 +152,4 @@ def build_sentinel1_grd_extractor( backend_functions["preprocessor"](fetch_type=fetch_type), ) - return CollectionFetcher( - backend_context, bands, fetcher, preprocessor, **params - ) + return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params) diff --git a/src/openeo_gfmap/inference/__init__.py b/src/openeo_gfmap/inference/__init__.py new file mode 100644 index 0000000..bcda91d --- /dev/null +++ b/src/openeo_gfmap/inference/__init__.py @@ -0,0 +1,3 @@ +"""This module contains the base classes for inference models. The base classes +provide some common methods and attributes to be used by other inference models. +""" diff --git a/src/openeo_gfmap/inference/inference_models.py b/src/openeo_gfmap/inference/inference_models.py new file mode 100644 index 0000000..e5f60f4 --- /dev/null +++ b/src/openeo_gfmap/inference/inference_models.py @@ -0,0 +1,38 @@ +"""Inference functionalities. Such as a base class to assist the implementation +of inference models on an UDF. +""" + +from abc import ABC + +import xarray as xr + +REQUIRED_IMPORTS = """ +import inspect +from abc import ABC, abstractmethod + +import openeo +from openeo.udf import XarrayDataCube +from openeo.udf.run_code import execute_local_udf +from openeo.udf.udf_data import UdfData + +from openeo_gfmap.features.feature_extractor import EPSG_HARMONIZED_NAME + +import xarray as xr +import numpy as np + +from typing import Union +""" + + +class ModelInference(ABC): + """Base class for all model inference UDFs. It provides some common + methods and attributes to be used by other model inference classes. + """ + + def _common_preparations( + self, inarr: xr.DataArray, parameters: dict + ) -> xr.DataArray: + """Common preparations for all inference models. This method will be + executed at the very beginning of the process. + """ + raise NotImplementedError("Inference UDF are not implemented yet.") diff --git a/src/openeo_gfmap/preprocessing/__init__.py b/src/openeo_gfmap/preprocessing/__init__.py index 9f83f02..0f25ca8 100644 --- a/src/openeo_gfmap/preprocessing/__init__.py +++ b/src/openeo_gfmap/preprocessing/__init__.py @@ -19,4 +19,4 @@ "get_bap_score", "get_bap_mask", "bap_masking", -] \ No newline at end of file +] diff --git a/src/openeo_gfmap/preprocessing/cloudmasking.py b/src/openeo_gfmap/preprocessing/cloudmasking.py index b4d9aff..3cbdeed 100644 --- a/src/openeo_gfmap/preprocessing/cloudmasking.py +++ b/src/openeo_gfmap/preprocessing/cloudmasking.py @@ -9,15 +9,16 @@ SCL_HARMONIZED_NAME: str = "S2-SCL" BAPSCORE_HARMONIZED_NAME: str = "S2-BAPSCORE" + def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: """Creates a mask from the SCL, dilates it and applies the mask to the optical bands of the datacube. The other bands such as DEM, SAR and METEO will not be affected by the mask. """ # Asserts if the SCL layer exists - assert SCL_HARMONIZED_NAME in cube.metadata.band_names, ( - f"The SCL band ({SCL_HARMONIZED_NAME}) is not present in the datacube." - ) + assert ( + SCL_HARMONIZED_NAME in cube.metadata.band_names + ), f"The SCL band ({SCL_HARMONIZED_NAME}) is not present in the datacube." kernel1_size = params.get("kernel1_size", 17) kernel2_size = params.get("kernel2_size", 3) @@ -28,9 +29,7 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: # Only applies the filtering to the optical part of the cube optical_cube = cube.filter_bands( - bands=list( - filter(lambda band: band.startswith("S2"), cube.metadata.band_names) - ) + bands=list(filter(lambda band: band.startswith("S2"), cube.metadata.band_names)) ) nonoptical_cube = cube.filter_bands( @@ -47,7 +46,7 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: kernel2_size=kernel2_size, mask1_values=[2, 4, 5, 6, 7], mask2_values=[3, 8, 9, 10, 11], - erosion_kernel_size=erosion_kernel_size + erosion_kernel_size=erosion_kernel_size, ) if len(nonoptical_cube.metadata.band_names) == 0: @@ -55,6 +54,7 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: return optical_cube.merge_cubes(nonoptical_cube) + def get_bap_score(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: """Calculates the Best Available Pixel (BAP) score for the given datacube, computed from the SCL layer. @@ -113,7 +113,7 @@ def get_bap_score(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: kernel2_size=kernel2_size, mask1_values=[2, 4, 5, 6, 7], mask2_values=[3, 8, 9, 10, 11], - erosion_kernel_size=erosion_kernel_size + erosion_kernel_size=erosion_kernel_size, ) # Replace NaN to 0 to avoid issues in the UDF @@ -121,15 +121,22 @@ def get_bap_score(cube: openeo.DataCube, **params: dict) -> openeo.DataCube: score = scl_cube.apply_neighborhood( process=openeo.UDF.from_file(str(udf_path)), - size=[{'dimension': 'x', 'unit': 'px', 'value': 256}, {'dimension': 'y', 'unit': 'px', 'value': 256}], - overlap=[{'dimension': 'x', 'unit': 'px', 'value': 16}, {'dimension': 'y', 'unit': 'px', 'value': 16}], + size=[ + {"dimension": "x", "unit": "px", "value": 256}, + {"dimension": "y", "unit": "px", "value": 256}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 16}, + {"dimension": "y", "unit": "px", "value": 16}, + ], ) - score = score.rename_labels('bands', [BAPSCORE_HARMONIZED_NAME]) + score = score.rename_labels("bands", [BAPSCORE_HARMONIZED_NAME]) # Merge the score to the scl cube return score + def get_bap_mask(cube: openeo.DataCube, period: Union[str, list], **params: dict): """Computes the bap score and masks the optical bands of the datacube using the best scores for each pixel on a given time period. This method both @@ -155,13 +162,14 @@ def get_bap_mask(cube: openeo.DataCube, period: Union[str, list], **params: dict The datacube with the BAP mask applied. """ # Checks if the S2-SCL band is present in the datacube - assert SCL_HARMONIZED_NAME in cube.metadata.band_names, ( - f"The {SCL_HARMONIZED_NAME} band is not present in the datacube." - ) + assert ( + SCL_HARMONIZED_NAME in cube.metadata.band_names + ), f"The {SCL_HARMONIZED_NAME} band is not present in the datacube." bap_score = get_bap_score(cube, **params) if isinstance(period, str): + def max_score_selection(score): max_score = score.max() return score.array_apply(lambda x: x != max_score) @@ -171,27 +179,26 @@ def max_score_selection(score): size=[ {"dimension": "x", "unit": "px", "value": 1}, {"dimension": "y", "unit": "px", "value": 1}, - {"dimension": "t", "value": period} + {"dimension": "t", "value": period}, ], - overlap=[] + overlap=[], ) elif isinstance(period, list): udf_path = Path(__file__).parent / "udf_rank.py" rank_mask = bap_score.apply_neighborhood( - process=openeo.UDF.from_file( - str(udf_path), - context={"intervals": period} - ), + process=openeo.UDF.from_file(str(udf_path), context={"intervals": period}), size=[ - {'dimension': 'x', 'unit': 'px', 'value': 256}, - {'dimension': 'y', 'unit': 'px', 'value': 256} + {"dimension": "x", "unit": "px", "value": 256}, + {"dimension": "y", "unit": "px", "value": 256}, ], overlap=[], ) else: - raise ValueError(f"'period' must be a string or a list of dates (in YYYY-mm-dd format), got {period}.") + raise ValueError( + f"'period' must be a string or a list of dates (in YYYY-mm-dd format), got {period}." + ) - return rank_mask.rename_labels('bands', ['S2-BAPMASK']) + return rank_mask.rename_labels("bands", ["S2-BAPMASK"]) def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict): @@ -213,9 +220,7 @@ def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict) The datacube with the BAP mask applied. """ optical_cube = cube.filter_bands( - bands=list( - filter(lambda band: band.startswith("S2"), cube.metadata.band_names) - ) + bands=list(filter(lambda band: band.startswith("S2"), cube.metadata.band_names)) ) nonoptical_cube = cube.filter_bands( @@ -226,9 +231,7 @@ def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict) rank_mask = get_bap_mask(optical_cube, period, **params) - optical_cube = optical_cube.mask( - rank_mask.resample_cube_spatial(cube) - ) + optical_cube = optical_cube.mask(rank_mask.resample_cube_spatial(cube)) # Do not merge if bands are empty! if len(nonoptical_cube.metadata.band_names) == 0: diff --git a/src/openeo_gfmap/preprocessing/interpolation.py b/src/openeo_gfmap/preprocessing/interpolation.py index 79fe256..bffaa49 100644 --- a/src/openeo_gfmap/preprocessing/interpolation.py +++ b/src/openeo_gfmap/preprocessing/interpolation.py @@ -6,6 +6,4 @@ def linear_interpolation(cube: openeo.DataCube,) -> openeo.DataCube: """Perform linear interpolation on the given datacube.""" - return cube.apply_dimension( - dimension="t", process="array_interpolate_linear" - ) \ No newline at end of file + return cube.apply_dimension(dimension="t", process="array_interpolate_linear") diff --git a/src/openeo_gfmap/preprocessing/udf_rank.py b/src/openeo_gfmap/preprocessing/udf_rank.py index 1fe65cd..b68dad1 100644 --- a/src/openeo_gfmap/preprocessing/udf_rank.py +++ b/src/openeo_gfmap/preprocessing/udf_rank.py @@ -15,7 +15,7 @@ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube: """ # First check if the period is defined in the context intervals = context.get("intervals", None) - array = cube.get_array().transpose('t', 'bands', 'y', 'x') + array = cube.get_array().transpose("t", "bands", "y", "x") bap_score = array.sel(bands="S2-BAPSCORE") @@ -23,7 +23,6 @@ def select_maximum(score: xr.DataArray): max_score = score.max(dim="t") return score == max_score - if isinstance(intervals, str): raise NotImplementedError( "Period as string is not implemented yet, please provide a list of interval tuples." diff --git a/src/openeo_gfmap/preprocessing/udf_score.py b/src/openeo_gfmap/preprocessing/udf_score.py index c2873d2..578a5b9 100644 --- a/src/openeo_gfmap/preprocessing/udf_score.py +++ b/src/openeo_gfmap/preprocessing/udf_score.py @@ -8,11 +8,12 @@ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube: - cube_array: xr.DataArray = cube.get_array() - cube_array = cube_array.transpose('t', 'bands', 'y', 'x') + cube_array = cube_array.transpose("t", "bands", "y", "x") - clouds = np.logical_or(np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3).isel(bands=0) + clouds = np.logical_or( + np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3 + ).isel(bands=0) weights = [1, 0.8, 0.5] @@ -20,14 +21,21 @@ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube: times = cube_array.t.dt.day.values # returns day of the month for each date sigma = 5 mu = 15 - score_doy = 1 / (sigma * math.sqrt(2 * math.pi)) * np.exp(-0.5 * ((times - mu) / sigma) ** 2) - score_doy = np.broadcast_to(score_doy[:, np.newaxis, np.newaxis], - [cube_array.sizes['t'], cube_array.sizes['y'], cube_array.sizes['x']]) + score_doy = ( + 1 + / (sigma * math.sqrt(2 * math.pi)) + * np.exp(-0.5 * ((times - mu) / sigma) ** 2) + ) + score_doy = np.broadcast_to( + score_doy[:, np.newaxis, np.newaxis], + [cube_array.sizes["t"], cube_array.sizes["y"], cube_array.sizes["x"]], + ) # Calculate the Distance To Cloud score # Erode # Source: https://github.com/dzanaga/satio-pc/blob/e5fc46c0c14bba77e01dca409cf431e7ef22c077/src/satio_pc/preprocessing/clouds.py#L127 e = footprints.disk(3) + # Define a function to apply binary erosion def erode(image, selem): return ~binary_erosion(image, selem) @@ -36,12 +44,12 @@ def erode(image, selem): eroded = xr.apply_ufunc( erode, # function to apply clouds, # input DataArray - input_core_dims=[['y', 'x']], # dimensions over which to apply function - output_core_dims=[['y', 'x']], # dimensions of the output + input_core_dims=[["y", "x"]], # dimensions over which to apply function + output_core_dims=[["y", "x"]], # dimensions of the output vectorize=True, # vectorize the function over non-core dimensions dask="parallelized", # enable dask parallelization output_dtypes=[np.int32], # data type of the output - kwargs={'selem': e} # additional keyword arguments to pass to erode + kwargs={"selem": e}, # additional keyword arguments to pass to erode ) # Distance to cloud = dilation @@ -50,33 +58,38 @@ def erode(image, selem): d = xr.apply_ufunc( distance_transform_cdt, eroded, - input_core_dims=[['y', 'x']], - output_core_dims=[['y', 'x']], + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], vectorize=True, dask="parallelized", - output_dtypes=[np.int32] + output_dtypes=[np.int32], ) d = xr.where(d == -1, d_req, d) score_clouds = 1 / (1 + np.exp(-0.2 * (np.minimum(d, d_req) - (d_req - d_min) / 2))) # Calculate the Coverage score - score_cov = 1 - clouds.sum(dim='x').sum(dim='y') / ( - cube_array.sizes['x'] * cube_array.sizes['y']) - score_cov = np.broadcast_to(score_cov.values[:, np.newaxis, np.newaxis], - [cube_array.sizes['t'], cube_array.sizes['y'], cube_array.sizes['x']]) + score_cov = 1 - clouds.sum(dim="x").sum(dim="y") / ( + cube_array.sizes["x"] * cube_array.sizes["y"] + ) + score_cov = np.broadcast_to( + score_cov.values[:, np.newaxis, np.newaxis], + [cube_array.sizes["t"], cube_array.sizes["y"], cube_array.sizes["x"]], + ) # Final score is weighted average - score = (weights[0] * score_clouds + weights[1] * score_doy + weights[2] * score_cov) / sum(weights) - score = np.where(cube_array.values[:,0,:,:]==0, 0, score) + score = ( + weights[0] * score_clouds + weights[1] * score_doy + weights[2] * score_cov + ) / sum(weights) + score = np.where(cube_array.values[:, 0, :, :] == 0, 0, score) score_da = xr.DataArray( score, coords={ - 't': cube_array.coords['t'], - 'y': cube_array.coords['y'], - 'x': cube_array.coords['x'], + "t": cube_array.coords["t"], + "y": cube_array.coords["y"], + "x": cube_array.coords["x"], }, - dims=['t', 'y', 'x'] + dims=["t", "y", "x"], ) score_da = score_da.expand_dims( @@ -85,6 +98,6 @@ def erode(image, selem): }, ) - score_da = score_da.transpose('t', 'bands', 'y', 'x') + score_da = score_da.transpose("t", "bands", "y", "x") return XarrayDataCube(score_da) diff --git a/src/openeo_gfmap/spatial.py b/src/openeo_gfmap/spatial.py index 1a38c49..9968ad9 100644 --- a/src/openeo_gfmap/spatial.py +++ b/src/openeo_gfmap/spatial.py @@ -29,13 +29,15 @@ def __dict__(self): } def __iter__(self): - return iter([ - ("west", self.west), - ("south", self.south), - ("east", self.east), - ("north", self.north), - ("crs", self.epsg) - ]) + return iter( + [ + ("west", self.west), + ("south", self.south), + ("east", self.east), + ("north", self.north), + ("crs", self.epsg), + ] + ) SpatialContext = Union[GeoJSON, BoundingBoxExtent] diff --git a/src/openeo_gfmap/utils/__init__.py b/src/openeo_gfmap/utils/__init__.py index b9f6608..9b9efd3 100644 --- a/src/openeo_gfmap/utils/__init__.py +++ b/src/openeo_gfmap/utils/__init__.py @@ -11,6 +11,11 @@ ) __all__ = [ - "load_json", "normalize_array", "select_optical_bands", "array_bounds", - "select_sar_bands", "arrays_cosine_similarity", "quintad_intervals" + "load_json", + "normalize_array", + "select_optical_bands", + "array_bounds", + "select_sar_bands", + "arrays_cosine_similarity", + "quintad_intervals", ] diff --git a/src/openeo_gfmap/utils/catalogue.py b/src/openeo_gfmap/utils/catalogue.py index 19a1a41..8dfd28a 100644 --- a/src/openeo_gfmap/utils/catalogue.py +++ b/src/openeo_gfmap/utils/catalogue.py @@ -10,7 +10,7 @@ def _check_cdse_catalogue( collection: str, spatial_extent: SpatialContext, temporal_extent: TemporalContext, - **additional_parameters: dict + **additional_parameters: dict, ) -> bool: """Checks if there is at least one product available in the given spatio-temporal context for a collection in the CDSE catalogue, @@ -37,9 +37,16 @@ def _check_cdse_catalogue( # Transform geojson into shapely geometry and compute bounds bounds = shape(spatial_extent).bounds elif isinstance(spatial_extent, SpatialContext): - bounds = [spatial_extent.west, spatial_extent.south, spatial_extent.east, spatial_extent.north] + bounds = [ + spatial_extent.west, + spatial_extent.south, + spatial_extent.east, + spatial_extent.north, + ] else: - raise ValueError('Provided spatial extent is not a valid GeoJSON or SpatialContext object.') + raise ValueError( + "Provided spatial extent is not a valid GeoJSON or SpatialContext object." + ) minx, miny, maxx, maxy = bounds @@ -66,7 +73,10 @@ def _check_cdse_catalogue( body = response.json() grd_tiles = list( - filter(lambda feature: feature["properties"]["productType"].contains("GRD"), body["features"]) + filter( + lambda feature: feature["properties"]["productType"].contains("GRD"), + body["features"], + ) ) return len(grd_tiles) > 0 diff --git a/src/openeo_gfmap/utils/intervals.py b/src/openeo_gfmap/utils/intervals.py index ad258ae..707b945 100644 --- a/src/openeo_gfmap/utils/intervals.py +++ b/src/openeo_gfmap/utils/intervals.py @@ -8,14 +8,14 @@ def quintad_intervals(temporal_extent: TemporalContext) -> list: - """ Returns a list of tuples (start_date, end_date) of quintad intervals - from the input temporal extent. Quintad intervals are intervals of - generally 5 days, that never overlap two months. + """Returns a list of tuples (start_date, end_date) of quintad intervals + from the input temporal extent. Quintad intervals are intervals of + generally 5 days, that never overlap two months. - All months are divided in 6 quintads, where the 6th quintad might - contain 6 days for months of 31 days. - For the month of February, the 6th quintad is only of three days, or - four days for the leap year. + All months are divided in 6 quintads, where the 6th quintad might + contain 6 days for months of 31 days. + For the month of February, the 6th quintad is only of three days, or + four days for the leap year. """ start_date, end_date = temporal_extent.to_datetime() quintads = [] diff --git a/src/openeo_gfmap/utils/tile_processing.py b/src/openeo_gfmap/utils/tile_processing.py index 0bd3897..b763be4 100644 --- a/src/openeo_gfmap/utils/tile_processing.py +++ b/src/openeo_gfmap/utils/tile_processing.py @@ -4,11 +4,8 @@ import xarray as xr -def normalize_array( - inarr: xr.DataArray, - percentile: float = 0.99 -) -> xr.DataArray: - """ Performs normalization between 0.0 and 1.0 using the given +def normalize_array(inarr: xr.DataArray, percentile: float = 0.99) -> xr.DataArray: + """Performs normalization between 0.0 and 1.0 using the given percentile. """ quantile_value = inarr.quantile(percentile, dim=["x", "y", "t"]) @@ -19,21 +16,17 @@ def normalize_array( # Perform clipping on values that are higher than the computed quantile return inarr.where(inarr < 1.0, 1.0) -def select_optical_bands( - inarr: xr.DataArray -) -> xr.DataArray: + +def select_optical_bands(inarr: xr.DataArray) -> xr.DataArray: """Filters and keep only the optical bands for a given array.""" return inarr.sel( bands=[ - band - for band in inarr.coords["bands"].to_numpy() - if band.startswith("S2-B") + band for band in inarr.coords["bands"].to_numpy() if band.startswith("S2-B") ] ) -def select_sar_bands( - inarr: xr.DataArray -) -> xr.DataArray: + +def select_sar_bands(inarr: xr.DataArray) -> xr.DataArray: """Filters and keep only the SAR bands for a given array.""" return inarr.sel( bands=[ @@ -43,20 +36,19 @@ def select_sar_bands( ] ) -def array_bounds( - inarr: xr.DataArray -) -> tuple: + +def array_bounds(inarr: xr.DataArray) -> tuple: """Returns the 4 bounds values for the x and y coordinates of the tile""" return ( inarr.coords["x"].min().item(), inarr.coords["y"].min().item(), inarr.coords["x"].max().item(), - inarr.coords["y"].max().item() + inarr.coords["y"].max().item(), ) + def arrays_cosine_similarity( - first_array: xr.DataArray, - second_array: xr.DataArray + first_array: xr.DataArray, second_array: xr.DataArray ) -> float: """Returns a similarity score based on normalized cosine distance. The input arrays must have similar ranges to obtain a valid score. @@ -68,4 +60,3 @@ def arrays_cosine_similarity( similarity = (dot_product / (first_norm * second_norm)).item() return similarity - diff --git a/tests/test_openeo_gfmap/resources/test_optical_cube.nc b/tests/test_openeo_gfmap/resources/test_optical_cube.nc new file mode 100644 index 0000000..ce7f8fb Binary files /dev/null and b/tests/test_openeo_gfmap/resources/test_optical_cube.nc differ diff --git a/tests/test_openeo_gfmap/test_cloud_masking.py b/tests/test_openeo_gfmap/test_cloud_masking.py index 8f06a2f..e0c3db4 100644 --- a/tests/test_openeo_gfmap/test_cloud_masking.py +++ b/tests/test_openeo_gfmap/test_cloud_masking.py @@ -22,13 +22,12 @@ south=51.215806593713, east=5.060320484557499, north=51.22149744530769, - epsg=4326 + epsg=4326, ) # November 2022 to February 2023 -temporal_extent = TemporalContext( - start_date="2022-11-01", end_date="2023-02-28" -) +temporal_extent = TemporalContext(start_date="2022-11-01", end_date="2023-02-28") + @pytest.mark.parametrize("backend", backends) def test_bap_score(backend: Backend): @@ -36,32 +35,26 @@ def test_bap_score(backend: Backend): backend_context = BackendContext(backend=backend) # Additional parameters - fetching_parameters = { - "fetching_resolution": 10.0 - } + fetching_parameters = {"fetching_resolution": 10.0} - preprocessing_parameters = { - "apply_scl_dilation": True - } + preprocessing_parameters = {"apply_scl_dilation": True} # Fetch the datacube s2_extractor = build_sentinel2_l2a_extractor( backend_context=backend_context, bands=["S2-B04", "S2-B08", "S2-SCL"], fetch_type=FetchType.TILE, - **fetching_parameters + **fetching_parameters, ) - cube = s2_extractor.get_cube( - connection, spatial_extent, temporal_extent - ) + cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent) # Compute the BAP score bap_score = get_bap_score(cube, **preprocessing_parameters) ndvi = cube.ndvi(nir="S2-B08", red="S2-B04") cube = bap_score.merge_cubes(ndvi).rename_labels( - 'bands', ['S2-BAPSCORE', 'S2-NDVI'] + "bands", ["S2-BAPSCORE", "S2-NDVI"] ) job = cube.create_job( @@ -77,27 +70,24 @@ def test_bap_score(backend: Backend): Path(__file__).parent / f"results/bap_score_{backend.value}.nc" ) + @pytest.mark.parametrize("backend", backends) def test_bap_masking(backend: Backend): connection = BACKEND_CONNECTIONS[backend]() backend_context = BackendContext(backend=backend) # Additional parameters - fetching_parameters = { - "fetching_resolution": 10.0 - } + fetching_parameters = {"fetching_resolution": 10.0} # Fetch the datacube s2_extractor = build_sentinel2_l2a_extractor( backend_context=backend_context, bands=["S2-B04", "S2-B03", "S2-B02", "S2-SCL"], fetch_type=FetchType.TILE, - **fetching_parameters + **fetching_parameters, ) - cube = s2_extractor.get_cube( - connection, spatial_extent, temporal_extent - ) + cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent) cube = cube.linear_scale_range(0, 65535, 0, 65535) @@ -127,6 +117,7 @@ def test_bap_masking(backend: Backend): Path(__file__).parent / f"results/bap_composited_{backend.value}.nc" ) + @pytest.mark.parametrize("backend", backends) def test_bap_quintad(backend: Backend): connection = BACKEND_CONNECTIONS[backend]() @@ -145,66 +136,62 @@ def test_bap_quintad(backend: Backend): backend_context=backend_context, bands=["S2-SCL"], fetch_type=FetchType.TILE, - **fetching_parameters + **fetching_parameters, ) - cube = s2_extractor.get_cube( - connection, spatial_extent, temporal_extent - ) + cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent) - compositing_intervals = quintad_intervals( - temporal_extent - ) + compositing_intervals = quintad_intervals(temporal_extent) expected_intervals = [ - ('2022-11-01', '2022-11-05'), - ('2022-11-06', '2022-11-10'), - ('2022-11-11', '2022-11-15'), - ('2022-11-16', '2022-11-20'), - ('2022-11-21', '2022-11-25'), - ('2022-11-26', '2022-11-30'), - ('2022-12-01', '2022-12-05'), - ('2022-12-06', '2022-12-10'), - ('2022-12-11', '2022-12-15'), - ('2022-12-16', '2022-12-20'), - ('2022-12-21', '2022-12-25'), - ('2022-12-26', '2022-12-31'), - ('2023-01-01', '2023-01-05'), - ('2023-01-06', '2023-01-10'), - ('2023-01-11', '2023-01-15'), - ('2023-01-16', '2023-01-20'), - ('2023-01-21', '2023-01-25'), - ('2023-01-26', '2023-01-31'), - ('2023-02-01', '2023-02-05'), - ('2023-02-06', '2023-02-10'), - ('2023-02-11', '2023-02-15'), - ('2023-02-16', '2023-02-20'), - ('2023-02-21', '2023-02-25'), - ('2023-02-26', '2023-02-28'), + ("2022-11-01", "2022-11-05"), + ("2022-11-06", "2022-11-10"), + ("2022-11-11", "2022-11-15"), + ("2022-11-16", "2022-11-20"), + ("2022-11-21", "2022-11-25"), + ("2022-11-26", "2022-11-30"), + ("2022-12-01", "2022-12-05"), + ("2022-12-06", "2022-12-10"), + ("2022-12-11", "2022-12-15"), + ("2022-12-16", "2022-12-20"), + ("2022-12-21", "2022-12-25"), + ("2022-12-26", "2022-12-31"), + ("2023-01-01", "2023-01-05"), + ("2023-01-06", "2023-01-10"), + ("2023-01-11", "2023-01-15"), + ("2023-01-16", "2023-01-20"), + ("2023-01-21", "2023-01-25"), + ("2023-01-26", "2023-01-31"), + ("2023-02-01", "2023-02-05"), + ("2023-02-06", "2023-02-10"), + ("2023-02-11", "2023-02-15"), + ("2023-02-16", "2023-02-20"), + ("2023-02-21", "2023-02-25"), + ("2023-02-26", "2023-02-28"), ] assert compositing_intervals == expected_intervals # Perform masking with BAP, masking optical bands - bap_mask = get_bap_mask(cube, period=compositing_intervals, **preprocessing_parameters) + bap_mask = get_bap_mask( + cube, period=compositing_intervals, **preprocessing_parameters + ) # Create a new extractor for the whole data now fetching_parameters = { "fetching_resolution": 10.0, - "pre_mask": bap_mask # Use of the pre-computed bap mask to load inteligently the data + "pre_mask": bap_mask, # Use of the pre-computed bap mask to load inteligently the data } s2_extractor = build_sentinel2_l2a_extractor( backend_context=backend_context, bands=["S2-B04", "S2-B03", "S2-B02", "S2-B08", "S2-SCL"], fetch_type=FetchType.TILE, - **fetching_parameters + **fetching_parameters, ) # Performs quintal compositing - cube = s2_extractor.get_cube( - connection, spatial_extent, temporal_extent - ) + cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent) cube = median_compositing(cube, period=compositing_intervals) diff --git a/tests/test_openeo_gfmap/test_feature_extractors.py b/tests/test_openeo_gfmap/test_feature_extractors.py new file mode 100644 index 0000000..d984bf6 --- /dev/null +++ b/tests/test_openeo_gfmap/test_feature_extractors.py @@ -0,0 +1,176 @@ +"""Test on feature extractors implementations, both local and remote.""" +from pathlib import Path +from typing import Callable + +import pytest +import xarray as xr + +from openeo_gfmap import BoundingBoxExtent, FetchType, TemporalContext +from openeo_gfmap.backend import BACKEND_CONNECTIONS, Backend, BackendContext +from openeo_gfmap.features import ( + PatchFeatureExtractor, + apply_feature_extractor, + apply_feature_extractor_local, +) +from openeo_gfmap.fetching import build_sentinel2_l2a_extractor + +SPATIAL_CONTEXT = BoundingBoxExtent( + west=4.261, + south=51.309, + east=4.267, + north=51.313, + epsg=4326, +) +TEMPORAL_EXTENT = TemporalContext("2023-10-01", "2024-01-01") + + +class DummyPatchExtractor(PatchFeatureExtractor): + def output_labels(self) -> list: + return ["red", "green", "blue"] + + def execute(self, inarr: xr.DataArray): + # Make the imports WITHIN the class + import xarray as xr # noqa: F401 + from scipy.ndimage import gaussian_filter + + # Performs some gaussian filtering to blur the RGB bands + rgb_bands = inarr.sel(bands=["S2-B04", "S2-B03", "S2-B02"]) + + for band in rgb_bands.bands: + for timestamp in rgb_bands.t: + rgb_bands.loc[{"bands": band, "t": timestamp}] = gaussian_filter( + rgb_bands.loc[{"bands": band, "t": timestamp}], sigma=1.0 + ) + + # Compute the median on the time band + rgb_bands = rgb_bands.median(dim="t").assign_coords( + {"bands": ["red", "green", "blue"]} + ) + + # Returns the rgb bands only in the feature, y, x order + return rgb_bands.transpose("bands", "y", "x") + + +class LatLonExtractor(PatchFeatureExtractor): + """Sample extractor that compute the latitude and longitude values + and concatenates them in a new array. + """ + + def output_labels(self) -> list: + return ["red", "lat", "lon"] + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + # Compute the latitude and longitude as bands in the input array + latlon = self.get_latlons(inarr) + + # Only select the first time for the input array + inarr = inarr.isel(t=0) + + # Add the bands in the input array + inarr = xr.concat([inarr, latlon], dim="bands").assign_coords( + {"bands": ["red", "lat", "lon"]} + ) + + return inarr.transpose("bands", "y", "x") + + +@pytest.mark.parametrize("backend, connection_fn", BACKEND_CONNECTIONS.items()) +def test_patch_feature_udf(backend: Backend, connection_fn: Callable): + backend_context = BackendContext(backend=backend) + connection = connection_fn() + output_path = Path(__file__).parent / f"results/patch_features_{backend.value}.nc/" + + bands_to_extract = ["S2-B04", "S2-B03", "S2-B02"] + + # Setup the RGB cube extraction + extractor = build_sentinel2_l2a_extractor( + backend_context, bands_to_extract, FetchType.TILE + ) + + rgb_cube = extractor.get_cube(connection, SPATIAL_CONTEXT, TEMPORAL_EXTENT) + + # Run the feature extractor + features = apply_feature_extractor( + DummyPatchExtractor, + rgb_cube, + parameters={}, + size=[ + {"dimension": "x", "unit": "px", "value": 128}, + {"dimension": "y", "unit": "px", "value": 128}, + ], + ) + + job = features.create_job(title="patch_feature_extractor", out_format="NetCDF") + job.start_and_wait() + + for asset in job.get_results().get_assets(): + if asset.metadata["type"].startswith("application/x-netcdf"): + asset.download(output_path) + break + + assert output_path.exists() + + # Read the output path and checks for the expected band names + output_cube = xr.open_dataset(output_path) + + assert set(output_cube.keys()) == set(["red", "green", "blue", "crs"]) + + +@pytest.mark.parametrize("backend, connection_fn", BACKEND_CONNECTIONS.items()) +def test_latlon_extractor(backend: Backend, connection_fn: Callable): + backend_context = BackendContext(backend=backend) + connection = connection_fn() + output_path = Path(__file__).parent / f"results/latlon_features_{backend.value}.nc" + + REDUCED_TEMPORAL_CONTEXT = TemporalContext( + start_date="2023-06-01", end_date="2023-06-30" + ) + + bands_to_extract = ["S2-B04"] + + extractor = build_sentinel2_l2a_extractor( + backend_context, bands_to_extract, FetchType.TILE + ) + + cube = extractor.get_cube(connection, SPATIAL_CONTEXT, REDUCED_TEMPORAL_CONTEXT) + + features = apply_feature_extractor( + LatLonExtractor, + cube, + parameters={}, + size=[ + {"dimension": "x", "unit": "px", "value": 128}, + {"dimension": "y", "unit": "px", "value": 128}, + ], + ) + + job = features.create_job(title="latlon_feature_extractor", out_format="NetCDF") + job.start_and_wait() + + for asset in job.get_results().get_assets(): + if asset.metadata["type"].startswith("application/x-netcdf"): + asset.download(output_path) + break + + assert output_path.exists() + + # Read the output path and checks for the expected band names + output_cube = xr.open_dataset(output_path) + + assert set(output_cube.keys()) == set(["red", "lat", "lon", "crs"]) + + +def test_patch_feature_local(): + input_path = Path(__file__).parent / "resources/test_optical_cube.nc" + + inds = xr.open_dataset(input_path).to_array(dim="bands") + + inds = inds.sel( + bands=[band for band in inds.bands.to_numpy() if band != "crs"] + ).transpose("bands", "t", "y", "x") + + features = apply_feature_extractor_local(DummyPatchExtractor, inds, parameters={}) + + features.to_netcdf(Path(__file__).parent / "results/patch_features_local.nc") + + assert set(features.bands.values) == set(["red", "green", "blue"]) diff --git a/tests/test_openeo_gfmap/test_intervals.py b/tests/test_openeo_gfmap/test_intervals.py index e0f3711..876d0d0 100644 --- a/tests/test_openeo_gfmap/test_intervals.py +++ b/tests/test_openeo_gfmap/test_intervals.py @@ -19,6 +19,7 @@ def test_quintad_january(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_april(): start_date = "2023-04-01" end_date = "2023-04-30" @@ -36,6 +37,7 @@ def test_quintad_april(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_february_nonleap(): start_date = "2023-02-01" end_date = "2023-02-28" @@ -53,6 +55,7 @@ def test_quintad_february_nonleap(): assert quintad_intervals(temporal_extent) == expected + def test_quitad_february_leapyear(): start_date = "2024-02-01" end_date = "2024-02-29" @@ -70,6 +73,7 @@ def test_quitad_february_leapyear(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_four_months(): start_date = "2023-01-01" end_date = "2023-04-30" @@ -105,6 +109,7 @@ def test_quintad_four_months(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_july_august(): start_date = "2023-07-01" end_date = "2023-08-31" @@ -128,6 +133,7 @@ def test_quintad_july_august(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_mid_month(): start_date = "2023-01-02" end_date = "2023-01-31" @@ -145,6 +151,7 @@ def test_quintad_mid_month(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_full_year(): # non-leap year start_date = "2023-01-01" @@ -162,6 +169,7 @@ def test_quintad_full_year(): assert len(quintad_intervals(temporal_extent)) == 72 + def test_quintad_mid_month_february(): start_date = "2024-01-31" end_date = "2024-03-02" @@ -181,6 +189,7 @@ def test_quintad_mid_month_february(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_single_day(): start_date = "2024-02-29" end_date = "2024-02-29" @@ -193,6 +202,7 @@ def test_quintad_single_day(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_end_month(): start_date = "2024-02-14" end_date = "2024-03-01" @@ -209,6 +219,7 @@ def test_quintad_end_month(): assert quintad_intervals(temporal_extent) == expected + def test_quintad_new_year(): start_date = "2023-12-04" end_date = "2024-01-01" @@ -227,4 +238,4 @@ def test_quintad_new_year(): print(quintad_intervals(temporal_extent)) - assert quintad_intervals(temporal_extent) == expected \ No newline at end of file + assert quintad_intervals(temporal_extent) == expected diff --git a/tests/test_openeo_gfmap/test_s1_fetchers.py b/tests/test_openeo_gfmap/test_s1_fetchers.py index 604eaf2..d0bbcf2 100644 --- a/tests/test_openeo_gfmap/test_s1_fetchers.py +++ b/tests/test_openeo_gfmap/test_s1_fetchers.py @@ -42,7 +42,7 @@ def sentinel1_grd( spatial_extent: SpatialContext, temporal_extent: TemporalContext, backend: Backend, - connection=openeo.Connection + connection=openeo.Connection, ): context = BackendContext(backend) country = spatial_extent["country"] @@ -55,7 +55,7 @@ def sentinel1_grd( "coefficient": "gamma0-ellipsoid", "load_collection": { "polarization": lambda polar: (polar == "VV") or (polar == "VH"), - } + }, } extractor: CollectionFetcher = build_sentinel1_grd_extractor( @@ -74,9 +74,7 @@ def sentinel1_grd( start_date=temporal_extent[0], end_date=temporal_extent[1] ) - cube = extractor.get_cube( - connection, spatial_extent, temporal_extent - ) + cube = extractor.get_cube(connection, spatial_extent, temporal_extent) output_file = ( Path(__file__).parent @@ -147,16 +145,14 @@ def compare_sentinel1_tiles(): first_tile = normalized_tiles[0] for tile_idx in range(1, len(normalized_tiles)): tile_to_compare = normalized_tiles[tile_idx] - similarity_score = arrays_cosine_similarity( - first_tile, tile_to_compare - ) + similarity_score = arrays_cosine_similarity(first_tile, tile_to_compare) assert similarity_score >= 0.95 def sentinel1_grd_point_based( spatial_context: SpatialContext, temporal_context: TemporalContext, backend: Backend, - connection: openeo.Connection + connection: openeo.Connection, ): """Test the point based extraction from the spatial aggregation of the given polygons. @@ -173,18 +169,16 @@ def sentinel1_grd_point_based( "coefficient": "gamma0-ellipsoid", "load_collection": { "polarization": lambda polar: (polar == "VV") or (polar == "VH"), - } + }, } extractor = build_sentinel1_grd_extractor( backend_context=context, bands=bands, fetch_type=FetchType.POINT, - **fetching_parameters + **fetching_parameters, ) - cube = extractor.get_cube( - connection, spatial_context, temporal_context - ) + cube = extractor.get_cube(connection, spatial_context, temporal_context) cube = cube.aggregate_spatial(spatial_context, reducer="mean") @@ -214,7 +208,7 @@ def sentinel1_grd_polygon_based( spatial_context: SpatialContext, temporal_context: TemporalContext, backend: Backend, - connection: openeo.Connection + connection: openeo.Connection, ): context = BackendContext(backend) bands = ["S1-VV", "S1-VH"] @@ -226,14 +220,14 @@ def sentinel1_grd_polygon_based( "coefficient": "gamma0-ellipsoid", "load_collection": { "polarization": lambda polar: (polar == "VV") or (polar == "VH"), - } + }, } extractor = build_sentinel1_grd_extractor( backend_context=context, bands=bands, fetch_type=FetchType.POLYGON, - **fetching_parameters + **fetching_parameters, ) cube = extractor.get_cube(connection, spatial_context, temporal_context) @@ -242,7 +236,9 @@ def sentinel1_grd_polygon_based( output_folder.mkdir(exist_ok=True, parents=True) job = cube.create_job( - title="test_extract_polygons_s1", out_format="NetCDF", sample_by_feature=True + title="test_extract_polygons_s1", + out_format="NetCDF", + sample_by_feature=True, ) job.start_and_wait() @@ -258,23 +254,23 @@ def sentinel1_grd_polygon_based( assert len(extracted_files) == len(spatial_context["features"]) - @pytest.mark.parametrize( "spatial_context, temporal_context, backend", test_configurations ) def test_sentinel1_grd( - spatial_context: SpatialContext, temporal_context: TemporalContext, - backend: Backend + spatial_context: SpatialContext, temporal_context: TemporalContext, backend: Backend ): connection = BACKEND_CONNECTIONS[backend]() TestS1Extractors.sentinel1_grd( spatial_context, temporal_context, backend, connection ) + @pytest.mark.depends(on=["test_sentinel1_grd"]) def test_compare_sentinel1_tiles(): TestS1Extractors.compare_sentinel1_tiles() + @pytest.mark.parametrize("backend", test_backends) def test_sentinel1_grd_point_based(backend: Backend): connection = BACKEND_CONNECTIONS[backend]() diff --git a/tests/test_openeo_gfmap/test_s2_fetchers.py b/tests/test_openeo_gfmap/test_s2_fetchers.py index fed1bc3..59e9749 100644 --- a/tests/test_openeo_gfmap/test_s2_fetchers.py +++ b/tests/test_openeo_gfmap/test_s2_fetchers.py @@ -59,12 +59,12 @@ Path(__file__).parent / "resources/puglia_extraction_polygons.gpkg" ) -#test_backends = [Backend.TERRASCOPE, Backend.CDSE] +# test_backends = [Backend.TERRASCOPE, Backend.CDSE] test_backends = [Backend.CDSE] test_spatio_temporal_extends = [ (SPATIAL_EXTENT_1, TEMPORAL_EXTENT_1), -# (SPATIAL_EXTENT_2, TEMPORAL_EXTENT_2), + (SPATIAL_EXTENT_2, TEMPORAL_EXTENT_2), ] test_configurations = [ @@ -99,10 +99,11 @@ def sentinel2_l2a( "S2-SCL", "S2-AOT", ] - fetching_parameters = { - "target_resolution": 10.0, - "target_crs": 3035 - } if country == "Belgium" else {} + fetching_parameters = ( + {"target_resolution": 10.0, "target_crs": 3035} + if country == "Belgium" + else {} + ) extractor: CollectionFetcher = build_sentinel2_l2a_extractor( context=context, bands=bands, @@ -115,7 +116,7 @@ def sentinel2_l2a( south=spatial_extent["south"], east=spatial_extent["east"], north=spatial_extent["north"], - epsg=spatial_extent["crs"] + epsg=spatial_extent["crs"], ) temporal_extent = TemporalContext( @@ -186,9 +187,7 @@ def compare_sentinel2_tiles(): first_tile = normalized_tiles[0] for tile_idx in range(1, len(normalized_tiles)): tile_to_compare = normalized_tiles[tile_idx] - similarity_score = arrays_cosine_similarity( - first_tile, tile_to_compare - ) + similarity_score = arrays_cosine_similarity(first_tile, tile_to_compare) assert similarity_score >= 0.95 def sentinel2_l2a_point_based( @@ -263,7 +262,9 @@ def sentinel2_l2a_polygon_based( output_folder.mkdir(exist_ok=True, parents=True) job = cube.create_job( - title="test_extract_polygons_s2", out_format="NetCDF", sample_by_feature=True + title="test_extract_polygons_s2", + out_format="NetCDF", + sample_by_feature=True, ) job.start_and_wait() @@ -295,6 +296,7 @@ def test_sentinel2_l2a( def test_compare_sentinel2_tiles(): TestS2Extractors.compare_sentinel2_tiles() + @pytest.mark.parametrize("backend", test_backends) def test_sentinel2_l2a_point_based(backend: Backend): connection = BACKEND_CONNECTIONS[backend]()