Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about multi-scale derivative loss #60

Open
lianNice opened this issue Oct 28, 2024 · 7 comments
Open

about multi-scale derivative loss #60

lianNice opened this issue Oct 28, 2024 · 7 comments

Comments

@lianNice
Copy link

Could you provide an implementation of multi-scale derivative loss? Or provide related resources?

@AakashKumarNain
Copy link

AakashKumarNain commented Nov 7, 2024

I have the same question. IMO that single bit is a very important detail, and should have been included in the repo

@PeideChi
Copy link

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

@lianNice
Copy link
Author

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:

  1. When implementing the ssi_normalize_depth function, I noticed that you did not exclude the pixel points corresponding to infinitely far points (e.g., space) or points without depth values (depth value of 0) from normalization. Could you please explain why?

  2. During the training process, does the MultiScaleDeriLoss encounter any NaN or Inf issues?

  3. Additionally, if I want to incorporate sparse point cloud data for training, the ranging accuracy improves, but the structural consistency deteriorates. Do you have any suggestions for addressing this issue?

@PeideChi
Copy link

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:

  1. When implementing the ssi_normalize_depth function, I noticed that you did not exclude the pixel points corresponding to infinitely far points (e.g., space) or points without depth values (depth value of 0) from normalization. Could you please explain why?
  2. During the training process, does the MultiScaleDeriLoss encounter any NaN or Inf issues?
  3. Additionally, if I want to incorporate sparse point cloud data for training, the ranging accuracy improves, but the structural consistency deteriorates. Do you have any suggestions for addressing this issue?

Thanks for your attention.

  1. I do so because I just train DepthPro on my own datasets which are synthetic and contain images that have all valid relative depth. For your circumstance, you can add the depth_mask easily.
  2. No. The loss descends gradually and the training process goes on well.
  3. I do not have such data. Sorry not to be able to offer you suggestions. But trimming top 20% error depth can help you maybe?

By the way, for some of my datasets, my implementation sometimes does not work well. So there could be something more to do to modified this implementation. If you have any idea, welcome to reply me more.

@lianNice
Copy link
Author

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:

  1. When implementing the ssi_normalize_depth function, I noticed that you did not exclude the pixel points corresponding to infinitely far points (e.g., space) or points without depth values (depth value of 0) from normalization. Could you please explain why?
  2. During the training process, does the MultiScaleDeriLoss encounter any NaN or Inf issues?
  3. Additionally, if I want to incorporate sparse point cloud data for training, the ranging accuracy improves, but the structural consistency deteriorates. Do you have any suggestions for addressing this issue?

Thanks for your attention.

  1. I do so because I just train DepthPro on my own datasets which are synthetic and contain images that have all valid relative depth. For your circumstance, you can add the depth_mask easily.
  2. No. The loss descends gradually and the training process goes on well.
  3. I do not have such data. Sorry not to be able to offer you suggestions. But trimming top 20% error depth can help you maybe?

By the way, for some of my datasets, my implementation sometimes does not work well. So there could be something more to do to modified this implementation. If you have any idea, welcome to reply me more.

Thank you very much for your response once again. I have one more question: I noticed that the input image is directly resized to (1536, 1536) without maintaining the aspect ratio. Does this not affect the results?

Additionally, for absolute depth, only fx is used. Why is fy not needed?

@PeideChi
Copy link

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:

  1. When implementing the ssi_normalize_depth function, I noticed that you did not exclude the pixel points corresponding to infinitely far points (e.g., space) or points without depth values (depth value of 0) from normalization. Could you please explain why?
  2. During the training process, does the MultiScaleDeriLoss encounter any NaN or Inf issues?
  3. Additionally, if I want to incorporate sparse point cloud data for training, the ranging accuracy improves, but the structural consistency deteriorates. Do you have any suggestions for addressing this issue?

Thanks for your attention.

  1. I do so because I just train DepthPro on my own datasets which are synthetic and contain images that have all valid relative depth. For your circumstance, you can add the depth_mask easily.
  2. No. The loss descends gradually and the training process goes on well.
  3. I do not have such data. Sorry not to be able to offer you suggestions. But trimming top 20% error depth can help you maybe?

By the way, for some of my datasets, my implementation sometimes does not work well. So there could be something more to do to modified this implementation. If you have any idea, welcome to reply me more.

Thank you very much for your response once again. I have one more question: I noticed that the input image is directly resized to (1536, 1536) without maintaining the aspect ratio. Does this not affect the results?

Additionally, for absolute depth, only fx is used. Why is fy not needed?

  1. Resizing is a common method to match the resolution of your input image to the network processing resolution. In my opinion, if the image is resized to a larger resolution, there is merely no affection. Otherwise, the result would be degenerated.
  2. For calculating the metric (absolute) depth, it refers to mathematics of Computational Photography. You may search for it on Internet.

@lianNice
Copy link
Author

Hi guys, here is my implementation of multi-scale derivative loss

import torch
import torch.nn as nn
import torch.nn.functional as F

def ssi_normalize_depth(depth):
    median = torch.median(depth)
    abs_diff = torch.abs(depth - median)  
    mean_abs_diff = torch.mean(abs_diff)
    normalized_depth = (depth - median) / mean_abs_diff
    return normalized_depth

class TrimMAELoss:
    def __init__(self, trim=0.2):
        self.trim = trim

    def __call__(self, prediction, target):
        res = (prediction - target).abs()
        sorted_res, _ = torch.sort(res.view(-1), descending=False)
        trimmed = sorted_res[: int(len(res) * (1.0 - self.trim))]
        return trimmed.sum() / len(res)

class MultiScaleDeriLoss(nn.Module):
    def __init__(self, operator='Scharr', norm=1, scales=6, trim=False, ssi=False, amp=False):
        super().__init__()
        self.name = "MultiScaleDerivativeLoss"
        self.operator = operator
        dtype = torch.float16 if amp else torch.float
        self.operators = {
            "Scharr": {
                'x': torch.tensor([[[[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[-3, 10, -3], [0, 0, 0], [3, 10, 3]]]], dtype=dtype).cuda(),
            },
            "Laplace": {
                'x': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
                'y': torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=dtype).cuda(),
            }
        }
        self.op_x = self.operators[operator]['x']
        self.op_y = self.operators[operator]['y']
        if norm == 1:
            self.loss_function = nn.L1Loss(reduction='mean')
        elif norm == 2:
            self.loss_function == nn.MSELoss(reduction='mean')
        if trim:
            self.loss_function = TrimMAELoss()
        self.ssi = ssi
        self.scales = scales

    def gradients(self, input_tensor):
        op_x, op_y = self.op_x, self.op_y
        groups = input_tensor.shape[1]
        op_x = op_x.repeat(groups, 1, 1, 1)
        op_y = op_y.repeat(groups, 1, 1, 1)
        grad_x = F.conv2d(input_tensor, op_x, groups=groups)
        grad_y = F.conv2d(input_tensor, op_y, groups=groups)
        return grad_x, grad_y

    def forward(self, prediction, target, mask=None):
        if self.ssi:
            prediction_ = ssi_normalize_depth(prediction)
            target_ = ssi_normalize_depth(target)
        else:
            prediction_ = prediction
            target_ = target
        prediction_ = prediction_.unsqueeze(0)
        target_ = target_.unsqueeze(0)
        total_loss = 0.0
        for scale in range(self.scales):
            grad_prediction_x, grad_prediction_y = self.gradients(prediction_)
            grad_target_x, grad_target_y = self.gradients(target_)
            loss_x = self.loss_function(grad_prediction_x, grad_target_x)
            loss_y = self.loss_function(grad_prediction_y, grad_target_y)
            total_loss += torch.mean(loss_x + loss_y)
            prediction_ = F.interpolate(prediction_, scale_factor=0.5)
            target_ = F.interpolate(target_, scale_factor=0.5)
        return total_loss / self.scales

If you have any question, feel free to let me know!

Thank you very much for your reply. After reviewing the implementation of your loss function, I have a few questions I'd like to consult with you:

  1. When implementing the ssi_normalize_depth function, I noticed that you did not exclude the pixel points corresponding to infinitely far points (e.g., space) or points without depth values (depth value of 0) from normalization. Could you please explain why?
  2. During the training process, does the MultiScaleDeriLoss encounter any NaN or Inf issues?
  3. Additionally, if I want to incorporate sparse point cloud data for training, the ranging accuracy improves, but the structural consistency deteriorates. Do you have any suggestions for addressing this issue?

Thanks for your attention.

  1. I do so because I just train DepthPro on my own datasets which are synthetic and contain images that have all valid relative depth. For your circumstance, you can add the depth_mask easily.
  2. No. The loss descends gradually and the training process goes on well.
  3. I do not have such data. Sorry not to be able to offer you suggestions. But trimming top 20% error depth can help you maybe?

By the way, for some of my datasets, my implementation sometimes does not work well. So there could be something more to do to modified this implementation. If you have any idea, welcome to reply me more.

Thank you very much for your response once again. I have one more question: I noticed that the input image is directly resized to (1536, 1536) without maintaining the aspect ratio. Does this not affect the results?
Additionally, for absolute depth, only fx is used. Why is fy not needed?

  1. Resizing is a common method to match the resolution of your input image to the network processing resolution. In my opinion, if the image is resized to a larger resolution, there is merely no affection. Otherwise, the result would be degenerated.
  2. For calculating the metric (absolute) depth, it refers to mathematics of Computational Photography. You may search for it on Internet.

Thank you for your explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants