Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Focal Tversky Loss implementaiton #2

Open
YoonSungLee opened this issue Oct 13, 2020 · 1 comment
Open

Focal Tversky Loss implementaiton #2

YoonSungLee opened this issue Oct 13, 2020 · 1 comment

Comments

@YoonSungLee
Copy link

Hi, thank you for your useful repo about a lot of loss functions.
I had a problem about class imbalance problem in my project, but solved this problem by using this repo.

But, when using this repo, I was not able to use 'focal_tversky' loss function in this repo.
Whenever I use this code, I got an error which means 'validation loss is nan'.

What can I do to use this loss function in my project?
Here's my code when I tried to use this function. I changed tensorflow to pytorch.

# focal tversky loss
class FocalTverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, smooth=1, alpha=0.7, gamma=0.75):
        super(FocalTverskyLoss, self).__init__()
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        pt_1 = self.tversky_index(inputs, targets)
        return torch.pow((1 - pt_1), self.gamma)
    
    def tversky_index(self, inputs, targets):
        y_true_pos = torch.flatten(targets)
        y_pred_pos = torch.flatten(inputs)
        true_pos = torch.sum(y_true_pos * y_pred_pos)
        false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
        false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
        return (true_pos + self.smooth) / (true_pos + self.alpha * false_neg + (
                    1 - self.alpha) * false_pos + self.smooth)

Thank you.

@raimannma
Copy link

raimannma commented Dec 15, 2022

The problem is that gamma is 0.75.

and if (1 - pt_1) in the forward method has negative values the power to a value less than 1 is not defined.

Cause you can't take the square root of a negative number (in real number world)

In paper they say:
γ can range from [1,3]

I think it is a bug in the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants