diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index e3022f0a1b..43abd37031 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -282,7 +282,8 @@ def compute_loss( if torch.any(thresh_mask): l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps) - return l1_loss + focal_scale * focal_loss + dice_loss + # return l1_loss + focal_scale * focal_loss + dice_loss + return focal_scale * focal_loss + dice_loss def _dbnet(