You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Thanks for your excellent work! One simple question here. When I read the code, I found that the focal loss code: def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): num_classes = logits.shape[1] gamma = gamma alpha = alpha dtype = targets.dtype device = targets.device class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0) t = targets.unsqueeze(1) p = torch.sigmoid(logits) term1 = (1 - p) ** gamma * torch.log(p) term2 = p ** gamma * torch.log(1 - p) return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
You name it as sigmoid_focal_loss_cpu.
I am a little bit confused since this code can be run on GPU (if the input tensor device is GPU). Am I right?
Many Thanks!
The text was updated successfully, but these errors were encountered:
Hi,
Thanks for your excellent work! One simple question here. When I read the code, I found that the focal loss code:
def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): num_classes = logits.shape[1] gamma = gamma alpha = alpha dtype = targets.dtype device = targets.device class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0) t = targets.unsqueeze(1) p = torch.sigmoid(logits) term1 = (1 - p) ** gamma * torch.log(p) term2 = p ** gamma * torch.log(1 - p) return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
You name it as
sigmoid_focal_loss_cpu
.I am a little bit confused since this code can be run on GPU (if the input tensor device is GPU). Am I right?
Many Thanks!
The text was updated successfully, but these errors were encountered: