Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] Running ViT-MAE with many bands and its specificities. #395

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9ede646
Merge branch 'cli/predict_bands' of github.com:IBM/terratorch into ou…
Joao-L-S-Almeida Aug 19, 2024
7ed81d7
Cleaning memory and GPU cache in order to avoid 'CUDA out of memory' …
Joao-L-S-Almeida Aug 19, 2024
6d09629
Registering profiling
Joao-L-S-Almeida Aug 21, 2024
050dd7a
Merge branch 'out_of_memory' of github.com:IBM/terratorch into overwr…
Joao-L-S-Almeida Aug 21, 2024
2caaa6f
Modifying slicing
Joao-L-S-Almeida Aug 21, 2024
ca028ec
Merge branch 'overwrite_default_prithvi' of github.com:IBM/terratorch…
Joao-L-S-Almeida Aug 21, 2024
9b6ce6a
Run the tests again
Joao-L-S-Almeida Aug 22, 2024
91ae617
Merge branch 'overwrite_default_prithvi' of github.com:IBM/terratorch…
Joao-L-S-Almeida Aug 23, 2024
1dc2c1e
Adjustments
Joao-L-S-Almeida Aug 23, 2024
6681f81
Here we also need to support others ways to define bands
Joao-L-S-Almeida Sep 2, 2024
c60bb10
list would be better
Joao-L-S-Almeida Sep 2, 2024
b3346e6
merging with main
Joao-L-S-Almeida Sep 9, 2024
ff29037
list is also an acceptable format
Joao-L-S-Almeida Sep 9, 2024
8df33ba
closed interval
Joao-L-S-Almeida Sep 9, 2024
8e98cd9
Testing MLP decoder
Joao-L-S-Almeida Nov 4, 2024
b94e308
Using MLP decoder
Joao-L-S-Almeida Nov 4, 2024
9b31014
missing import
Joao-L-S-Almeida Nov 4, 2024
703cf69
Merge branch 'overwrite_default_prithvi' of github.com:IBM/terratorch…
Joao-L-S-Almeida Nov 4, 2024
055f075
patching along bands dimension
Joao-L-S-Almeida Jan 24, 2025
08706d7
using band patching
Joao-L-S-Almeida Jan 24, 2025
62fd107
minor fixes
Joao-L-S-Almeida Jan 25, 2025
825f7be
debugging
Joao-L-S-Almeida Jan 27, 2025
9c9f59c
Defining the first dimension for the 3D convolution
Joao-L-S-Almeida Jan 27, 2025
4f75596
Is it necessary to transpose the input tensor or not ?
Joao-L-S-Almeida Jan 27, 2025
bf1090b
Evaluaitng the number of c patches
Joao-L-S-Almeida Jan 27, 2025
45ab81b
Using the proper number of channels patches
Joao-L-S-Almeida Jan 27, 2025
92f3942
Testing the patching strategy
Joao-L-S-Almeida Jan 27, 2025
57883a5
merging
Joao-L-S-Almeida Jan 28, 2025
fa53f6d
Merge branch 'patch/spec' into overwrite_default_prithvi_main
Joao-L-S-Almeida Jan 28, 2025
04cfac1
minor fixes
Joao-L-S-Almeida Jan 28, 2025
4b6e5a9
Merge branch 'main' into patch/spec
Joao-L-S-Almeida Jan 31, 2025
8f2fcc9
Merge branch 'main' into overwrite_default_prithvi_main
Joao-L-S-Almeida Jan 31, 2025
97435d4
workaraound to avoid input arguments repetition (specifically 'model_…
Joao-L-S-Almeida Jan 31, 2025
c4398ae
Merge branch 'patch/spec' into overwrite_default_prithvi_main
Joao-L-S-Almeida Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
data = data.fillna(nan_replace)
return data

def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
if bands_intervals is None:
return None
bands = []
for element in bands_intervals:
# if its an interval
if isinstance(element, tuple):
if len(element) != 2: # noqa: PLR2004
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1]))
bands.extend(expanded_element)
else:
bands.append(element)
return bands


