Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal Changes #507

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions init2winit/model_lib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
return the mean of function values. This is to make trainer.py more agnostic to
the details of the padding and masking.
"""

import functools

from absl import logging
Expand All @@ -27,6 +28,7 @@
import jax.numpy as jnp
import optax


bi_tempered_loss = None
try:
from jax_bitempered_loss import loss as bi_tempered_loss # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -394,9 +396,35 @@ def weighted_mean_absolute_error(logits, targets, weights=None):
return jnp.sum(unnormalized_mean_absolute_error) / normalization


def multiclass_hinge_loss(logits, targets, weights=None):
"""Implements the multiclass hinge loss.

Args:
logits: Array with shape [batch_size, num_labels].
targets: One-hot encoded labels with shape [batch size, num labels].
weights: None or float array of shape (batch,).

Returns:
The multiclass hinge loss for classification, averaged over the first
dimension (batch) to match cross_entropy_loss.
"""
losses = jnp.max(logits + 1.0 - targets, axis=-1) - jnp.sum(
logits * targets, axis=-1)
if weights is not None:
if weights.ndim != targets.ndim - 1:
raise ValueError('Incorrect shapes. Got shape %s weights and %s targets' %
(str(weights.shape), str(targets.shape)))
normalization = weights.sum()
losses *= weights
else:
normalization = targets.shape[0]

return jnp.sum(losses) / normalization

# TODO(cheolmin): add mean_squared_error
_ALL_LOSS_FUNCTIONS = {
'rescaled_mean_squared_error': (rescaled_mean_squared_error, None),
'multiclass_hinge_loss': (multiclass_hinge_loss, None),
'sigmoid_mean_squared_error': (sigmoid_mean_squared_error, jax.nn.sigmoid),
'sigmoid_binary_cross_entropy':
(sigmoid_binary_cross_entropy, jax.nn.sigmoid),
Expand Down