diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index 50b6d697..6a9c7377 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -81,6 +81,11 @@ def is_one_band(img): def write_tiff(img_wrt, filename, metadata): + # Adapting the number of bands to be compatible with the + # output dimensions. + count = img_wrt.shape[0] + metadata['count'] = count + with rasterio.open(filename, "w", **metadata) as dest: if is_one_band(img_wrt): img_wrt = img_wrt[None] diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index aff17ab0..1286ff57 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -65,6 +65,7 @@ def __init__( tiled_inference_parameters: TiledInferenceParameters = None, test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, + output_most_probable: bool = True, ) -> None: """Constructor @@ -110,6 +111,8 @@ def __init__( lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. + output_most_probable (bool): A boolean to define if the output during the inference will be just + for the most probable class or if it will include all of them. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -136,6 +139,12 @@ def __init__( self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) + self.output_most_probable = output_most_probable + + if output_most_probable: + self.select_classes = lambda y: y.argmax(dim=1) + else: + self.select_classes = lambda y: y def configure_losses(self) -> None: """Initialize the loss criterion. @@ -349,5 +358,7 @@ def model_forward(x): ) else: y_hat: Tensor = self(x, **rest).output - y_hat = y_hat.argmax(dim=1) + + y_hat = self.select_classes(y_hat) + return y_hat, file_names diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml index 52b9ee58..586aa6b1 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml @@ -51,12 +51,12 @@ data: - 2 - 1 - 0 - train_data_root: tests/ - train_label_data_root: tests/ - val_data_root: tests/ - val_label_data_root: tests/ - test_data_root: tests/ - test_label_data_root: tests/ + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs img_grep: "segmentation*input*.tif" label_grep: "segmentation*label*.tif" means: @@ -83,8 +83,8 @@ model: decoder: UperNetDecoder pretrained: true backbone: prithvi_swin_B - backbone_pretrained_cfg_overlay: - file: tests/prithvi_swin_B.pt + #backbone_pretrained_cfg_overlay: + #file: tests/prithvi_swin_B.pt backbone_drop_path_rate: 0.3 # backbone_window_size: 8 decoder_channels: 256 @@ -99,6 +99,7 @@ model: num_frames: 1 num_classes: 2 head_dropout: 0.5708022831486758 + output_most_probable: false loss: ce #aux_heads: # - name: aux_head