Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new matric "spatial distortion index" #2260

Merged
merged 41 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2a6dd63
added new matric "spatial distortion index"
ywchan2005 Dec 4, 2023
6bc62bf
added missing docs
ywchan2005 Dec 4, 2023
fed2f4a
moved kornia from image_test.txt to image.txt
ywchan2005 Dec 4, 2023
aab1a48
fixed typo in version in requirements
ywchan2005 Dec 4, 2023
440266b
fixed docstrings
ywchan2005 Dec 4, 2023
3fe20a9
changed kornia to lazy import
ywchan2005 Dec 5, 2023
fe9a835
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
15eb079
changed torchvision to lazy import
ywchan2005 Dec 5, 2023
4035bfe
fix style
ywchan2005 Dec 5, 2023
b4fbb08
fix type hint
ywchan2005 Dec 6, 2023
dace7f6
fix missing link in doc
ywchan2005 Dec 6, 2023
2fa4951
fix type hint
ywchan2005 Dec 7, 2023
075cf50
fix mypy error
ywchan2005 Dec 8, 2023
d6fccf6
remove dependence of kornia
ywchan2005 Dec 8, 2023
9df3782
fixed ruff error
ywchan2005 Dec 8, 2023
0f8a1e6
Update docs/source/image/spatial_distortion_index.rst
ywchan2005 Dec 11, 2023
caf9dea
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 11, 2023
7f38b0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
00f29a0
Update src/torchmetrics/image/d_s.py
ywchan2005 Dec 11, 2023
c64915b
fix style
ywchan2005 Dec 11, 2023
f169369
moved checking of tensor input to update
ywchan2005 Dec 11, 2023
cd9c4dc
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
c6087d6
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
af5288b
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
f7debf0
changed `ws` to `window_size`
ywchan2005 Dec 14, 2023
3137b47
changed `p` to `norm_order`
ywchan2005 Dec 14, 2023
c9a506c
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
3b40466
fix assert regex in tests
ywchan2005 Dec 14, 2023
af4b50a
changed `_update` and `_compute` functions to take `ms`, `pan` and `p…
ywchan2005 Dec 15, 2023
115ead3
fix type hint
ywchan2005 Dec 18, 2023
fbfad53
Merge branch 'master' into spatial-distortion-index-metric
SkafteNicki Dec 20, 2023
a334561
skip on missing import
SkafteNicki Dec 20, 2023
b9af837
Merge branch 'spatial-distortion-index-metric' of https://github.com/…
SkafteNicki Dec 20, 2023
f926209
skip on missing import
SkafteNicki Dec 20, 2023
c647390
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
9047d67
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
a026d10
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
3991278
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
4df37e5
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
779e73d
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 21, 2023
5f308ed
Merge branch 'master' into spatial-distortion-index-metric
Borda Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))


- Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260))


- Added `CriticalSuccessIndex` metric to image subpackage ([#2257](https://github.com/Lightning-AI/torchmetrics/pull/2257))


### Changed

- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))
Expand Down
21 changes: 21 additions & 0 deletions docs/source/image/spatial_distortion_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Spatial Distortion Index
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

########################
Spatial Distortion Index
########################

Module Interface
________________

.. autoclass:: torchmetrics.image.SpatialDistortionIndex
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.spatial_distortion_index
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/abstract/document/995823
.. _SpectralDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9
.. _SpatialDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9
.. _RelativeAverageSpectralError: https://www.semanticscholar.org/paper/Data-Fusion.-Definitions-and-Architectures-Fusion-Wald/51b2b81e5124b3bb7ec53517a5dd64d8e348cadf
.. _WMAPE: https://en.wikipedia.org/wiki/WMAPE
.. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.d_s import spatial_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
Expand All @@ -32,6 +33,7 @@

__all__ = [
"spectral_distortion_index",
"spatial_distortion_index",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"peak_signal_noise_ratio",
Expand Down
267 changes: 267 additions & 0 deletions src/torchmetrics/functional/image/d_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.utilities.distributed import reduce
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE

if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["_spatial_distortion_index_compute", "spatial_distortion_index"]


def _spatial_distortion_index_update(
preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
"""Update and returns variables required to compute Spatial Distortion Index.

Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.

Return:
A tuple of Tensors containing ``preds``, ``ms``, ``pan`` and ``pan_lr``.

Raises:
TypeError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.

"""
if len(preds.shape) != 4:
raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.")
if preds.dtype != ms.dtype:
raise TypeError(
f"Expected `preds` and `ms` to have the same data type. Got preds: {preds.dtype} and ms: {ms.dtype}."
)
if preds.dtype != pan.dtype:
raise TypeError(
f"Expected `preds` and `pan` to have the same data type. Got preds: {preds.dtype} and pan: {pan.dtype}."
)
if pan_lr is not None and preds.dtype != pan_lr.dtype:
raise TypeError(
f"Expected `preds` and `pan_lr` to have the same data type."
f" Got preds: {preds.dtype} and pan_lr: {pan_lr.dtype}."
)
if len(ms.shape) != 4:
raise ValueError(f"Expected `ms` to have BxCxHxW shape. Got ms: {ms.shape}.")
if len(pan.shape) != 4:
raise ValueError(f"Expected `pan` to have BxCxHxW shape. Got pan: {pan.shape}.")
if pan_lr is not None and len(pan_lr.shape) != 4:
raise ValueError(f"Expected `pan_lr` to have BxCxHxW shape. Got pan_lr: {pan_lr.shape}.")
if preds.shape[:2] != ms.shape[:2]:
raise ValueError(
f"Expected `preds` and `ms` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and ms: {ms.shape}."
)
if preds.shape[:2] != pan.shape[:2]:
raise ValueError(
f"Expected `preds` and `pan` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and pan: {pan.shape}."
)
if pan_lr is not None and preds.shape[:2] != pan_lr.shape[:2]:
raise ValueError(
f"Expected `preds` and `pan_lr` to have the same batch and channel sizes."
f" Got preds: {preds.shape} and pan_lr: {pan_lr.shape}."
)

preds_h, preds_w = preds.shape[-2:]
ms_h, ms_w = ms.shape[-2:]
pan_h, pan_w = pan.shape[-2:]
if preds_h != pan_h:
raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}")
if preds_w != pan_w:
raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}")
if preds_h % ms_h != 0:
raise ValueError(
f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}."
)
if preds_w % ms_w != 0:
raise ValueError(
f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}."
)
if pan_h % ms_h != 0:
raise ValueError(
f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}."
)
if pan_w % ms_w != 0:
raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.")

if pan_lr is not None:
pan_lr_h, pan_lr_w = pan_lr.shape[-2:]
if pan_lr_h != ms_h:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}."
)
if pan_lr_w != ms_w:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}."
)

return preds, ms, pan, pan_lr


def _spatial_distortion_index_compute(
preds: Tensor,
ms: Tensor,
pan: Tensor,
pan_lr: Optional[Tensor] = None,
norm_order: int = 1,
window_size: int = 7,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Compute Spatial Distortion Index (SpatialDistortionIndex_).

Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.
norm_order: Order of the norm applied on the difference.
window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

Return:
Tensor with SpatialDistortionIndex score

Raises:
ValueError
If ``window_size`` is smaller than dimension of ``ms``.

Example:
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> ms = torch.rand([16, 3, 16, 16])
>>> pan = torch.rand([16, 3, 32, 32])
>>> preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan)
>>> _spatial_distortion_index_compute(preds, ms, pan, pan_lr)
tensor(0.0090)

"""
length = preds.shape[1]

ms_h, ms_w = ms.shape[-2:]
if window_size >= ms_h or window_size >= ms_w:
raise ValueError(
f"Expected `window_size` to be smaller than dimension of `ms`. Got window_size: {window_size}."
)

if pan_lr is None:
if not _TORCHVISION_AVAILABLE:
raise ValueError(
"When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be "
"installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`."
)
from torchvision.transforms.functional import resize
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved

from torchmetrics.functional.image.helper import _uniform_filter

pan_degraded = _uniform_filter(pan, window_size=window_size)
pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False)
else:
pan_degraded = pan_lr

m1 = torch.zeros(length, device=preds.device)
m2 = torch.zeros(length, device=preds.device)

for i in range(length):
m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1])
m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1])
diff = (m1 - m2).abs() ** norm_order
return reduce(diff, reduction) ** (1 / norm_order)


def spatial_distortion_index(
preds: Tensor,
ms: Tensor,
pan: Tensor,
pan_lr: Optional[Tensor] = None,
norm_order: int = 1,
window_size: int = 7,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s.

Metric is used to compare the spatial distortion between two images.

Args:
preds: High resolution multispectral image.
ms: Low resolution multispectral image.
pan: High resolution panchromatic image.
pan_lr: Low resolution panchromatic image.
norm_order: Order of the norm applied on the difference.
window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

Return:
Tensor with SpatialDistortionIndex score

Raises:
TypeError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
ValueError:
If ``norm_order`` is not a positive integer.
ValueError:
If ``window_size`` is not a positive integer.

Example:
>>> from torchmetrics.functional.image import spatial_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> ms = torch.rand([16, 3, 16, 16])
>>> pan = torch.rand([16, 3, 32, 32])
>>> spatial_distortion_index(preds, ms, pan)
tensor(0.0090)

"""
if not isinstance(norm_order, int) or norm_order <= 0:
raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.")
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.")
preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr)
return _spatial_distortion_index_compute(preds, ms, pan, pan_lr, norm_order, window_size, reduction)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.image.csi import CriticalSuccessIndex
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.d_s import SpatialDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
from torchmetrics.image.psnr import PeakSignalNoiseRatio
Expand All @@ -31,6 +32,7 @@

__all__ = [
"SpectralDistortionIndex",
"SpatialDistortionIndex",
"ErrorRelativeGlobalDimensionlessSynthesis",
"PeakSignalNoiseRatio",
"PeakSignalNoiseRatioWithBlockedEffect",
Expand Down
Loading
Loading