diff --git a/odc/geo/_xr_interop.py b/odc/geo/_xr_interop.py index 2ca3cb0f..9c7ad0b6 100644 --- a/odc/geo/_xr_interop.py +++ b/odc/geo/_xr_interop.py @@ -48,6 +48,15 @@ resolve_fill_value, resolve_nodata, ) +from .masking import ( + bits_to_bool, + enum_to_bool, + mask_invalid_data, + mask_clouds, + mask_ls, + mask_s2, + scale_and_offset, +) from .overlap import compute_output_geobox from .roi import roi_is_empty from .types import Nodata, Resolution, SomeNodata, SomeResolution, SomeShape, xy_ @@ -1053,6 +1062,14 @@ def nodata(self, value: Nodata): colorize = _wrap_op(colorize) + scale_and_offset = _wrap_op(scale_and_offset) + + bits_to_bool = _wrap_op(bits_to_bool) + + enum_to_bool = _wrap_op(enum_to_bool) + + mask_invalid_data = _wrap_op(mask_invalid_data) + if have.rasterio: write_cog = _wrap_op(write_cog) to_cog = _wrap_op(to_cog) @@ -1093,6 +1110,16 @@ def to_rgba( ) -> xarray.DataArray: return to_rgba(self._xx, bands=bands, vmin=vmin, vmax=vmax) + scale_and_offset = _wrap_op(scale_and_offset) + + mask_invalid_data = _wrap_op(mask_invalid_data) + + mask_clouds = _wrap_op(mask_clouds) + + mask_ls = _wrap_op(mask_ls) + + mask_s2 = _wrap_op(mask_s2) + ODCExtensionDs.to_rgba.__doc__ = to_rgba.__doc__ diff --git a/odc/geo/masking.py b/odc/geo/masking.py new file mode 100644 index 00000000..a8266d4f --- /dev/null +++ b/odc/geo/masking.py @@ -0,0 +1,340 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2020 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 +""" +Functions around supporting cloud masking. +""" + +from typing import Annotated, Any, Callable, Sequence +import numpy as np +from xarray import DataArray, Dataset + +from enum import Enum + + +class SENTINEL2_L2A_SCL(Enum): + """ + Sentinel-2 Scene Classification Layer (SCL) values. + """ + + NO_DATA = 0 + SATURATED_OR_DEFECTIVE = 1 + DARK_AREA_PIXELS = 2 + CLOUD_SHADOWS = 3 + VEGETATION = 4 + NOT_VEGETATED = 5 + WATER = 6 + UNCLASSIFIED = 7 + CLOUD_MEDIUM_PROBABILITY = 8 + CLOUD_HIGH_PROBABILITY = 9 + THIN_CIRRUS = 10 + SNOW = 11 + + +SENTINEL2_L2A_SCALE = 0.0001 +SENTINEL2_L2A_OFFSET = -0.1 + + +class LANDSAT_C2L2_PIXEL_QA(Enum): + """ + Landsat Collection 2 Surface Reflectance Pixel Quality values. + """ + + NO_DATA = 0 + DILATED_CLOUD = 1 + CIRRUS = 2 + CLOUD = 3 + CLOUD_SHADOW = 4 + SNOW = 5 + CLEAR = 6 + WATER = 7 + # Not sure how to implement these yet... + # CLOUD_CONFIDENCE = [8, 9] + # CLOUD_SHADOW_CONFIDENCE = [10, 11] + # SNOW_ICE_CONFIDENCE = [12, 13] + # CIRRUS_CONFIDENCE = [14, 15] + + +LANDSAT_C2L2_SCALE = 0.0000275 +LANDSAT_C2L2_OFFSET = -0.2 + +# TODO: QA_RADSAT and QA_AEROSOL for Landsat Collection 2 Surface Reflectance + + +def bits_to_bool( + xx: DataArray, + bits: Sequence[int] | None = None, + bitflags: int | None = None, + invert: bool = False, +) -> DataArray: + """ + Convert integer array into boolean array using bitmasks. + + :param xx: DataArray with integer values + :param bits: List of bit positions to convert to a bitflag mask (e.g. [0, 1, 2] -> 0b111) + :param bitflags: Integer value with bits set that will be used to extract the boolean mask (e.g. 0b00011000) + :param invert: Invert the mask + :return: DataArray with boolean values + """ + assert not ( + bits is None and bitflags is None + ), "Either bits or bitflags must be provided" + assert not ( + bits is not None and bitflags is not None + ), "Only one of bits or bitflags can be provided" + + if bitflags is None: + bitflags = 0 + + if bits is not None: + for b in bits: + bitflags |= 1 << b + + mask = (xx & bitflags) != 0 + + if invert: + mask = ~mask + + return mask + + +def enum_to_bool( + xx: DataArray, values: Sequence[Any], invert: bool = False +) -> DataArray: + """ + Convert array into boolean array using a list of invalid values. + + :param xx: DataArray with integer values + :param values: List of valid values to convert to a boolean mask + :param invert: Invert the mask + :return: DataArray with boolean values + """ + + mask = xx.isin(values) + + if invert: + mask = ~mask + + return mask + + +def scale_and_offset( + xx: DataArray | Dataset, + scale: float | None = None, + offset: float | None = None, + clip: Annotated[Sequence[int | float], 2] | None = None, +) -> DataArray | Dataset: + """ + Apply scale and offset to the DataArray. Leave scale and offset blank to use + the values from the DataArray's attrs. + + :param xx: DataArray with integer values + :param scale: Scale factor + :param offset: Offset + :return: DataArray with scaled and offset values + """ + + # For the Dataset case, we do this recursively for all variables. + if isinstance(xx, Dataset): + for var in xx.data_vars: + xx[var] = scale_and_offset(xx[var], scale, offset, clip=clip) + + return xx + + # "Scales" and "offsets" is used by GDAL. + if scale is None: + scale = xx.attrs.get("scales") + + if offset is None: + offset = xx.attrs.get("offsets") + + # Catch the case where one is provided and not the other... + if scale is None and offset is not None: + scale = 1.0 + + if offset is None and scale is not None: + offset = 0.0 + + # Store the nodata values to apply to the result + nodata = xx.odc.nodata + + # Stash the attributes + attrs = dict(xx.attrs.items()) + + if nodata is not None: + nodata_mask = xx == nodata + + # If both are missing, we can just return the original array. + if scale is not None and offset is not None: + xx = (xx * scale) + offset + + if clip is not None: + assert len(clip) == 2, "Clip must be a list of two values" + xx = xx.clip(clip[0], clip[1]) + + # Re-attach nodata + if nodata is not None: + xx = xx.where(~nodata_mask, other=nodata) + + xx.attrs = attrs # Not sure if this is required + + return xx + + +# pylint: disable-next=dangerous-default-value +def mask_invalid_data( + xx: DataArray | Dataset, + nodata: int | float | None = None, + skip_bands: Sequence[str] = [], +) -> DataArray | Dataset: + """ + Mask out invalid data values. + + :param xx: DataArray + :return: DataArray with invalid data values converted to np.nan. Note this will change the dtype to float. + """ + if isinstance(xx, Dataset): + for var in xx.data_vars: + if var not in skip_bands: + xx[var] = mask_invalid_data(xx[var], nodata) + return xx + + if nodata is None: + nodata = xx.odc.nodata + + assert nodata is not None, "Nodata value must be provided or available in attrs" + + xx = xx.where(xx != nodata) + xx.odc.nodata = np.nan + + return xx + + +# pylint: disable-next=too-many-arguments, dangerous-default-value +def mask_clouds( + xx: Dataset, + qa_name: str, + scale: float, + offset: float, + clip: tuple, + mask_func: Callable = enum_to_bool, # Pass the function for enum-based masks (bits_to_bool or enum_to_bool) + mask_func_args: dict = {}, + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, +) -> Dataset: + """ + General cloud masking function for both Landsat and Sentinel-2 products. + + :param xx: Dataset or DataArray + :param qa_name: QA band to use for masking + :param mask_classes: List of mask class values (e.g., cloud, cloud shadow) + :param scale: Scale value for the dataset + :param offset: Offset value for the dataset + :param clip: Clip range for the data + :param includ_cirrus: Whether to include cirrus in the mask + :param apply_mask: Apply the cloud mask to the data, erasing data where clouds are present + :param keep_qa: Keep the QA band in the output + :param return_mask: Return the mask as a variable called "mask" + :param enum_to_bool_func: Function to convert bit values to boolean mask (either bits_to_bool or enum_to_bool) + :return: Dataset or DataArray with invalid data values converted to np.nan. This will change the dtype to float. + """ + attrs = dict(xx.attrs.items()) + + # Retrieve the QA band + qa = xx[qa_name] + + # Drop the QA band and apply other preprocessing steps + xx = xx.drop_vars(qa_name) + xx = mask_invalid_data(xx) + xx = scale_and_offset(xx, scale=scale, offset=offset, clip=clip) + + # Generate the mask + mask = mask_func(qa, **mask_func_args) + + # Apply the mask if required + if apply_mask: + xx = xx.where(~mask) + + # Set 'nodata' to np.nan for all variables + for var in xx.data_vars: + xx[var].odc.nodata = np.nan + + # Optionally keep the QA band + if keep_qa: + xx[qa_name] = qa + + # Optionally return the mask + if return_mask: + xx["mask"] = mask + + xx.attrs = attrs + + return xx # type: ignore + + +def mask_ls( + xx: Dataset, + qa_name: str = "pixel_qa", + include_cirrus: bool = False, + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, +) -> Dataset: + """ + Perform cloud masking for Landsat Collection 2 products. + """ + mask_bits = [ + LANDSAT_C2L2_PIXEL_QA.CLOUD.value, + LANDSAT_C2L2_PIXEL_QA.CLOUD_SHADOW.value, + ] + if include_cirrus: + mask_bits.append(LANDSAT_C2L2_PIXEL_QA.CIRRUS.value) + + return mask_clouds( + xx=xx, + qa_name=qa_name, + scale=LANDSAT_C2L2_SCALE, + offset=LANDSAT_C2L2_OFFSET, + clip=(0.0, 1.0), + mask_func=bits_to_bool, + mask_func_args={"bits": mask_bits}, + apply_mask=apply_mask, + keep_qa=keep_qa, + return_mask=return_mask, + ) + + +def mask_s2( + xx: Dataset, + qa_name: str = "scl", + include_cirrus: bool = False, + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, +) -> Dataset: + """ + Perform cloud masking for Sentinel-2 L2A products. + """ + mask_values = [ + SENTINEL2_L2A_SCL.SATURATED_OR_DEFECTIVE.value, + SENTINEL2_L2A_SCL.CLOUD_MEDIUM_PROBABILITY.value, + SENTINEL2_L2A_SCL.CLOUD_HIGH_PROBABILITY.value, + SENTINEL2_L2A_SCL.CLOUD_SHADOWS.value, + ] + if include_cirrus: + mask_values.append(SENTINEL2_L2A_SCL.THIN_CIRRUS.value) + + return mask_clouds( + xx=xx, + qa_name=qa_name, + scale=SENTINEL2_L2A_SCALE, + offset=SENTINEL2_L2A_OFFSET, + mask_func=enum_to_bool, + mask_func_args={"values": mask_values}, + clip=(0.0, 1.0), + apply_mask=apply_mask, + keep_qa=keep_qa, + return_mask=return_mask, + ) diff --git a/tests/test_masking.py b/tests/test_masking.py new file mode 100644 index 00000000..ad0fac21 --- /dev/null +++ b/tests/test_masking.py @@ -0,0 +1,100 @@ +import numpy as np +from odc.geo.masking import ( + bits_to_bool, + enum_to_bool, + scale_and_offset, + mask_invalid_data, +) + + +from xarray import DataArray, Dataset + +# Top left is cloud, top right is cloud shadow +# Bottom left is both cloud and cloud shadow, bottom right is neither +xx_bits = DataArray( + [[0b00010000, 0b00001000], [0b00011000, 0b00000000]], + dims=("y", "x"), + attrs={"nodata": 0}, +) + +# Test some values, so 3 is cloud, 9 is cloud shadow +xx_values = DataArray([[3, 9], [3, 0]], dims=("y", "x"), attrs={"nodata": 0}) + +# Array with at least one zero to test nodata +xx_with_nodata = DataArray([[0, 1], [2, 3]], dims=("y", "x"), attrs={"nodata": 0}) + + +# Test bits_to_bool +def test_bits_to_bool(): + # Test with bits + mask = bits_to_bool(xx_bits, bits=[4, 3], bitflags=None) + assert mask.equals(DataArray([[True, True], [True, False]], dims=("y", "x"))) + + # Test with bitflags + mask = bits_to_bool(xx_bits, bits=None, bitflags=0b00011000) + assert mask.equals(DataArray([[True, True], [True, False]], dims=("y", "x"))) + + # Test with invert + mask = bits_to_bool(xx_bits, bits=[4, 3], bitflags=None, invert=True) + assert mask.equals(DataArray([[False, False], [False, True]], dims=("y", "x"))) + + mask = bits_to_bool(xx_bits, bits=None, bitflags=0b00010000, invert=True) + assert mask.equals(DataArray([[False, True], [False, True]], dims=("y", "x"))) + + +# Test enum_to_bool +def test_enum_to_bool(): + mask = enum_to_bool(xx_values, values=[3, 9]) + assert mask.equals(DataArray([[True, True], [True, False]], dims=("y", "x"))) + + mask = enum_to_bool(xx_values, values=[3, 9], invert=True) + assert mask.equals(DataArray([[False, False], [False, True]], dims=("y", "x"))) + + +# Test apply_scale_and_offset +def test_scale_and_offset(): + mask = scale_and_offset(xx_values, scale=1.0, offset=0.0) + assert mask.equals(DataArray([[3, 9], [3, 0]], dims=("y", "x"))) + + mask = scale_and_offset(xx_values) + assert mask.equals(DataArray([[3, 9], [3, 0]], dims=("y", "x"))) + + mask = scale_and_offset(xx_values, scale=2.0, offset=1.0) + assert mask.equals(DataArray([[7, 19], [7, 0]], dims=("y", "x"))) + + +# Test mask_invalid +def test_mask_invalid_data(): + mask = mask_invalid_data(xx_with_nodata) + assert mask.equals(DataArray([[np.nan, 1.0], [2.0, 3.0]], dims=("y", "x"))) + + mask = mask_invalid_data(xx_with_nodata, nodata=1) + assert mask.equals(DataArray([[0, np.nan], [2, 3]], dims=("y", "x"))) + + +# Test landsat masking +def test_mask_landsat(): + xx = Dataset( + {"pixel_qa": xx_bits, "red": scale_and_offset(xx_with_nodata, offset=20000)} + ) + print(xx) + + xx = xx.odc.mask_ls() + + assert xx["red"].equals( + DataArray([[np.nan, np.nan], [np.nan, 0.3500825]], dims=("y", "x")) + ) + + +def test_mask_sentinel2(): + xx = Dataset( + {"scl": xx_values, "red": scale_and_offset(xx_with_nodata, offset=8000)} + ) + + xx = xx.odc.mask_s2() + + assert xx["red"].equals( + DataArray([[np.nan, np.nan], [np.nan, 0.7003]], dims=("y", "x")) + ) + + assert xx.red.odc.nodata is not None