Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making functions in utils/metrics.py compatible with torch.Tensor inputs #216

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

federico-carrara
Copy link
Collaborator

Description

This PR aims to make all the functions in utils/metric.py torch-compatible.

  • What: Made all functions for computing metrics compatible with torch.Tensor inputs.
  • Why: Most of these metrics are used either during model validation or evaluation, hence they need to handle torch.Tensor's. Therefore, having torch compatible metrics readily available avoids cumbersome type conversion within the NN models.
  • How: Replaced function calls explicitly requesting numpy (e.g., np.max(arr)) with more flexible method calls (e.g., arr.max()).

Changes Made

  • Added: torch requires tensors to be of type double (i.e., torch.float64) to compute things like mean and std. Therefore, I added a private method that casts torch.tensor's to double to enable mean() and std() computation. Casting to numpy happens instead in the PSNR computation that uses skimage module, since the module explicitly requires that. Finally, added notebook to show current issues with the tests.
  • Modified: replaced function calls with method calls.
  • Removed: Nothing.


Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

Float tensor.
"""
if isinstance(x, torch.Tensor) and x.dtype != torch.float64:
warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the warning necessary? I ask because it will never lead to any information loss. Also, we do not return this tensor; we just use it to compute the metrics.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd definitely remove it "in production", but I have temporarily put it there for debugging :)

Predicted image.

Returns
-------
Union[float, torch.tensor]
Scale invariant PSNR value.
"""
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
gt_ = _zero_mean(gt) / np.std(gt)
# cast tensors to double dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic looks good.

@jdeschamps
Copy link
Member

What is the status of this PR?

@federico-carrara
Copy link
Collaborator Author

I remember there was a problem with dtypes changing from numpy to torch (at least this is what I have in my notes). Unless this is a feature we want to have soon, I'll just close the PR!

@jdeschamps
Copy link
Member

Let's leave it for now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants