-
Notifications
You must be signed in to change notification settings - Fork 271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
about multi-scale derivative loss #60
Comments
I have the same question. IMO that single bit is a very important detail, and should have been included in the repo |
Hi guys, here is my implementation of multi-scale derivative loss import torch
import torch.nn as nn
import torch.nn.functional as F
def ssi_normalize_depth(depth):
median = torch.median(depth)
abs_diff = torch.abs(depth - median)
mean_abs_diff = torch.mean(abs_diff)
normalized_depth = (depth - median) / mean_abs_diff
return normalized_depth
class TrimMAELoss:
def __init__(self, trim=0.2):
self.trim = trim
def __call__(self, prediction, target):
res = (prediction - target).abs()
sorted_res, _ = torch.sort(res.view(-1), descending=False)
trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
return trimmed.sum() / len(res)
class MultiScaleDeriLoss(nn.Module):
def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
super().__init__()
self.name = "MultiScaleDerivativeLoss"
self.operator = operator
dtype = torch.float16 if amp else torch.float
self.operators = {
"Scharr": {
'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
},
"Laplace": {
'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
}
}
self.op_x = self.operators[operator]['x']
self.op_y = self.operators[operator]['y']
if norm == 1:
self.loss_function = nn.L1Loss(reduction='mean')
elif norm == 2:
self.loss_function == nn.MSELoss(reduction='mean')
if trim:
self.loss_function = TrimMAELoss()
self.ssi = ssi
self.scales = scales
def gradients(self, input_tensor):
op_x, op_y = self.op_x, self.op_y
groups = input_tensor.shape[1]
op_x = op_x.repeat(groups, 1, 1, 1)
op_y = op_y.repeat(groups, 1, 1, 1)
grad_x = F.conv2d(input_tensor, op_x, groups=groups)
grad_y = F.conv2d(input_tensor, op_y, groups=groups)
return grad_x, grad_y
def forward(self, prediction, target, mask=None):
if self.ssi:
prediction_ = ssi_normalize_depth(prediction)
target_ = ssi_normalize_depth(target)
else:
prediction_ = prediction
target_ = target
prediction_ = prediction_.unsqueeze(0)
target_ = target_.unsqueeze(0)
total_loss = 0.0
for scale in range(self.scales):
grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
grad_target_x, grad_target_y = self.gradients(target_)
loss_x = self.loss_function(grad_prediction_x, grad_target_x)
loss_y = self.loss_function(grad_prediction_y, grad_target_y)
total_loss += torch.mean(loss_x + loss_y)
prediction_ = F.interpolate(prediction_, scale_factor=0.5)
target_ = F.interpolate(target_, scale_factor=0.5)
return total_loss / self.scales If you have any question, feel free to let me know! |
Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:
|
Thanks for your attention.
By the way, for some of my datasets, my implementation sometimes does not work well. So there could be something more to do to modified this implementation. If you have any idea, welcome to reply me more. |
Thank you very much for your response once again. I have one more question: I noticed that the input image is directly resized to (1536, 1536) without maintaining the aspect ratio. Does this not affect the results? Additionally, for absolute depth, only fx is used. Why is fy not needed? |
|
Thank you for your explanation. |
Could you provide an implementation of multi-scale derivative loss? Or provide related resources?
The text was updated successfully, but these errors were encountered: