Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more loss options during training #32

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ machineB.save()

### Training & Evaluation in Command Line

We provide a script in "medsegpy/train_net.py", that is made to train
We provide a script in "tools/train_net.py", that is made to train
all the configs provided in medsegpy.
You may want to use it as a reference to write your own training script for
new research.
Expand Down
4 changes: 4 additions & 0 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class Config(object):
# Class name for robust loss computation
ROBUST_LOSS_NAME = ""
ROBUST_LOSS_STEP_SIZE = 1e-1
# Additonal loss functions to run during training
# [[(id_1, output_mode_1), class_weights_1],
# [(id_2, output_mode_2), class_weights_2] ... ]
LOSS_METRICS = []

# PIDS to include, None = all pids
PIDS = None
Expand Down
9 changes: 8 additions & 1 deletion medsegpy/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,14 @@ def _train_model(self):
# TODO: Add more options for metrics.
optimizer = solver.build_optimizer(cfg)
loss_func = self.build_loss()
metrics = [lr_callback(optimizer), dice_loss]

loss_metrics = []
if len(cfg.LOSS_METRICS) > 0:
for loss_idx, loss_metric in enumerate(cfg.LOSS_METRICS):
new_metric = build_loss(cfg, build_additional_metric=True, additional_metric=loss_metric)
new_metric.name = f'{loss_metric[0][0]}_{loss_idx}'
loss_metrics.append(new_metric)
metrics = [lr_callback(optimizer), dice_loss] + loss_metrics

callbacks = self.build_callbacks()
if isinstance(loss_func, kc.Callback):
Expand Down
4 changes: 4 additions & 0 deletions medsegpy/loss/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def reduce_tensor(x, reduction="mean", axis=None, weights=None):
use_weights = weights is not None
if use_weights:
x *= weights
if (reduction in ("none", None)) and (len(tf.where(weights==0)) == (len(weights) - 1)):
# if one of the weights = 1 and rest = 0, then only want loss of that single value
# need to scale by factor len(weights) because final reduction is a mean
return x * len(weights)

if reduction == "mean" and use_weights:
ndim = K.ndim(x)
Expand Down
48 changes: 43 additions & 5 deletions medsegpy/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AVG_DICE_LOSS = ("avg_dice", "sigmoid")
AVG_DICE_LOSS_SOFTMAX = ("avg_dice", "softmax")
AVG_DICE_NO_REDUCE = ("avg_dice_no_reduce", "sigmoid")
AVG_DICE_NO_REDUCE_SOFTMAX = ("avg_dice_no_reduce", "softmax")
WEIGHTED_CROSS_ENTROPY_LOSS = ("weighted_cross_entropy", "softmax")
WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS = ("weighted_cross_entropy_sigmoid", "sigmoid")

Expand All @@ -36,6 +37,7 @@
"AVG_DICE_LOSS",
"AVG_DICE_LOSS_SOFTMAX",
"AVG_DICE_NO_REDUCE",
"AVG_DICE_NO_REDUCE_SOFTMAX",
"WEIGHTED_CROSS_ENTROPY_LOSS",
"WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS",
"BINARY_CROSS_ENTROPY_LOSS",
Expand All @@ -46,11 +48,30 @@
]


def build_loss(cfg):
loss = cfg.LOSS
def build_loss(cfg, build_additional_metric=False, additional_metric: list = None):
if build_additional_metric is False:
loss = cfg.LOSS
robust_loss_cls = cfg.ROBUST_LOSS_NAME
robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE
class_weights = cfg.CLASS_WEIGHTS
elif build_additional_metric is True:
loss = additional_metric[0]
# yaml giving trouble importing list of tuples - need to conver manually?
if type(loss) == list:
loss = tuple(loss)
class_weights = additional_metric[1]
# not supporting robust loss for additional metrics (for now).
robust_loss_cls = False
robust_step_size = None

num_classes = len(cfg.CATEGORIES)
robust_loss_cls = cfg.ROBUST_LOSS_NAME
robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE

# allow config to specify weights as integer indicating we only want
# to test one of the classes.
if type(class_weights) in (list, tuple):
pass
elif type(class_weights) is int:
class_weights = get_class_weights_from_int(class_weights, num_classes)

if robust_loss_cls:
reduction = "class"
Expand All @@ -64,7 +85,7 @@ def build_loss(cfg):
pass
loss = get_training_loss(
loss,
weights=cfg.CLASS_WEIGHTS,
weights=class_weights,
# Remove computation on the background class.
remove_background=cfg.INCLUDE_BACKGROUND,
reduce=reduction,
Expand All @@ -79,6 +100,12 @@ def build_loss(cfg):
else:
raise ValueError(f"{robust_loss_cls} not supported")

def get_class_weights_from_int(label, num_classes):
"""Returns class_weights for an integer label."""
class_weights = [0] * num_classes
class_weights[label] = 1
return class_weights


# TODO (arjundd): Add ability to exclude specific indices from loss function.
def get_training_loss_from_str(loss_str: str):
Expand All @@ -91,6 +118,8 @@ def get_training_loss_from_str(loss_str: str):
return AVG_DICE_LOSS
elif loss_str == "AVG_DICE_NO_REDUCE":
return AVG_DICE_NO_REDUCE
elif loss_str == "AVG_DICE_NO_REDUCE_SOFTMAX":
return AVG_DICE_NO_REDUCE_SOFTMAX
elif loss_str == "WEIGHTED_CROSS_ENTROPY_LOSS":
return WEIGHTED_CROSS_ENTROPY_LOSS
elif loss_str == "WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS":
Expand Down Expand Up @@ -134,6 +163,15 @@ def get_training_loss(loss, **kwargs):
kwargs.pop("reduce", None)
kwargs["reduction"] = "none"
return DiceLoss(**kwargs)
elif loss == AVG_DICE_NO_REDUCE_SOFTMAX:
# Below is actually the same as the above, we could/should amalgamate?
kwargs.pop("reduce", None)
kwargs["reduction"] = "none"
# we don't need to add the softmax activation here -
# it should already be added here:
# (https://github.com/ad12/MedSegPy/blob/0c316baaf040c22d562940a198a0e48eef2d36a8/medsegpy/modeling/meta_arch/unet.py#L152)
# kwargs["activation"] = "softmax"
return DiceLoss(**kwargs)
else:
raise ValueError("Loss type not supported")

Expand Down