diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py index 8d8fb1388..ad0c3ea35 100644 --- a/brainpy/_src/losses/comparison.py +++ b/brainpy/_src/losses/comparison.py @@ -39,6 +39,7 @@ 'log_cosh_loss', 'ctc_loss_with_forward_probs', 'ctc_loss', + 'multi_margin_loss', ] @@ -1050,3 +1051,47 @@ def ctc_loss(logits: ArrayType, logits, logit_paddings, labels, label_paddings, blank_id=blank_id, log_epsilon=log_epsilon) return per_seq_loss + + +def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): + r"""Computes multi-class margin loss, also called multi-class hinge loss. + + This loss function is often used in multi-class classification problems. + It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin. + + The loss function for sample :math:`i` is: + + .. math:: + \ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin}) + + where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class, + and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Args: + predicts: :math:`(N, C)` where `C = number of classes`. + target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + margin (float, optional): Has a default value of :math:`1`. + p (float, optional): Has a default value of :math:`1`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the sum of the output will be divided by the + number of elements in the output, ``'sum'``: the output will be summed. + Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, + and in the meantime, specifying either of those two args will override :attr:`reduction`. + Default: ``'mean'`` + + Returns: + a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. + """ + assert p == 1 or p == 2, 'p should be 1 or 2' + batch_size = predicts.shape[0] + correct_scores = predicts[jnp.arange(batch_size), targets] + margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p) + margins = margins.at[jnp.arange(batch_size), targets].set(0) + if reduction == 'mean': + return jnp.sum(margins) / batch_size + elif reduction == 'sum': + return jnp.sum(margins) + elif reduction == 'none': + return margins diff --git a/brainpy/losses.py b/brainpy/losses.py index bf5177b74..f2506742c 100644 --- a/brainpy/losses.py +++ b/brainpy/losses.py @@ -18,6 +18,7 @@ log_cosh_loss as log_cosh_loss, ctc_loss_with_forward_probs as ctc_loss_with_forward_probs, ctc_loss as ctc_loss, + multi_margin_loss as multi_margin_loss, ) from brainpy._src.losses.comparison import (