From fdcf4dc758988354f58d5ad4322f8ca402011a2a Mon Sep 17 00:00:00 2001 From: Rafael Soares Padilha Date: Thu, 1 Aug 2024 14:40:43 -0300 Subject: [PATCH] SAM - Normalization and ChipWindow (#184) This PR replaces the `ChipWindow` named tuple with a tuple type alias. This fixes some serialization/deserialization errors that caused the workflow to break. Additionally, this PR modifies how we scale and offset the raster values during normalization before SAM image encoder. After applying the raster's scale and offset normalization on the RGB bands, we now clip the values to the range [0,1] before multiplying them by 255. The lack of clipping operation was leading to inconsistencies in the segmentation masks outputs. --- .../automatic_segmentation.yaml | 2 +- ops/segment_anything/sam_inference.py | 2 +- .../combine_sam_masks.py | 8 ++++---- .../test_combine_sam_masks.py | 4 ++-- src/vibe_core/vibe_core/data/core_types.py | 17 ++--------------- src/vibe_lib/vibe_lib/segment_anything.py | 18 +++++++++++++----- 6 files changed, 23 insertions(+), 28 deletions(-) diff --git a/ops/segment_anything/automatic_segmentation.yaml b/ops/segment_anything/automatic_segmentation.yaml index 567706bb..58ad0f32 100644 --- a/ops/segment_anything/automatic_segmentation.yaml +++ b/ops/segment_anything/automatic_segmentation.yaml @@ -41,7 +41,7 @@ description: parameters: model_type: SAM's image encoder backbone architecture, among 'vit_h', 'vit_l', or 'vit_b'. Before running the workflow, make sure the desired model has been exported to the cluster by running `scripts/export_sam_models.py`. For more information, refer to the FarmVibes.AI troubleshooting page in the documentation. band_names: Name of raster bands that should be selected to compose the 3-channel images expected by SAM. If not provided, will try to use ["R", "G", "B"]. If only a single band name is provided, will replicate it through all three channels. - band_scaling: A list of floats to scale each band by to the range of [0.0, 1.0] or [0.0, 255.0]. If not provided, will default to the raster scaling parameter. If a list with a single value is provided, will use it for all three bands. + band_scaling: A list of floats to scale each band by to the range of [0.0, 1.0]. If not provided, will default to the raster scaling parameter. If a list with a single value is provided, will use it for all three bands. band_offset: A list of floats to offset each band by. If not provided, will default to the raster offset value. If a list with a single value is provided, will use it for all three bands. spatial_overlap: Percentage of spatial overlap between chips in the range of [0.0, 1.0). points_per_side: The number of points to be sampled along one side of the chip to be prompts. The total number of points is points_per_side**2. diff --git a/ops/segment_anything/sam_inference.py b/ops/segment_anything/sam_inference.py index 5e833062..f6f13020 100644 --- a/ops/segment_anything/sam_inference.py +++ b/ops/segment_anything/sam_inference.py @@ -497,7 +497,7 @@ def generate_masks_from_grid( meta = cast(Dict[str, Any], write_info_list[0]["meta"]) meta.update({**INT_COMPRESSION_KWARGS}) - write_window = ChipWindow( + write_window = ( int(read_window.col_off - dataset.offset.width), int(read_window.row_off - dataset.offset.height), int(read_window.width), diff --git a/ops/segment_anything_combine_masks/combine_sam_masks.py b/ops/segment_anything_combine_masks/combine_sam_masks.py index e041023f..e457b98d 100644 --- a/ops/segment_anything_combine_masks/combine_sam_masks.py +++ b/ops/segment_anything_combine_masks/combine_sam_masks.py @@ -12,10 +12,10 @@ def touch_chip_boundaries(bbox: BBox, chip_window: ChipWindow) -> bool: return ( - bbox[0] <= chip_window.col_offset - or bbox[1] <= chip_window.row_offset - or bbox[2] >= chip_window.col_offset + chip_window.width - or bbox[3] >= chip_window.row_offset + chip_window.height + bbox[0] <= chip_window[0] # col_offset + or bbox[1] <= chip_window[1] # row_offset + or bbox[2] >= chip_window[0] + chip_window[2] # col_offset + width + or bbox[3] >= chip_window[1] + chip_window[3] # row_offset + height ) diff --git a/ops/segment_anything_combine_masks/test_combine_sam_masks.py b/ops/segment_anything_combine_masks/test_combine_sam_masks.py index a3d26354..febd5f62 100644 --- a/ops/segment_anything_combine_masks/test_combine_sam_masks.py +++ b/ops/segment_anything_combine_masks/test_combine_sam_masks.py @@ -8,7 +8,7 @@ import xarray as xr from shapely import geometry as shpg -from vibe_core.data.core_types import ChipWindow, gen_guid +from vibe_core.data.core_types import gen_guid from vibe_core.data.rasters import CategoricalRaster, SamMaskRaster from vibe_dev.testing.op_tester import OpTester from vibe_lib.raster import save_raster_to_asset @@ -59,7 +59,7 @@ def create_segmented_raster( categories=["background", "foreground"], mask_score=[mask_score], mask_bbox=[tuple([float(c) for c in mask_bbox])], # type: ignore - chip_window=ChipWindow(0.0, 0.0, float(raster_size), float(raster_size)), + chip_window=(0.0, 0.0, float(raster_size), float(raster_size)), ) diff --git a/src/vibe_core/vibe_core/data/core_types.py b/src/vibe_core/vibe_core/data/core_types.py index 55271080..da6b4fa2 100644 --- a/src/vibe_core/vibe_core/data/core_types.py +++ b/src/vibe_core/vibe_core/data/core_types.py @@ -15,7 +15,6 @@ ClassVar, Dict, List, - NamedTuple, Optional, Tuple, Type, @@ -51,20 +50,8 @@ """Type alias for a time range, as a tuple of two `datetime` objects (start, end).""" -class ChipWindow(NamedTuple): - """Represent a window of a raster chip. - - Attributes: - col_offset: The column offset of the window with relation to the raster chip. - row_offset: The row offset of the window with relation to the raster chip. - width: The width of the window. - height: The height of the window. - """ - - col_offset: float - row_offset: float - width: float - height: float +ChipWindow = Tuple[float, float, float, float] +"""Type alias representing a raster chip window, as (col_offset, row_offset, width, height).""" def gen_guid(): diff --git a/src/vibe_lib/vibe_lib/segment_anything.py b/src/vibe_lib/vibe_lib/segment_anything.py index 648acb74..0d4267c3 100644 --- a/src/vibe_lib/vibe_lib/segment_anything.py +++ b/src/vibe_lib/vibe_lib/segment_anything.py @@ -11,12 +11,13 @@ from geopandas import GeoDataFrame from numpy.typing import NDArray from rasterio import Affine +from rasterio.windows import Window from shapely.geometry.base import BaseGeometry from torchvision.transforms.functional import resize from vibe_core.data import GeometryCollection, Raster from vibe_core.data.core_types import BBox, Point -from vibe_lib.spaceeye.chip import ChipDataset, Dims, Window +from vibe_lib.spaceeye.chip import ChipDataset, Dims LOGGER = logging.getLogger(__name__) @@ -473,7 +474,7 @@ def build_chip_preprocessing_operation( elif len(band_scaling) != len(band_names): raise ValueError(f"Expected one or three scaling parameters. Got {band_scaling}") else: - band_scaling = [raster.scale] * 3 + band_scaling = [float(raster.scale)] * 3 scale = np.array(band_scaling).reshape(1, 3, 1, 1) if band_offset: @@ -483,13 +484,20 @@ def build_chip_preprocessing_operation( elif len(band_offset) != len(band_names): raise ValueError(f"Expected one or three offset parameters. Got {band_offset}") else: - band_offset = [raster.offset] * 3 + band_offset = [float(raster.offset)] * 3 offset = np.array(band_offset).reshape(1, 3, 1, 1) def preprocessing_operation(chip: NDArray[Any]) -> NDArray[Any]: normalized_chip = chip[:, band_idx, :, :] * scale + offset - if np.min(normalized_chip) >= 0 and np.max(normalized_chip) <= 1: - normalized_chip = normalized_chip * 255.0 + if np.min(normalized_chip) < 0 or np.max(normalized_chip) > 1: + LOGGER.warning( + "Chip values are outside the expected range [0, 1] after scaling and offset. " + f"Found max of {np.max(normalized_chip)} and min of {np.min(normalized_chip)}." + "Will clip to [0, 1] and normalize to [0, 255]. Please, verify the band_scaling " + "and band_offset parameters of the workflow." + ) + normalized_chip = np.clip(normalized_chip, 0, 1) + normalized_chip = normalized_chip * 255.0 return normalized_chip.astype(np.float32) return preprocessing_operation