Skip to content

Commit

Permalink
Merge pull request #566 from charlielam0615/add_multiclass_marginloss
Browse files Browse the repository at this point in the history
add support for multi-class margin loss
  • Loading branch information
chaoming0625 authored Dec 27, 2023
2 parents ff81c8f + 68da27e commit fb9a321
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions brainpy/_src/losses/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'log_cosh_loss',
'ctc_loss_with_forward_probs',
'ctc_loss',
'multi_margin_loss',
]


Expand Down Expand Up @@ -1050,3 +1051,47 @@ def ctc_loss(logits: ArrayType,
logits, logit_paddings, labels, label_paddings,
blank_id=blank_id, log_epsilon=log_epsilon)
return per_seq_loss


def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'):
r"""Computes multi-class margin loss, also called multi-class hinge loss.
This loss function is often used in multi-class classification problems.
It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin.
The loss function for sample :math:`i` is:
.. math::
\ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin})
where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class,
and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`
and :math:`i \neq y`.
Args:
predicts: :math:`(N, C)` where `C = number of classes`.
target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`.
margin (float, optional): Has a default value of :math:`1`.
p (float, optional): Has a default value of :math:`1`.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the sum of the output will be divided by the
number of elements in the output, ``'sum'``: the output will be summed.
Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
and in the meantime, specifying either of those two args will override :attr:`reduction`.
Default: ``'mean'``
Returns:
a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`.
"""
assert p == 1 or p == 2, 'p should be 1 or 2'
batch_size = predicts.shape[0]
correct_scores = predicts[jnp.arange(batch_size), targets]
margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p)
margins = margins.at[jnp.arange(batch_size), targets].set(0)
if reduction == 'mean':
return jnp.sum(margins) / batch_size
elif reduction == 'sum':
return jnp.sum(margins)
elif reduction == 'none':
return margins
1 change: 1 addition & 0 deletions brainpy/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
log_cosh_loss as log_cosh_loss,
ctc_loss_with_forward_probs as ctc_loss_with_forward_probs,
ctc_loss as ctc_loss,
multi_margin_loss as multi_margin_loss,
)

from brainpy._src.losses.comparison import (
Expand Down

0 comments on commit fb9a321

Please sign in to comment.