From 7ed81d7c2992243cabc1c01da7175c975357b90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 19 Aug 2024 10:46:06 -0400 Subject: [PATCH 01/23] Cleaning memory and GPU cache in order to avoid 'CUDA out of memory' error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/classification_tasks.py | 6 ++++++ terratorch/tasks/regression_tasks.py | 7 +++++++ terratorch/tasks/segmentation_tasks.py | 8 +++++++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 5820e23d..f503d5a2 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -269,6 +269,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] + # Avoiding GPU memory overloading + # Removing GPU cache + torch.cuda.empty_cache() + # Forcing the Python garbage collector + gc.collect() + y_hat = self(x).output y_hat = y_hat.argmax(dim=1) return y_hat, file_names diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index a83a4ea3..1a3a41ba 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from typing import Any +import gc import lightning import matplotlib.pyplot as plt @@ -368,6 +369,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T def model_forward(x): return self(x).output + # Avoiding GPU memory overloading + # Removing GPU cache + torch.cuda.empty_cache() + # Forcing the Python garbage collector + gc.collect() + if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters) else: diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 5f123351..e85f5d97 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -301,7 +301,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics.update(y_hat_hard, y) - + torch.cuda.memory_summary(device=None, abbreviated=False) def on_test_epoch_end(self) -> None: self.log_dict(self.test_metrics.compute(), sync_dist=True) self.test_metrics.reset() @@ -324,6 +324,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T def model_forward(x): return self(x).output + # Avoiding GPU memory overloading + # Removing GPU cache + torch.cuda.empty_cache() + # Forcing the Python garbage collector + gc.collect() + if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters From 6d09629c770492c2e15689055a9e93685ac5ace8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 21 Aug 2024 12:25:58 -0300 Subject: [PATCH 02/23] Registering profiling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/segmentation_tasks.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index e85f5d97..f61666b9 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -16,6 +16,8 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference +from torch.profiler import profile, record_function, ProfilerActivity + BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -329,12 +331,13 @@ def model_forward(x): torch.cuda.empty_cache() # Forcing the Python garbage collector gc.collect() - - if self.tiled_inference_parameters: - y_hat: Tensor = tiled_inference( - model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters - ) - else: - y_hat: Tensor = self(x).output + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters + ) + else: + y_hat: Tensor = self(x).output y_hat = y_hat.argmax(dim=1) return y_hat, file_names From 2caaa6f016ba2013c050adc8246d0bcd4bee89fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 21 Aug 2024 17:30:58 -0300 Subject: [PATCH 03/23] Modifying slicing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 2bf4319e..19723f13 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -182,7 +182,7 @@ def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | if len(element) != 2: # noqa: PLR2004 msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" raise Exception(msg) - expanded_element = list(range(element[0], element[1] + 1)) + expanded_element = list(range(element[0], element[1])) bands.extend(expanded_element) else: bands.append(element) From 9b6ce6af3f66702eca356d9a82cbb17bd1e3f6e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 22 Aug 2024 17:51:39 -0300 Subject: [PATCH 04/23] Run the tests again MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/manufactured-finetune_prithvi_swin_B.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/manufactured-finetune_prithvi_swin_B.yaml b/tests/manufactured-finetune_prithvi_swin_B.yaml index 03cd7ea7..1c577544 100644 --- a/tests/manufactured-finetune_prithvi_swin_B.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B.yaml @@ -20,7 +20,7 @@ trainer: init_args: monitor: val/loss patience: 100 - max_epochs: 5 + max_epochs: 3 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true From 1dc2c1e0487d268c688512e0b9ffc349a59b8b7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 23 Aug 2024 19:07:47 -0400 Subject: [PATCH 05/23] Adjustments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 4 +++- terratorch/tasks/regression_tasks.py | 9 ++------ terratorch/tasks/segmentation_tasks.py | 22 ++++++------------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index ea6657b2..820e1bc4 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -93,6 +93,7 @@ def __init__( allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, @@ -185,6 +186,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands + self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands self.output_bands = output_bands self.rgb_indices = rgb_indices self.expand_temporal_dimension = expand_temporal_dimension @@ -427,7 +429,7 @@ def __init__( self.dataset_bands = dataset_bands self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands - self.predict_output_bands = predict_output_bands if predict_output_bands else dataset_bands + self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands self.output_bands = output_bands self.rgb_indices = rgb_indices diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 1a3a41ba..a03d2131 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -368,15 +368,10 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T def model_forward(x): return self(x).output - - # Avoiding GPU memory overloading - # Removing GPU cache - torch.cuda.empty_cache() - # Forcing the Python garbage collector - gc.collect() - + if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters) else: y_hat: Tensor = self(x).output + return y_hat, file_names diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index f61666b9..7bd57abc 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -16,8 +16,6 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference -from torch.profiler import profile, record_function, ProfilerActivity - BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -326,18 +324,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T def model_forward(x): return self(x).output - # Avoiding GPU memory overloading - # Removing GPU cache - torch.cuda.empty_cache() - # Forcing the Python garbage collector - gc.collect() - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - if self.tiled_inference_parameters: - y_hat: Tensor = tiled_inference( - model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters - ) - else: - y_hat: Tensor = self(x).output + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters + ) + else: + y_hat: Tensor = self(x).output + y_hat = y_hat.argmax(dim=1) return y_hat, file_names From 6681f819a91d5627098802e1bbc13697df44708e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 2 Sep 2024 11:42:18 -0300 Subject: [PATCH 06/23] Here we also need to support others ways to define bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/prithvi_model_factory.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/terratorch/models/prithvi_model_factory.py b/terratorch/models/prithvi_model_factory.py index 73ce819b..635333e7 100644 --- a/terratorch/models/prithvi_model_factory.py +++ b/terratorch/models/prithvi_model_factory.py @@ -31,6 +31,24 @@ class DecoderNotFoundError(Exception): @register_factory class PrithviModelFactory(ModelFactory): + + @staticmethod + def _generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None): + if bands_intervals is None: + return None + bands = [] + for element in bands_intervals: + # if its an interval + if isinstance(element, tuple): + if len(element) != 2: # noqa: PLR2004 + msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" + raise Exception(msg) + expanded_element = list(range(element[0], element[1])) + bands.extend(expanded_element) + else: + bands.append(element) + return bands + def build_model( self, task: str, @@ -80,7 +98,8 @@ def build_model( Returns: nn.Module: Full model with encoder, decoder and head. """ - bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] + bands = self._generate_bands_intervals(bands) + if in_channels is None: in_channels = len(bands) # TODO: support auxiliary heads From c60bb10fd57a3169fa356384b268a655ace28c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 2 Sep 2024 12:16:20 -0300 Subject: [PATCH 07/23] list would be better MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/prithvi_model_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/models/prithvi_model_factory.py b/terratorch/models/prithvi_model_factory.py index 635333e7..5bbd5b37 100644 --- a/terratorch/models/prithvi_model_factory.py +++ b/terratorch/models/prithvi_model_factory.py @@ -39,7 +39,7 @@ def _generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple bands = [] for element in bands_intervals: # if its an interval - if isinstance(element, tuple): + if isinstance(element, list): if len(element) != 2: # noqa: PLR2004 msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" raise Exception(msg) @@ -99,7 +99,7 @@ def build_model( nn.Module: Full model with encoder, decoder and head. """ bands = self._generate_bands_intervals(bands) - + print(bands) if in_channels is None: in_channels = len(bands) # TODO: support auxiliary heads From ff29037f5e68442037de25d36a14b5262dfbbe4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 9 Sep 2024 10:50:17 -0400 Subject: [PATCH 08/23] list is also an acceptable format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/prithvi_model_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/models/prithvi_model_factory.py b/terratorch/models/prithvi_model_factory.py index 5bbd5b37..a031b608 100644 --- a/terratorch/models/prithvi_model_factory.py +++ b/terratorch/models/prithvi_model_factory.py @@ -39,7 +39,7 @@ def _generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple bands = [] for element in bands_intervals: # if its an interval - if isinstance(element, list): + if isinstance(element, list) or isinstance(element, tuple): if len(element) != 2: # noqa: PLR2004 msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" raise Exception(msg) From 8df33bac09bb252adbc94ad311f804689d01758f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 9 Sep 2024 10:50:47 -0400 Subject: [PATCH 09/23] closed interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index ad9fe6d6..d6bd5537 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -48,7 +48,7 @@ def generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple[ msg = "When defining an interval, a tuple of two integers should be passed,\ defining start and end indices inclusive" raise Exception(msg) - expanded_element = list(range(element[0], element[1] + 1)) + expanded_element = list(range(element[0], element[1])) bands.extend(expanded_element) else: bands.append(element) From 8e98cd9dbdbb9ff179388345878ddb7bdb5d940e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 4 Nov 2024 12:51:02 -0300 Subject: [PATCH 10/23] Testing MLP decoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/decoders/mlp_decoder.py | 38 +++++ ...actured-finetune-finetune-mlp_decoder.yaml | 146 ++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 terratorch/models/decoders/mlp_decoder.py create mode 100644 tests/resources/configs/manufactured-finetune-finetune-mlp_decoder.yaml diff --git a/terratorch/models/decoders/mlp_decoder.py b/terratorch/models/decoders/mlp_decoder.py new file mode 100644 index 00000000..61999475 --- /dev/null +++ b/terratorch/models/decoders/mlp_decoder.py @@ -0,0 +1,38 @@ +# Copyright contributors to the Terratorch project + +"""Pass the features straight through +""" + +from torch import Tensor, nn +import torch +from terratorch.registry import TERRATORCH_DECODER_REGISTRY + + +@TERRATORCH_DECODER_REGISTRY.register +class MLPDecoder(nn.Module): + """Identity decoder. Useful to pass the feature straight to the head.""" + + def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activation: str = "ReLU", out_index=-1) -> None: + """Constructor + Args: + embed_dim (int): Input embedding dimension + out_index (int, optional): Index of the input list to take.. Defaults to -1. + """ + + super().__init__() + self.embed_dim = embed_dim + self.channels = channels + self.dim = out_index + self.n_inputs = len(self.embed_dim) + self.out_channels = self.embed_dim[self.dim] + self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels) + self.activation = getattr(nn, activation)() + + def forward(self, x: list[Tensor]): + + data_ = torch.cat(x, axis=1) + data_ = data_.permute(0, 2, 3, 1) + data_ = self.activation(self.hidden_layer(data_)) + data_ = data_.permute(0, 3, 1, 2) + + return data_ diff --git a/tests/resources/configs/manufactured-finetune-finetune-mlp_decoder.yaml b/tests/resources/configs/manufactured-finetune-finetune-mlp_decoder.yaml new file mode 100644 index 00000000..0b022733 --- /dev/null +++ b/tests/resources/configs/manufactured-finetune-finetune-mlp_decoder.yaml @@ -0,0 +1,146 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: cpu + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 2 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: MLPDecoder + pretrained: false + backbone: prithvi_vit_100 + decoder_activation: ReLU + backbone_drop_path_rate: 0.3 + num_frames: 1 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.Identity + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss From b94e3080ba22c2f7e74e85494495837e362492d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 4 Nov 2024 11:44:44 -0500 Subject: [PATCH 11/23] Using MLP decoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/decoders/__init__.py | 3 ++- terratorch/models/decoders/mlp_decoder.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/terratorch/models/decoders/__init__.py b/terratorch/models/decoders/__init__.py index c4b6465a..97d7585a 100644 --- a/terratorch/models/decoders/__init__.py +++ b/terratorch/models/decoders/__init__.py @@ -4,5 +4,6 @@ from terratorch.models.decoders.identity_decoder import IdentityDecoder from terratorch.models.decoders.satmae_head import SatMAEHead, SatMAEHeadViT from terratorch.models.decoders.upernet_decoder import UperNetDecoder +from terratorch.models.decoders.mlp_decoder import MLPDecoder -__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT"] +__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT", "MLPDecoder"] diff --git a/terratorch/models/decoders/mlp_decoder.py b/terratorch/models/decoders/mlp_decoder.py index 61999475..5f092346 100644 --- a/terratorch/models/decoders/mlp_decoder.py +++ b/terratorch/models/decoders/mlp_decoder.py @@ -5,10 +5,10 @@ from torch import Tensor, nn import torch -from terratorch.registry import TERRATORCH_DECODER_REGISTRY +#from terratorch.registry import TERRATORCH_DECODER_REGISTRY -@TERRATORCH_DECODER_REGISTRY.register +#@TERRATORCH_DECODER_REGISTRY.register class MLPDecoder(nn.Module): """Identity decoder. Useful to pass the feature straight to the head.""" @@ -25,6 +25,7 @@ def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activ self.dim = out_index self.n_inputs = len(self.embed_dim) self.out_channels = self.embed_dim[self.dim] + self.output_embed_dim = self.out_channels self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels) self.activation = getattr(nn, activation)() From 9b31014d83d045c96886c02eb1a85eef4aa60115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 4 Nov 2024 14:22:46 -0300 Subject: [PATCH 12/23] missing import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/decoders/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terratorch/models/decoders/__init__.py b/terratorch/models/decoders/__init__.py index c4b6465a..97d7585a 100644 --- a/terratorch/models/decoders/__init__.py +++ b/terratorch/models/decoders/__init__.py @@ -4,5 +4,6 @@ from terratorch.models.decoders.identity_decoder import IdentityDecoder from terratorch.models.decoders.satmae_head import SatMAEHead, SatMAEHeadViT from terratorch.models.decoders.upernet_decoder import UperNetDecoder +from terratorch.models.decoders.mlp_decoder import MLPDecoder -__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT"] +__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT", "MLPDecoder"] From 055f0754993b180301efb7a6a5f1d2a6e325ebd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 24 Jan 2025 13:03:21 -0300 Subject: [PATCH 13/23] patching along bands dimension MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index bf0dde12..0aad26fd 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -137,6 +137,7 @@ def __init__( patch_size: tuple[int, int, int] = (1, 16, 16), in_chans: int = 3, embed_dim: int = 768, + band_patch_size: int = None, norm_layer: nn.Module | None = None, flatten: bool = True, bias: bool = True, @@ -144,12 +145,18 @@ def __init__( super().__init__() self.input_size = input_size self.patch_size = patch_size + self.band_patch_size = band_patch_size self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] assert self.grid_size >= [1,1,1], "Patch size is bigger than input size." self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] self.flatten = flatten - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + if self.band_patch_size: + kernel_size = (self.band_patch_size, self.patch_size[1], self.patch_size[2]) + else: + kernel_size = self.patch_size + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): @@ -158,8 +165,9 @@ def forward(self, x): if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") - + print(x.shape) x = self.proj(x) + print(x.shape) if self.flatten: x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C x = self.norm(x) @@ -417,7 +425,9 @@ def forward_features( t, h, w = x.shape[-3:] # embed patches + print(x.shape) x = self.patch_embed(x) + print(x.shape) pos_embed = self.interpolate_pos_encoding(x, t, h, w) # add pos embed w/o cls token From 08706d73e75a52aec3d9d7cb17660ff4753d1032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 24 Jan 2025 17:19:16 -0300 Subject: [PATCH 14/23] using band patching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida number of patches for the bands direction Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 0aad26fd..73d5576a 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -136,8 +136,9 @@ def __init__( input_size: tuple[int, int, int] = (1, 224, 224), patch_size: tuple[int, int, int] = (1, 16, 16), in_chans: int = 3, + tub_size: int = 1, embed_dim: int = 768, - band_patch_size: int = None, + band_patch_size: int = 2, norm_layer: nn.Module | None = None, flatten: bool = True, bias: bool = True, @@ -145,6 +146,7 @@ def __init__( super().__init__() self.input_size = input_size self.patch_size = patch_size + self.tub_size = tub_size self.band_patch_size = band_patch_size self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] assert self.grid_size >= [1,1,1], "Patch size is bigger than input size." @@ -156,7 +158,7 @@ def __init__( else: kernel_size = self.patch_size - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) + self.proj = nn.Conv3d(tub_size, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): @@ -165,6 +167,7 @@ def forward(self, x): if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + x = x.transpose(2, 1) print(x.shape) x = self.proj(x) print(x.shape) @@ -344,7 +347,7 @@ def random_masking(self, sequence, mask_ratio, noise=None): return sequence_unmasked, mask, ids_restore - def interpolate_pos_encoding(self, x, t, w, h): + def interpolate_pos_encoding(self, x, t, c, w, h): """ Adapted from: - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, @@ -356,16 +359,18 @@ def interpolate_pos_encoding(self, x, t, w, h): class_pos_embed = self.pos_embed[:, :1] patch_pos_embed = self.pos_embed[:, 1:] + c_patches = c // self.patch_embed.band_patch_size t_patches = t // self.patch_embed.patch_size[0] w_patches = w // self.patch_embed.patch_size[1] h_patches = h // self.patch_embed.patch_size[2] n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) + print(n_sqrt, c_patches, t_patches) patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) - + print(patch_pos_embed.shape) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(h_patches, w_patches), + size=(c_patches, h_patches, w_patches), mode='bicubic', align_corners=True, ) @@ -381,12 +386,12 @@ def forward( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - t, h, w = x.shape[-3:] + t, c, h, w = x.shape[-4:] # embed patches x = self.patch_embed(x) - pos_embed = self.interpolate_pos_encoding(x, t, h, w) + pos_embed = self.interpolate_pos_encoding(x, t, c, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -422,14 +427,14 @@ def forward_features( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - t, h, w = x.shape[-3:] + t, c, h, w = x.shape[-4:] # embed patches print(x.shape) x = self.patch_embed(x) print(x.shape) - pos_embed = self.interpolate_pos_encoding(x, t, h, w) + pos_embed = self.interpolate_pos_encoding(x, t, c, h, w) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] From 62fd107615e6029271c196c6ea90fc762536ec04 Mon Sep 17 00:00:00 2001 From: Joao Lucas de Sousa Almeida Date: Fri, 24 Jan 2025 21:32:17 -0300 Subject: [PATCH 15/23] minor fixes Signed-off-by: Joao Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 73d5576a..2fd7a07d 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -365,12 +365,12 @@ def interpolate_pos_encoding(self, x, t, c, w, h): h_patches = h // self.patch_embed.patch_size[2] n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) - print(n_sqrt, c_patches, t_patches) + print(c_patches) patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) print(patch_pos_embed.shape) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(c_patches, h_patches, w_patches), + size=(c_patches * h_patches, w_patches), mode='bicubic', align_corners=True, ) @@ -387,7 +387,7 @@ def forward( # add time dim x = x.unsqueeze(2) t, c, h, w = x.shape[-4:] - + print(t, c, h, w) # embed patches x = self.patch_embed(x) @@ -427,7 +427,7 @@ def forward_features( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - t, c, h, w = x.shape[-4:] + c, t, h, w = x.shape[-4:] # embed patches print(x.shape) @@ -469,6 +469,7 @@ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list number_of_tokens = x_no_token.shape[1] tokens_per_timestep = number_of_tokens // effective_time_dim h = int(np.sqrt(tokens_per_timestep)) + print(f"Shape:{x_no_token.shape}") encoded = rearrange( x_no_token, "batch (t h w) e -> batch (t e) h w", From 825f7be92f4b796ed093ae184567167fa336e377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 10:03:26 -0300 Subject: [PATCH 16/23] debugging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2fd7a07d..09ca733b 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -365,15 +365,16 @@ def interpolate_pos_encoding(self, x, t, c, w, h): h_patches = h // self.patch_embed.patch_size[2] n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) - print(c_patches) patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) - print(patch_pos_embed.shape) + + print(f"patch_pos_embed: {patch_pos_embed.shape}") patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(c_patches * h_patches, w_patches), + size=(c_patches*h_patches, w_patches), mode='bicubic', align_corners=True, ) + print(f"patch_pos_embed: {patch_pos_embed.shape}") patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) @@ -430,12 +431,14 @@ def forward_features( c, t, h, w = x.shape[-4:] # embed patches - print(x.shape) + print(f"Before: {x.shape}") x = self.patch_embed(x) - print(x.shape) + print(f"After: {x.shape}") pos_embed = self.interpolate_pos_encoding(x, t, c, h, w) # add pos embed w/o cls token + print(f"x: {x.shape}") + print(f"pos_embed: {pos_embed.shape}") x = x + pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: From 9c9f59c27abf33a473d63580968342a106b5dc66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 11:38:15 -0300 Subject: [PATCH 17/23] Defining the first dimension for the 3D convolution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 09ca733b..2636721d 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -147,6 +147,7 @@ def __init__( self.input_size = input_size self.patch_size = patch_size self.tub_size = tub_size + self.in_chans = in_chans self.band_patch_size = band_patch_size self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] assert self.grid_size >= [1,1,1], "Patch size is bigger than input size." @@ -155,10 +156,12 @@ def __init__( if self.band_patch_size: kernel_size = (self.band_patch_size, self.patch_size[1], self.patch_size[2]) + first_conv_dim = tub_size else: kernel_size = self.patch_size + first_conv_dim = in_chans - self.proj = nn.Conv3d(tub_size, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) + self.proj = nn.Conv3d(first_conv_dim, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): @@ -168,9 +171,7 @@ def forward(self, x): warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") x = x.transpose(2, 1) - print(x.shape) x = self.proj(x) - print(x.shape) if self.flatten: x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C x = self.norm(x) @@ -367,14 +368,13 @@ def interpolate_pos_encoding(self, x, t, c, w, h): n_sqrt = int((patch_pos_embed.shape[1] / t_patches) ** 0.5) patch_pos_embed = patch_pos_embed.reshape(t_patches, n_sqrt, n_sqrt, self.embed_dim).permute(0, 3, 1, 2) - print(f"patch_pos_embed: {patch_pos_embed.shape}") patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(c_patches*h_patches, w_patches), mode='bicubic', align_corners=True, ) - print(f"patch_pos_embed: {patch_pos_embed.shape}") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) @@ -388,7 +388,7 @@ def forward( # add time dim x = x.unsqueeze(2) t, c, h, w = x.shape[-4:] - print(t, c, h, w) + # embed patches x = self.patch_embed(x) @@ -431,14 +431,11 @@ def forward_features( c, t, h, w = x.shape[-4:] # embed patches - print(f"Before: {x.shape}") + x = self.patch_embed(x) - print(f"After: {x.shape}") pos_embed = self.interpolate_pos_encoding(x, t, c, h, w) # add pos embed w/o cls token - print(f"x: {x.shape}") - print(f"pos_embed: {pos_embed.shape}") x = x + pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: @@ -462,23 +459,27 @@ def forward_features( x = self.norm(x) out[-1] = x + return out def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: out = [] effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] + c = self.patch_embed.in_chans // self.patch_embed.band_patch_size + for x in features: x_no_token = x[:, 1:, :] number_of_tokens = x_no_token.shape[1] - tokens_per_timestep = number_of_tokens // effective_time_dim + tokens_per_timestep = number_of_tokens // effective_time_dim // 3 h = int(np.sqrt(tokens_per_timestep)) - print(f"Shape:{x_no_token.shape}") + encoded = rearrange( x_no_token, - "batch (t h w) e -> batch (t e) h w", + "batch (t h w c) e -> batch (t e) (c h) w", e=self.embed_dim, t=effective_time_dim, h=h, + c=c, ) out.append(encoded) return out From 4f7559649be9355bee1b39fc45c40fdce75b4c00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 11:41:06 -0300 Subject: [PATCH 18/23] Is it necessary to transpose the input tensor or not ? MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2636721d..93f81a0f 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -157,9 +157,11 @@ def __init__( if self.band_patch_size: kernel_size = (self.band_patch_size, self.patch_size[1], self.patch_size[2]) first_conv_dim = tub_size + self.dim_transposer = lambda x: x.transpose(2, 1) else: kernel_size = self.patch_size first_conv_dim = in_chans + self.dim_transposer = lambda x: x self.proj = nn.Conv3d(first_conv_dim, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -170,7 +172,8 @@ def forward(self, x): if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") - x = x.transpose(2, 1) + + x = slef.dim_transposer(x) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C From bf1090bf691f7cea4e92986a4eb14525fcfc4c2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 11:56:47 -0300 Subject: [PATCH 19/23] Evaluaitng the number of c patches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 93f81a0f..f87f71cc 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -138,7 +138,7 @@ def __init__( in_chans: int = 3, tub_size: int = 1, embed_dim: int = 768, - band_patch_size: int = 2, + band_patch_size: int = None, norm_layer: nn.Module | None = None, flatten: bool = True, bias: bool = True, @@ -154,6 +154,7 @@ def __init__( self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] self.flatten = flatten + # When spectral patching is used, some adaptations are required if self.band_patch_size: kernel_size = (self.band_patch_size, self.patch_size[1], self.patch_size[2]) first_conv_dim = tub_size @@ -173,6 +174,8 @@ def forward(self, x): warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + # When spectral patching is used the tensor must be transposed in order + # to operate over the proper dimension. x = slef.dim_transposer(x) x = self.proj(x) if self.flatten: @@ -249,6 +252,7 @@ class PrithviViT(nn.Module): def __init__(self, img_size: int | tuple[int, int] = 224, patch_size: int | tuple[int, int, int] = (1, 16, 16), + band_patch_size: int = None, num_frames: int = 1, in_chans: int = 3, embed_dim: int = 1024, @@ -271,10 +275,20 @@ def __init__(self, if isinstance(patch_size, int): patch_size = (1, patch_size, patch_size) + self.band_patch_size = band_patch_size + + # If spectral patching is being used, we need a way to evaluate the + # extra number of patches. + if self.band_patch_size: + self.eval_c_patches = lambda c: c // self.patch_embed.band_patch_size + else: + self.eval_c_patches = lambda c: 1 + # 3D patch embedding self.patch_embed = PatchEmbed( input_size=(num_frames,) + self.img_size, patch_size=patch_size, + band_patch_size=band_patch_size, in_chans=in_chans, embed_dim=embed_dim, ) @@ -363,7 +377,7 @@ def interpolate_pos_encoding(self, x, t, c, w, h): class_pos_embed = self.pos_embed[:, :1] patch_pos_embed = self.pos_embed[:, 1:] - c_patches = c // self.patch_embed.band_patch_size + c_patches = self.eval_c_patches(c) t_patches = t // self.patch_embed.patch_size[0] w_patches = w // self.patch_embed.patch_size[1] h_patches = h // self.patch_embed.patch_size[2] @@ -373,7 +387,7 @@ def interpolate_pos_encoding(self, x, t, c, w, h): patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - size=(c_patches*h_patches, w_patches), + size=(c_patches*h_patches, w_patches), # Accounting the extra patches produced by the spectral patching mode='bicubic', align_corners=True, ) From 45ab81b8ac3d3ba6c4a6025cc7424c97af1e66b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 12:26:04 -0300 Subject: [PATCH 20/23] Using the proper number of channels patches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_mae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index f87f71cc..c0f414b8 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -176,7 +176,7 @@ def forward(self, x): # When spectral patching is used the tensor must be transposed in order # to operate over the proper dimension. - x = slef.dim_transposer(x) + x = self.dim_transposer(x) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C @@ -482,12 +482,12 @@ def forward_features( def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: out = [] effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] - c = self.patch_embed.in_chans // self.patch_embed.band_patch_size + c = self.eval_c_patches(self.patch_embed.in_chans) for x in features: x_no_token = x[:, 1:, :] number_of_tokens = x_no_token.shape[1] - tokens_per_timestep = number_of_tokens // effective_time_dim // 3 + tokens_per_timestep = number_of_tokens // effective_time_dim // c h = int(np.sqrt(tokens_per_timestep)) encoded = rearrange( From 92f3942d7d67b2243472e0fcb5f2f7d8e1a5834e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 27 Jan 2025 15:02:02 -0300 Subject: [PATCH 21/23] Testing the patching strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/test_backbones.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 546250d9..7110a7b5 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -104,6 +104,7 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_ {backbone.embed_dim} = {expected_t * backbone.embed_dim} but was {e.shape[1]}" gc.collect() + @pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) def test_out_indices(model_name, input_224): out_indices = (2, 4, 8, 10) @@ -117,6 +118,28 @@ def test_out_indices(model_name, input_224): assert torch.allclose(full_output[full_index], output[filtered_index]) gc.collect() +@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"]) +@pytest.mark.parametrize("band_patch_size", [2, 3, None]) +def test_band_patch_size(model_name, band_patch_size, input_224): + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, band_patch_size=band_patch_size) + + n_channels = input_224.shape[1] + img_size = input_224.shape[-1] + + full_output = backbone.forward_features(input_224) + patch_size = backbone.patch_embed.patch_size[-1] + band_patch_size = backbone.patch_embed.band_patch_size + + if band_patch_size: + c_patches = n_channels // band_patch_size + else: + c_patches = 1 + + n_patches = c_patches * (img_size // patch_size)**2 + + assert full_output[-1].shape[1] - 1 == n_patches, "The number of patches does not correspond to the expected one." + + gc.collect() @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): From 04cfac1b2b86ce54dceb143c68b697b99a7552e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 28 Jan 2025 15:44:59 -0300 Subject: [PATCH 22/23] minor fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datamodules/generic_pixel_wise_data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index b79d1850..2413672f 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -96,8 +96,6 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, From 97435d4771b2122017158dde1781897453b46ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 31 Jan 2025 14:30:32 -0300 Subject: [PATCH 23/23] workaraound to avoid input arguments repetition (specifically 'model_bands') MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_vit.py | 36 +++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9f85b82d..4c923b6c 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -27,6 +27,14 @@ PRITHVI_V2_MEAN = [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0] PRITHVI_V2_STD = [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0] +# TODO This operation is probably a workaround. For some reason the variable +# "model_bands" is being repeated. It's necessary to check the reason for it. +def _overwrite_with_kwargs(extra_kwargs, kwargs): + + for k in extra_kwargs.keys(): + if k in kwargs.keys(): + extra_kwargs[k] = kwargs.pop(k) + return extra_kwargs, kwargs def _cfg(**kwargs): return { @@ -255,7 +263,9 @@ def prithvi_eo_tiny( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_tiny", pretrained=pretrained, model_bands=bands, **kwargs) @@ -265,7 +275,9 @@ def prithvi_eo_v1_100( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_v1_100", pretrained=pretrained, model_bands=bands, **kwargs) @@ -275,7 +287,9 @@ def prithvi_eo_v2_300( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_v2_300", pretrained=pretrained, model_bands=bands, **kwargs) @@ -285,7 +299,9 @@ def prithvi_eo_v2_600( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_v2_600", pretrained=pretrained, model_bands=bands, **kwargs) @@ -295,7 +311,9 @@ def prithvi_eo_v2_300_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_v2_300_tl", pretrained=pretrained, model_bands=bands, **kwargs) @@ -305,7 +323,9 @@ def prithvi_eo_v2_600_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return _create_prithvi("prithvi_eo_v2_600_tl", pretrained=pretrained, model_bands=bands, **kwargs) @@ -319,7 +339,9 @@ def prithvi_vit_tiny( warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. " f"prithvi_vit_tiny will be removed in a future version.", FutureWarning) - + vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": model_bands}, kwargs) + pretrained = vars_updated["pretrained"] + bands = vars_updated["model_bands"] return prithvi_eo_tiny(pretrained=pretrained, model_bands=bands, **kwargs)