Skip to content

Commit

Permalink
models: Switch to kornia AugmentationSequential (#1979)
Browse files Browse the repository at this point in the history
* models: Switch to kornia AugmentationSequential

* Bump min version of kornia

* mypy-fix

* Remove kornia warnings from ignore list

* Wrap Lambda in ImageSequential
  • Loading branch information
ashnair1 authored Apr 5, 2024
1 parent 2425e29 commit abceea0
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 36 deletions.
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# kornia 0.6.9+ required for kornia.augmentation.RandomBrightness
"kornia>=0.6.9",
# kornia 0.7.2+ required for dict support in AugmentationSequential
"kornia>=0.7.2",
# lightly 1.4.4+ required for MoCo v3 support
# lightly 1.4.26 is incompatible with the version of timm required by smp
# https://github.com/microsoft/torchgeo/issues/1824
Expand Down Expand Up @@ -228,8 +228,6 @@ filterwarnings = [
"ignore:Call to deprecated create function:DeprecationWarning:tensorboard.compat.proto",
# https://github.com/treebeardtech/nbmake/issues/68
'ignore:The \(fspath. py.path.local\) argument to NotebookFile is deprecated:pytest.PytestDeprecationWarning:nbmake.pytest_plugin',
# https://github.com/kornia/kornia/issues/777
"ignore:Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0:UserWarning:torch.nn.functional",
# https://github.com/pytorch/pytorch/pull/24929
"ignore:Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0:UserWarning:torch.nn.functional",
# https://github.com/scikit-image/scikit-image/issues/6663
Expand Down Expand Up @@ -259,8 +257,6 @@ filterwarnings = [
"ignore:Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package:UserWarning",
# https://github.com/Lightning-AI/lightning/issues/18545
"ignore:LightningCLI's args parameter is intended to run from within Python like if it were from the command line.:UserWarning",
# https://github.com/kornia/kornia/pull/1611
"ignore:`ColorJitter` is now following Torchvision implementation.:DeprecationWarning:kornia.augmentation._2d.intensity.color_jitter",
# https://github.com/kornia/kornia/pull/1663
"ignore:`RandomGaussianBlur` has changed its behavior and now randomly sample sigma for both axes.:DeprecationWarning",
# https://github.com/pytorch/pytorch/pull/111576
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.21
kornia==0.6.9
kornia==0.7.2
lightly==1.4.4
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
Expand Down
4 changes: 1 addition & 3 deletions torchgeo/models/dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from torch import Tensor
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["DOFABase16_Weights", "DOFALarge16_Weights"]


Expand Down Expand Up @@ -375,7 +373,7 @@ def forward(self, x: Tensor, wavelengths: list[float]) -> Tensor:

# https://github.com/zhu-xlab/DOFA/blob/master/normalize_dataset.py
# Normalization is sensor-dependent and is therefore left out
_dofa_transforms = AugmentationSequential(K.CenterCrop((224, 224)), data_keys=["image"])
_dofa_transforms = K.AugmentationSequential(K.CenterCrop((224, 224)), data_keys=None)

# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
Expand Down
18 changes: 8 additions & 10 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["ResNet50_Weights", "ResNet18_Weights"]


# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=["image"],
data_keys=None,
)

# Normalization only available for RGB dataset, defined here:
Expand All @@ -32,31 +30,31 @@
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_seco_transforms = AugmentationSequential(
_seco_transforms = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
data_keys=None,
)

# Normalization only available for RGB dataset, defined here:
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_gassl_transforms = AugmentationSequential(
_gassl_transforms = K.AugmentationSequential(
K.Resize(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
data_keys=None,
)

# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
_ssl4eo_l_transforms = AugmentationSequential(
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
data_keys=["image"],
data_keys=None,
)

# https://github.com/pytorch/vision/pull/6883
Expand Down
18 changes: 8 additions & 10 deletions torchgeo/models/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["Swin_V2_B_Weights"]

# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
_satlas_transforms = AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=["image"]
_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
)

# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
Expand All @@ -31,17 +29,17 @@
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
) # noqa: E501
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = AugmentationSequential(
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
data_keys=None,
)

# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
_landsat_satlas_transforms = AugmentationSequential(
_landsat_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
data_keys=None,
)

# https://github.com/pytorch/vision/pull/6883
Expand Down
10 changes: 4 additions & 6 deletions torchgeo/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,23 @@
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["ViTSmall16_Weights"]

# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=["image"],
data_keys=None,
)

# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
_ssl4eo_l_transforms = AugmentationSequential(
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
data_keys=["image"],
data_keys=None,
)

# https://github.com/pytorch/vision/pull/6883
Expand Down

0 comments on commit abceea0

Please sign in to comment.