Skip to content

Commit

Permalink
Add new matric "spatial distortion index" (#2260)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Dec 21, 2023
1 parent b9fe394 commit 2cab7b3
Show file tree
Hide file tree
Showing 9 changed files with 966 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))


- Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260))


- Added `CriticalSuccessIndex` metric to image subpackage ([#2257](https://github.com/Lightning-AI/torchmetrics/pull/2257))


### Changed

- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))
Expand Down
21 changes: 21 additions & 0 deletions docs/source/image/spatial_distortion_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Spatial Distortion Index
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

########################
Spatial Distortion Index
########################

Module Interface
________________

.. autoclass:: torchmetrics.image.SpatialDistortionIndex
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.spatial_distortion_index
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/abstract/document/995823
.. _SpectralDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9
.. _SpatialDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9
.. _RelativeAverageSpectralError: https://www.semanticscholar.org/paper/Data-Fusion.-Definitions-and-Architectures-Fusion-Wald/51b2b81e5124b3bb7ec53517a5dd64d8e348cadf
.. _WMAPE: https://en.wikipedia.org/wiki/WMAPE
.. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.d_s import spatial_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
Expand All @@ -32,6 +33,7 @@

__all__ = [
"spectral_distortion_index",
"spatial_distortion_index",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"peak_signal_noise_ratio",
Expand Down
267 changes: 267 additions & 0 deletions src/torchmetrics/functional/image/d_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.utilities.distributed import reduce
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE

if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["_spatial_distortion_index_compute", "spatial_distortion_index"]


def _spatial_distortion_index_update(
preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
"""Update and returns variables required to compute Spatial Distortion Index.
Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.
Return:
A tuple of Tensors containing ``preds``, ``ms``, ``pan`` and ``pan_lr``.
Raises:
TypeError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
"""
if len(preds.shape) != 4:
raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.")
if preds.dtype != ms.dtype:
raise TypeError(
f"Expected `preds` and `ms` to have the same data type. Got preds: {preds.dtype} and ms: {ms.dtype}."
)
if preds.dtype != pan.dtype:
raise TypeError(
f"Expected `preds` and `pan` to have the same data type. Got preds: {preds.dtype} and pan: {pan.dtype}."
)
if pan_lr is not None and preds.dtype != pan_lr.dtype:
raise TypeError(
f"Expected `preds` and `pan_lr` to have the same data type."
f" Got preds: {preds.dtype} and pan_lr: {pan_lr.dtype}."
)
if len(ms.shape) != 4:
raise ValueError(f"Expected `ms` to have BxCxHxW shape. Got ms: {ms.shape}.")
if len(pan.shape) != 4:
raise ValueError(f"Expected `pan` to have BxCxHxW shape. Got pan: {pan.shape}.")
if pan_lr is not None and len(pan_lr.shape) != 4:
raise ValueError(f"Expected `pan_lr` to have BxCxHxW shape. Got pan_lr: {pan_lr.shape}.")
if preds.shape[:2] != ms.shape[:2]:
raise ValueError(
f"Expected `preds` and `ms` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and ms: {ms.shape}."
)
if preds.shape[:2] != pan.shape[:2]:
raise ValueError(
f"Expected `preds` and `pan` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and pan: {pan.shape}."
)
if pan_lr is not None and preds.shape[:2] != pan_lr.shape[:2]:
raise ValueError(
f"Expected `preds` and `pan_lr` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and pan_lr: {pan_lr.shape}."
)

preds_h, preds_w = preds.shape[-2:]
ms_h, ms_w = ms.shape[-2:]
pan_h, pan_w = pan.shape[-2:]
if preds_h != pan_h:
raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}")
if preds_w != pan_w:
raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}")
if preds_h % ms_h != 0:
raise ValueError(
f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}."
)
if preds_w % ms_w != 0:
raise ValueError(
f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}."
)
if pan_h % ms_h != 0:
raise ValueError(
f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}."
)
if pan_w % ms_w != 0:
raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.")

if pan_lr is not None:
pan_lr_h, pan_lr_w = pan_lr.shape[-2:]
if pan_lr_h != ms_h:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}."
)
if pan_lr_w != ms_w:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}."
)

return preds, ms, pan, pan_lr


def _spatial_distortion_index_compute(
preds: Tensor,
ms: Tensor,
pan: Tensor,
pan_lr: Optional[Tensor] = None,
norm_order: int = 1,
window_size: int = 7,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Compute Spatial Distortion Index (SpatialDistortionIndex_).
Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.
norm_order: Order of the norm applied on the difference.
window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
Return:
Tensor with SpatialDistortionIndex score
Raises:
ValueError
If ``window_size`` is smaller than dimension of ``ms``.
Example:
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> ms = torch.rand([16, 3, 16, 16])
>>> pan = torch.rand([16, 3, 32, 32])
>>> preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan)
>>> _spatial_distortion_index_compute(preds, ms, pan, pan_lr)
tensor(0.0090)
"""
length = preds.shape[1]

ms_h, ms_w = ms.shape[-2:]
if window_size >= ms_h or window_size >= ms_w:
raise ValueError(
f"Expected `window_size` to be smaller than dimension of `ms`. Got window_size: {window_size}."
)

if pan_lr is None:
if not _TORCHVISION_AVAILABLE:
raise ValueError(
"When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be "
"installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`."
)
from torchvision.transforms.functional import resize

from torchmetrics.functional.image.helper import _uniform_filter

pan_degraded = _uniform_filter(pan, window_size=window_size)
pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False)
else:
pan_degraded = pan_lr

m1 = torch.zeros(length, device=preds.device)
m2 = torch.zeros(length, device=preds.device)

for i in range(length):
m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1])
m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1])
diff = (m1 - m2).abs() ** norm_order
return reduce(diff, reduction) ** (1 / norm_order)


def spatial_distortion_index(
preds: Tensor,
ms: Tensor,
pan: Tensor,
pan_lr: Optional[Tensor] = None,
norm_order: int = 1,
window_size: int = 7,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s.
Metric is used to compare the spatial distortion between two images.
Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.
norm_order: Order of the norm applied on the difference.
window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
Return:
Tensor with SpatialDistortionIndex score
Raises:
TypeError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
ValueError:
If ``norm_order`` is not a positive integer.
ValueError:
If ``window_size`` is not a positive integer.
Example:
>>> from torchmetrics.functional.image import spatial_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> ms = torch.rand([16, 3, 16, 16])
>>> pan = torch.rand([16, 3, 32, 32])
>>> spatial_distortion_index(preds, ms, pan)
tensor(0.0090)
"""
if not isinstance(norm_order, int) or norm_order <= 0:
raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.")
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.")
preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr)
return _spatial_distortion_index_compute(preds, ms, pan, pan_lr, norm_order, window_size, reduction)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.image.csi import CriticalSuccessIndex
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.d_s import SpatialDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
from torchmetrics.image.psnr import PeakSignalNoiseRatio
Expand All @@ -31,6 +32,7 @@

__all__ = [
"SpectralDistortionIndex",
"SpatialDistortionIndex",
"ErrorRelativeGlobalDimensionlessSynthesis",
"PeakSignalNoiseRatio",
"PeakSignalNoiseRatioWithBlockedEffect",
Expand Down
Loading

0 comments on commit 2cab7b3

Please sign in to comment.