Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing the segmentation task to output multiple class labels #393

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def is_one_band(img):

def write_tiff(img_wrt, filename, metadata):

# Adapting the number of bands to be compatible with the
# output dimensions.
count = img_wrt.shape[0]
metadata['count'] = count

with rasterio.open(filename, "w", **metadata) as dest:
if is_one_band(img_wrt):
img_wrt = img_wrt[None]
Expand Down
13 changes: 12 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
output_most_probable: bool = True,
) -> None:
"""Constructor

Expand Down Expand Up @@ -110,6 +111,8 @@ def __init__(
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
output_most_probable (bool): A boolean to define if the output during the inference will be just
for the most probable class or if it will include all of them.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand All @@ -136,6 +139,12 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)
self.output_most_probable = output_most_probable

if output_most_probable:
self.select_classes = lambda y: y.argmax(dim=1)
else:
self.select_classes = lambda y: y

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -349,5 +358,7 @@ def model_forward(x):
)
else:
y_hat: Tensor = self(x, **rest).output
y_hat = y_hat.argmax(dim=1)

y_hat = self.select_classes(y_hat)

return y_hat, file_names
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ data:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
Expand All @@ -83,8 +83,8 @@ model:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
#backbone_pretrained_cfg_overlay:
#file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
Expand All @@ -99,6 +99,7 @@ model:
num_frames: 1
num_classes: 2
head_dropout: 0.5708022831486758
output_most_probable: false
loss: ce
#aux_heads:
# - name: aux_head
Expand Down