From 3526942a44365207382f13fdfb1f6df0f3694ced Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 3 Apr 2024 14:05:28 +0400 Subject: [PATCH 1/5] models: Switch to kornia AugmentationSequential --- torchgeo/models/dofa.py | 4 +--- torchgeo/models/resnet.py | 18 ++++++++---------- torchgeo/models/swin.py | 14 ++++++-------- torchgeo/models/vit.py | 10 ++++------ 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index bb6c043778c..e7c147d76fd 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -15,8 +15,6 @@ from torch import Tensor from torchvision.models._api import Weights, WeightsEnum -from ..transforms import AugmentationSequential - __all__ = ["DOFABase16_Weights", "DOFALarge16_Weights"] @@ -375,7 +373,7 @@ def forward(self, x: Tensor, wavelengths: list[float]) -> Tensor: # https://github.com/zhu-xlab/DOFA/blob/master/normalize_dataset.py # Normalization is sensor-dependent and is therefore left out -_dofa_transforms = AugmentationSequential(K.CenterCrop((224, 224)), data_keys=["image"]) +_dofa_transforms = K.AugmentationSequential(K.CenterCrop((224, 224)), data_keys=None) # https://github.com/pytorch/vision/pull/6883 # https://github.com/pytorch/vision/pull/7107 diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index e78c9394c65..4c3c4ec54c3 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -11,19 +11,17 @@ from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum -from ..transforms import AugmentationSequential - __all__ = ["ResNet50_Weights", "ResNet18_Weights"] # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # Normalization either by 10K or channel-wise with band statistics -_zhu_xlab_transforms = AugmentationSequential( +_zhu_xlab_transforms = K.AugmentationSequential( K.Resize(256), K.CenterCrop(224), K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)), - data_keys=["image"], + data_keys=None, ) # Normalization only available for RGB dataset, defined here: @@ -32,31 +30,31 @@ _max = torch.tensor([88, 103, 129]) _mean = torch.tensor([0.485, 0.456, 0.406]) _std = torch.tensor([0.229, 0.224, 0.225]) -_seco_transforms = AugmentationSequential( +_seco_transforms = K.AugmentationSequential( K.Resize(256), K.CenterCrop(224), K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=["image"], + data_keys=None, ) # Normalization only available for RGB dataset, defined here: # https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501 _mean = torch.tensor([0.485, 0.456, 0.406]) _std = torch.tensor([0.229, 0.224, 0.225]) -_gassl_transforms = AugmentationSequential( +_gassl_transforms = K.AugmentationSequential( K.Resize(224), K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=["image"], + data_keys=None, ) # https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 -_ssl4eo_l_transforms = AugmentationSequential( +_ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), - data_keys=["image"], + data_keys=None, ) # https://github.com/pytorch/vision/pull/6883 diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index b5c38acce6f..a7acac1ffc4 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -12,16 +12,14 @@ from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum -from ..transforms import AugmentationSequential - __all__ = ["Swin_V2_B_Weights"] # https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). # See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 # Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501 -_satlas_transforms = AugmentationSequential( - K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=["image"] +_satlas_transforms = K.AugmentationSequential( + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None ) # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). @@ -31,17 +29,17 @@ [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] ) # noqa: E501 _mean = torch.zeros_like(_std) -_sentinel2_ms_satlas_transforms = AugmentationSequential( +_sentinel2_ms_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=_mean, std=_std), Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), - data_keys=["image"], + data_keys=None, ) # Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 -_landsat_satlas_transforms = AugmentationSequential( +_landsat_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), - data_keys=["image"], + data_keys=None, ) # https://github.com/pytorch/vision/pull/6883 diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index 2c7e37fc2ed..60c875c203c 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -11,25 +11,23 @@ from timm.models.vision_transformer import VisionTransformer from torchvision.models._api import Weights, WeightsEnum -from ..transforms import AugmentationSequential - __all__ = ["ViTSmall16_Weights"] # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # Normalization either by 10K or channel-wise with band statistics -_zhu_xlab_transforms = AugmentationSequential( +_zhu_xlab_transforms = K.AugmentationSequential( K.Resize(256), K.CenterCrop(224), K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)), - data_keys=["image"], + data_keys=None, ) # https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 -_ssl4eo_l_transforms = AugmentationSequential( +_ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), - data_keys=["image"], + data_keys=None, ) # https://github.com/pytorch/vision/pull/6883 From 895d0b9e2de6ec103438928718c82f38251218d4 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 3 Apr 2024 14:17:46 +0400 Subject: [PATCH 2/5] Bump min version of kornia --- pyproject.toml | 4 ++-- requirements/min-reqs.old | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7c8f8726065..c22ae30d6aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.6.9+ required for kornia.augmentation.RandomBrightness - "kornia>=0.6.9", + # kornia 0.7.2+ required for dict support in AugmentationSequential + "kornia>=0.7.2", # lightly 1.4.4+ required for MoCo v3 support # lightly 1.4.26 is incompatible with the version of timm required by smp # https://github.com/microsoft/torchgeo/issues/1824 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 8889520d1bf..f0dcdd4ac2c 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,7 +4,7 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 -kornia==0.6.9 +kornia==0.7.2 lightly==1.4.4 lightning[pytorch-extra]==2.0.0 matplotlib==3.5.0 From e480cce3c5719322ebcdbab19051073b750a4aa0 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 3 Apr 2024 14:32:41 +0400 Subject: [PATCH 3/5] mypy-fix --- torchgeo/models/swin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index a7acac1ffc4..db7751922ba 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -31,14 +31,14 @@ _mean = torch.zeros_like(_std) _sentinel2_ms_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=_mean, std=_std), - Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), + Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), # type: ignore[arg-type] data_keys=None, ) # Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 _landsat_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), - Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), + Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), # type: ignore[arg-type] data_keys=None, ) From f2faa63243b360efdb772ade86eb5fede830d9e5 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 5 Apr 2024 16:42:52 +0400 Subject: [PATCH 4/5] Remove kornia warnings from ignore list --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c22ae30d6aa..d9abb9eea1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,8 +228,6 @@ filterwarnings = [ "ignore:Call to deprecated create function:DeprecationWarning:tensorboard.compat.proto", # https://github.com/treebeardtech/nbmake/issues/68 'ignore:The \(fspath. py.path.local\) argument to NotebookFile is deprecated:pytest.PytestDeprecationWarning:nbmake.pytest_plugin', - # https://github.com/kornia/kornia/issues/777 - "ignore:Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0:UserWarning:torch.nn.functional", # https://github.com/pytorch/pytorch/pull/24929 "ignore:Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0:UserWarning:torch.nn.functional", # https://github.com/scikit-image/scikit-image/issues/6663 @@ -259,8 +257,6 @@ filterwarnings = [ "ignore:Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package:UserWarning", # https://github.com/Lightning-AI/lightning/issues/18545 "ignore:LightningCLI's args parameter is intended to run from within Python like if it were from the command line.:UserWarning", - # https://github.com/kornia/kornia/pull/1611 - "ignore:`ColorJitter` is now following Torchvision implementation.:DeprecationWarning:kornia.augmentation._2d.intensity.color_jitter", # https://github.com/kornia/kornia/pull/1663 "ignore:`RandomGaussianBlur` has changed its behavior and now randomly sample sigma for both axes.:DeprecationWarning", # https://github.com/pytorch/pytorch/pull/111576 From 800b6b2e39fad2fb99bbf1af407c8f103ce3bf45 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 5 Apr 2024 17:00:40 +0400 Subject: [PATCH 5/5] Wrap Lambda in ImageSequential --- torchgeo/models/swin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index db7751922ba..a44e70ab2c6 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -31,14 +31,14 @@ _mean = torch.zeros_like(_std) _sentinel2_ms_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=_mean, std=_std), - Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), # type: ignore[arg-type] + K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), data_keys=None, ) # Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 _landsat_satlas_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), - Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)), # type: ignore[arg-type] + K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), data_keys=None, )