-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making functions in utils/metrics.py
compatible with torch.Tensor
inputs
#216
base: main
Are you sure you want to change the base?
Making functions in utils/metrics.py
compatible with torch.Tensor
inputs
#216
Conversation
…SNR` * updated tests to reflect changes
…careamics into fc/refac/torch_metrics
Float tensor. | ||
""" | ||
if isinstance(x, torch.Tensor) and x.dtype != torch.float64: | ||
warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the warning necessary? I ask because it will never lead to any information loss. Also, we do not return this tensor; we just use it to compute the metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd definitely remove it "in production", but I have temporarily put it there for debugging :)
Predicted image. | ||
|
||
Returns | ||
------- | ||
Union[float, torch.tensor] | ||
Scale invariant PSNR value. | ||
""" | ||
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) | ||
gt_ = _zero_mean(gt) / np.std(gt) | ||
# cast tensors to double dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic looks good.
What is the status of this PR? |
I remember there was a problem with dtypes changing from |
Let's leave it for now! |
Description
This PR aims to make all the functions in
utils/metric.py
torch
-compatible.torch.Tensor
inputs.torch.Tensor
's. Therefore, havingtorch
compatible metrics readily available avoids cumbersome type conversion within the NN models.numpy
(e.g.,np.max(arr)
) with more flexible method calls (e.g.,arr.max()
).Changes Made
torch
requires tensors to be of typedouble
(i.e.,torch.float64
) to compute things like mean and std. Therefore, I added a private method that caststorch.tensor
's todouble
to enablemean()
andstd()
computation. Casting tonumpy
happens instead in the PSNR computation that usesskimage
module, since the module explicitly requires that. Finally, added notebook to show current issues with the tests.Please ensure your PR meets the following requirements: