Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new matric "spatial distortion index" #2260

Merged
merged 41 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2a6dd63
added new matric "spatial distortion index"
ywchan2005 Dec 4, 2023
6bc62bf
added missing docs
ywchan2005 Dec 4, 2023
fed2f4a
moved kornia from image_test.txt to image.txt
ywchan2005 Dec 4, 2023
aab1a48
fixed typo in version in requirements
ywchan2005 Dec 4, 2023
440266b
fixed docstrings
ywchan2005 Dec 4, 2023
3fe20a9
changed kornia to lazy import
ywchan2005 Dec 5, 2023
fe9a835
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
15eb079
changed torchvision to lazy import
ywchan2005 Dec 5, 2023
4035bfe
fix style
ywchan2005 Dec 5, 2023
b4fbb08
fix type hint
ywchan2005 Dec 6, 2023
dace7f6
fix missing link in doc
ywchan2005 Dec 6, 2023
2fa4951
fix type hint
ywchan2005 Dec 7, 2023
075cf50
fix mypy error
ywchan2005 Dec 8, 2023
d6fccf6
remove dependence of kornia
ywchan2005 Dec 8, 2023
9df3782
fixed ruff error
ywchan2005 Dec 8, 2023
0f8a1e6
Update docs/source/image/spatial_distortion_index.rst
ywchan2005 Dec 11, 2023
caf9dea
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 11, 2023
7f38b0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
00f29a0
Update src/torchmetrics/image/d_s.py
ywchan2005 Dec 11, 2023
c64915b
fix style
ywchan2005 Dec 11, 2023
f169369
moved checking of tensor input to update
ywchan2005 Dec 11, 2023
cd9c4dc
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
c6087d6
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
af5288b
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
f7debf0
changed `ws` to `window_size`
ywchan2005 Dec 14, 2023
3137b47
changed `p` to `norm_order`
ywchan2005 Dec 14, 2023
c9a506c
Update src/torchmetrics/functional/image/d_s.py
ywchan2005 Dec 14, 2023
3b40466
fix assert regex in tests
ywchan2005 Dec 14, 2023
af4b50a
changed `_update` and `_compute` functions to take `ms`, `pan` and `p…
ywchan2005 Dec 15, 2023
115ead3
fix type hint
ywchan2005 Dec 18, 2023
fbfad53
Merge branch 'master' into spatial-distortion-index-metric
SkafteNicki Dec 20, 2023
a334561
skip on missing import
SkafteNicki Dec 20, 2023
b9af837
Merge branch 'spatial-distortion-index-metric' of https://github.com/…
SkafteNicki Dec 20, 2023
f926209
skip on missing import
SkafteNicki Dec 20, 2023
c647390
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
9047d67
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
a026d10
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
3991278
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
4df37e5
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 20, 2023
779e73d
Merge branch 'master' into spatial-distortion-index-metric
mergify[bot] Dec 21, 2023
5f308ed
Merge branch 'master' into spatial-distortion-index-metric
Borda Dec 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))


- Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260))


### Changed

- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))
Expand Down
21 changes: 21 additions & 0 deletions docs/source/image/spatial_distortion_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Spatial Distortion Index
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#########################
Spatial Distortion Index
#########################
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved

Module Interface
________________

.. autoclass:: torchmetrics.image.SpatialDistortionIndex
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.spatial_distortion_index
1 change: 1 addition & 0 deletions requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
scipy >1.0.0, <1.11.0
torchvision >=0.8, <0.17.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
kornia >=0.6.7, <0.7.1
1 change: 0 additions & 1 deletion requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scikit-image >=0.19.0, <=0.21.0
kornia >=0.6.7, <0.7.1
pytorch-msssim ==1.0.0
sewar >=0.4.4, <=0.4.6
numpy <1.25.0
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.d_s import spatial_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
Expand All @@ -31,6 +32,7 @@

__all__ = [
"spectral_distortion_index",
"spatial_distortion_index",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"peak_signal_noise_ratio",
Expand Down
246 changes: 246 additions & 0 deletions src/torchmetrics/functional/image/d_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple

import torch
from kornia.filters import filter2d
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved
from torch import Tensor
from torchvision.transforms.functional import resize
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.utilities.distributed import reduce


def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
"""Update and returns variables required to compute Spatial Distortion Index.

Args:
preds: High resolution multispectral image.
target: A dictionary containing the following keys:

- ``'ms'``: low resolution multispectral image.
- ``'pan'``: high resolution panchromatic image.
- ``'pan_lr'``: (optional) low resolution panchromatic image.
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved

Return:
A tuple of Tensors containing ``preds`` and ``target``.

Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds`` and ``target`` don't have the same batch and channel sizes.
ValueError:
If ``target`` doesn't have ``ms`` and ``pan``.

"""
if len(preds.shape) != 4:
raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.")
if "ms" not in target or "pan" not in target:
raise ValueError(f"Expected `target` to have keys ('ms', 'pan'). Got target: {target.keys()}")
for name, t in target.items():
if preds.dtype != t.dtype:
raise TypeError(
f"Expected `preds` and `{name}` to have the same data type. "
"Got preds: {preds.dtype} and {name}: {t.dtype}."
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved
)
for name, t in target.items():
if len(t.shape) != 4:
raise ValueError(f"Expected `{name}` to have BxCxHxW shape. Got {name}: {t.shape}.")
for name, t in target.items():
if preds.shape[:2] != t.shape[:2]:
raise ValueError(
f"Expected `preds` and `{name}` to have same batch and channel sizes. "
"Got preds: {preds.shape} and {name}: {t.shape}."
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved
)
return preds, target


def _spatial_distortion_index_compute(
preds: Tensor,
target: Dict[str, Tensor],
p: int = 1,
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved
ws: int = 7,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Compute Spatial Distortion Index (SpatialDistortionIndex_).

Args:
preds: High resolution multispectral image.
target: A dictionary containing the following keys:

- ``'ms'``: low resolution multispectral image.
- ``'pan'``: high resolution panchromatic image.
- ``'pan_lr'``: (optional) low resolution panchromatic image.

p: Order of the norm applied on the difference.
ws: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

Return:
Tensor with SpatialDistortionIndex score

Raises:
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.

Example:
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... }
>>> preds, target = _spatial_distortion_index_update(preds, target)
>>> _spatial_distortion_index_compute(preds, target)
tensor(0.0090)

"""
length = preds.shape[1]

ms = target["ms"]
pan = target["pan"]
pan_lr = target["pan_lr"] if "pan_lr" in target else None

preds_h, preds_w = preds.shape[-2:]
ms_h, ms_w = ms.shape[-2:]
pan_h, pan_w = pan.shape[-2:]
if preds_h != pan_h:
raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}")
if preds_w != pan_w:
raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}")
if preds_h % ms_h != 0:
raise ValueError(
f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}."
)
if preds_w % ms_w != 0:
raise ValueError(
f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}."
)
if pan_h % ms_h != 0:
raise ValueError(
f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}."
)
if pan_w % ms_w != 0:
raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.")
if ws >= ms_h or ws >= ms_w:
raise ValueError(f"Expected `ws` to be smaller than dimension of `ms`. Got ws: {ws}.")

if pan_lr is not None:
pan_lr_h, pan_lr_w = pan_lr.shape[-2:]
if pan_lr_h != ms_h:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}."
)
if pan_lr_w != ms_w:
raise ValueError(
f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}."
)
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved

pan_degraded = pan_lr
if pan_degraded is None:
kernel = torch.ones(size=(1, ws, ws))
pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True)
pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False)

m1 = torch.zeros(length, device=preds.device)
m2 = torch.zeros(length, device=preds.device)

for i in range(length):
m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1])
m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1])
diff = (m1 - m2).abs() ** p
return reduce(diff, reduction) ** (1 / p)


def spatial_distortion_index(
preds: Tensor,
target: Dict[str, Tensor],
p: int = 1,
ws: int = 7,
ywchan2005 marked this conversation as resolved.
Show resolved Hide resolved
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s.

Metric is used to compare the spatial distortion between two images.

Args:
preds: High resolution multispectral image.
target: A dictionary containing the following keys:

- ``'ms'``: low resolution multispectral image.
- ``'pan'``: high resolution panchromatic image.
- ``'pan_lr'``: (optional) low resolution panchromatic image.

p: Order of the norm applied on the difference.
ws: Window size of the filter applied to degrade the high resolution panchromatic image.
reduction: A method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

Return:
Tensor with SpatialDistortionIndex score

Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If ``preds`` and ``target`` don't have the same batch and channel sizes.
ValueError:
If ``target`` doesn't have ``ms`` and ``pan``.
ValueError:
If ``preds`` and ``pan`` don't have the same dimension.
ValueError:
If ``ms`` and ``pan_lr`` don't have the same dimension.
ValueError:
If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
ValueError:
If ``p`` is not a positive integer.
ValueError:
If ``ws`` is not a positive integer.

Example:
>>> from torchmetrics.functional.image import spatial_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... }
>>> spatial_distortion_index(preds, target)
tensor(0.0090)

"""
if not isinstance(p, int) or p <= 0:
raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.")
if not isinstance(ws, int) or ws <= 0:
raise ValueError(f"Expected `ws` to be a positive integer. Got ws: {ws}.")
preds, target = _spatial_distortion_index_update(preds, target)
return _spatial_distortion_index_compute(preds, target, p, ws, reduction)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.d_s import SpatialDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
from torchmetrics.image.psnr import PeakSignalNoiseRatio
Expand All @@ -30,6 +31,7 @@

__all__ = [
"SpectralDistortionIndex",
"SpatialDistortionIndex",
"ErrorRelativeGlobalDimensionlessSynthesis",
"PeakSignalNoiseRatio",
"PeakSignalNoiseRatioWithBlockedEffect",
Expand Down
Loading