class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
"""GenericNonGeoSegmentationDataset"""
Expand Down
2 changes: 1 addition & 1 deletion terratorch/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple[
msg = "When defining an interval, a tuple of two integers should be passed,\
defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1] + 1))
expanded_element = list(range(element[0], element[1]))
bands.extend(expanded_element)
else:
bands.append(element)
Expand Down
55 changes: 46 additions & 9 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,35 @@ def __init__(
input_size: tuple[int, int, int] = (1, 224, 224),
patch_size: tuple[int, int, int] = (1, 16, 16),
in_chans: int = 3,
tub_size: int = 1,
embed_dim: int = 768,
band_patch_size: int = None,
norm_layer: nn.Module | None = None,
flatten: bool = True,
bias: bool = True,
):
super().__init__()
self.input_size = input_size
self.patch_size = patch_size
self.tub_size = tub_size
self.in_chans = in_chans
self.band_patch_size = band_patch_size
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
assert self.grid_size >= [1,1,1], "Patch size is bigger than input size."
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten

self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
# When spectral patching is used, some adaptations are required
if self.band_patch_size:
kernel_size = (self.band_patch_size, self.patch_size[1], self.patch_size[2])
first_conv_dim = tub_size
self.dim_transposer = lambda x: x.transpose(2, 1)
else:
kernel_size = self.patch_size
first_conv_dim = in_chans
self.dim_transposer = lambda x: x

self.proj = nn.Conv3d(first_conv_dim, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
Expand All @@ -159,6 +174,9 @@ def forward(self, x):
warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")

# When spectral patching is used the tensor must be transposed in order
# to operate over the proper dimension.
x = self.dim_transposer(x)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
Expand Down Expand Up @@ -234,6 +252,7 @@ class PrithviViT(nn.Module):
def __init__(self,
img_size: int | tuple[int, int] = 224,
patch_size: int | tuple[int, int, int] = (1, 16, 16),
band_patch_size: int = None,
num_frames: int = 1,
in_chans: int = 3,
embed_dim: int = 1024,
Expand All @@ -256,10 +275,20 @@ def __init__(self,
if isinstance(patch_size, int):
patch_size = (1, patch_size, patch_size)

self.band_patch_size = band_patch_size

# If spectral patching is being used, we need a way to evaluate the
# extra number of patches.
if self.band_patch_size:
self.eval_c_patches = lambda c: c // self.patch_embed.band_patch_size
else:
self.eval_c_patches = lambda c: 1

# 3D patch embedding
self.patch_embed = PatchEmbed(
input_size=(num_frames,) + self.img_size,
patch_size=patch_size,
band_patch_size=band_patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
Expand Down Expand Up @@ -336,7 +365,7 @@ def random_masking(self, sequence, mask_ratio, noise=None):

return sequence_unmasked, mask, ids_restore

def interpolate_pos_encoding(self, x, t, w, h):
def interpolate_pos_encoding(self, x, t, c, w, h):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
Expand All @@ -348,6 +377,7 @@ def interpolate_pos_encoding(self, x, t, w, h):

class_pos_embed = self.pos_embed[:, :1]
patch_pos_embed = self.pos_embed[:, 1:]
c_patches = self.eval_c_patches(c)
t_patches = t // self.patch_embed.patch_size[0]
w_patches = w // self.patch_embed.patch_size[1]
h_patches = h // self.patch_embed.patch_size[2]
Expand All @@ -357,10 +387,11 @@ def interpolate_pos_encoding(self, x, t, w, h):

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(h_patches, w_patches),
size=(c_patches*h_patches, w_patches), # Accounting the extra patches produced by the spectral patching
mode='bicubic',
align_corners=True,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

Expand All @@ -373,12 +404,12 @@ def forward(
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
# add time dim
x = x.unsqueeze(2)
t, h, w = x.shape[-3:]
t, c, h, w = x.shape[-4:]

# embed patches
x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(x, t, h, w)
pos_embed = self.interpolate_pos_encoding(x, t, c, h, w)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand Down Expand Up @@ -414,12 +445,13 @@ def forward_features(
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
# add time dim
x = x.unsqueeze(2)
t, h, w = x.shape[-3:]
c, t, h, w = x.shape[-4:]

# embed patches

x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(x, t, h, w)
pos_embed = self.interpolate_pos_encoding(x, t, c, h, w)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand All @@ -444,22 +476,27 @@ def forward_features(

x = self.norm(x)
out[-1] = x

return out

def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
out = []
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
c = self.eval_c_patches(self.patch_embed.in_chans)

for x in features:
x_no_token = x[:, 1:, :]
number_of_tokens = x_no_token.shape[1]
tokens_per_timestep = number_of_tokens // effective_time_dim
tokens_per_timestep = number_of_tokens // effective_time_dim // c
h = int(np.sqrt(tokens_per_timestep))

encoded = rearrange(
x_no_token,
"batch (t h w) e -> batch (t e) h w",
"batch (t h w c) e -> batch (t e) (c h) w",
e=self.embed_dim,
t=effective_time_dim,
h=h,
c=c,
)
out.append(encoded)
return out
Expand Down
36 changes: 29 additions & 7 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
PRITHVI_V2_MEAN = [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0]
PRITHVI_V2_STD = [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0]

# TODO This operation is probably a workaround. For some reason the variable
# "model_bands" is being repeated. It's necessary to check the reason for it.
def _overwrite_with_kwargs(extra_kwargs, kwargs):

for k in extra_kwargs.keys():
if k in kwargs.keys():
extra_kwargs[k] = kwargs.pop(k)
return extra_kwargs, kwargs

def _cfg(**kwargs):
return {
Expand Down Expand Up @@ -255,7 +263,9 @@ def prithvi_eo_tiny(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_tiny", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -265,7 +275,9 @@ def prithvi_eo_v1_100(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_v1_100", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -275,7 +287,9 @@ def prithvi_eo_v2_300(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_v2_300", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -285,7 +299,9 @@ def prithvi_eo_v2_600(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_v2_600", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -295,7 +311,9 @@ def prithvi_eo_v2_300_tl(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_v2_300_tl", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -305,7 +323,9 @@ def prithvi_eo_v2_600_tl(
bands: list[HLSBands] | None = None,
**kwargs,
) -> PrithviViT:

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return _create_prithvi("prithvi_eo_v2_600_tl", pretrained=pretrained, model_bands=bands, **kwargs)


Expand All @@ -319,7 +339,9 @@ def prithvi_vit_tiny(

warnings.warn(f"The model prithvi_vit_tiny was renamed to prithvi_eo_tiny. "
f"prithvi_vit_tiny will be removed in a future version.", FutureWarning)

vars_updated, kwargs = _overwrite_with_kwargs({"pretrained": pretrained, "model_bands": model_bands}, kwargs)
pretrained = vars_updated["pretrained"]
bands = vars_updated["model_bands"]
return prithvi_eo_tiny(pretrained=pretrained, model_bands=bands, **kwargs)


Expand Down
6 changes: 4 additions & 2 deletions terratorch/models/decoders/mlp_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,29 @@

from torch import Tensor, nn
import torch

from terratorch.registry import TERRATORCH_DECODER_REGISTRY


@TERRATORCH_DECODER_REGISTRY.register

class MLPDecoder(nn.Module):
"""Identity decoder. Useful to pass the feature straight to the head."""

def __init__(self, embed_dim: int, channels: int = 100, out_dim:int = 100, activation: str = "ReLU", out_index=-1) -> None:
"""Constructor

Args:
embed_dim (int): Input embedding dimension
out_index (int, optional): Index of the input list to take.. Defaults to -1.
"""

super().__init__()
self.embed_dim = embed_dim
self.channels = channels
self.dim = out_index
self.n_inputs = len(self.embed_dim)
self.out_channels = self.embed_dim[self.dim]
self.output_embed_dim = self.out_channels
self.hidden_layer = torch.nn.Linear(self.out_channels*self.n_inputs, self.out_channels)
self.activation = getattr(nn, activation)()

Expand Down
23 changes: 23 additions & 0 deletions terratorch/models/prithvi_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,27 @@

@MODEL_FACTORY_REGISTRY.register
class PrithviModelFactory(ModelFactory):

@staticmethod
def _generate_bands_intervals(bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
if bands_intervals is None:
return None
bands = []
for element in bands_intervals:
# if its an interval
if isinstance(element, list) or isinstance(element, tuple):
if len(element) != 2: # noqa: PLR2004
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1]))
bands.extend(expanded_element)
else:
bands.append(element)
return bands

def __init__(self) -> None:
self._factory: EncoderDecoderFactory = EncoderDecoderFactory()

def build_model(
self,
task: str,
Expand Down Expand Up @@ -72,7 +91,11 @@ def build_model(
Returns:
nn.Module: Full model with encoder, decoder and head.
"""
bands = self._generate_bands_intervals(bands)
print(bands)

warnings.warn("PrithviModelFactory is deprecated. Please switch to EncoderDecoderFactory.", stacklevel=1)

if in_channels is None:
in_channels = len(bands)
# TODO: support auxiliary heads
Expand Down
6 changes: 6 additions & 0 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

# Avoiding GPU memory overloading
# Removing GPU cache
torch.cuda.empty_cache()
# Forcing the Python garbage collector
gc.collect()

y_hat = self(x).output
y_hat = y_hat.argmax(dim=1)
return y_hat, file_names
5 changes: 4 additions & 1 deletion terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from functools import partial
from typing import Any
import gc

import logging
import lightning
Expand Down Expand Up @@ -386,10 +387,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T

def model_forward(x):
return self(x).output

if self.tiled_inference_parameters:
# TODO: tiled inference does not work with additional input data (**rest)
y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
else:

y_hat: Tensor = self(x, **rest).output

return y_hat, file_names
Loading
Loading