Skip to content

Commit

Permalink
Merge pull request #226 from ibm-granite/quantile
Browse files Browse the repository at this point in the history
Add Quantile and Huber loss to TTM
  • Loading branch information
vijaye12 authored Dec 12, 2024
2 parents 48346a6 + f52f3e4 commit 9e4077e
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 59 deletions.
297 changes: 243 additions & 54 deletions notebooks/hfdemo/ttm_getting_started.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/models/tinytimemixer/test_modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def check_module(
[True, False, "mean", "std"],
[True, False],
[None, [0, 2]],
["mse", "mae", None],
["mse", "mae", "pinball", "huber", None],
[8, 16],
[True, False],
)
Expand Down
10 changes: 7 additions & 3 deletions tsfm_public/models/tinytimemixer/configuration_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class TinyTimeMixerConfig(PretrainedConfig):
Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the
scaler is set to "mean".
loss (`string`, *optional*, defaults to `"mse"`):
The loss function for the model corresponding to the `distribution_output` head. For parametric
distributions it is the negative log likelihood ("nll") and for point estimates it is the mean squared
error "mse" or "mae". Distribution head (nll) is currently disabled and not allowed.
The loss function to finetune or pretrain the the model. Allowed values are "mse" or "mae" or "pinball" or "huber".
Use pinball loss for probabilistic forecasts of different quantiles.
Distribution head (nll) is currently disabled and not allowed.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal weight initialization distribution.
post_init (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -212,6 +212,8 @@ def __init__(
# initialization parameters
init_linear: str = "pytorch",
init_embed: str = "pytorch",
quantile: float = 0.5,
huber_delta: float = 1,
**kwargs,
):
self.num_input_channels = num_input_channels
Expand Down Expand Up @@ -266,6 +268,8 @@ def __init__(
self.prediction_filter_length = prediction_filter_length
self.init_linear = init_linear
self.init_embed = init_embed
self.quantile = quantile
self.huber_delta = huber_delta

super().__init__(**kwargs)

Expand Down
35 changes: 34 additions & 1 deletion tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,35 @@
"""


class PinballLoss(nn.Module):
def __init__(self, quantile: float):
"""
Initialize the Pinball Loss for multidimensional tensors.
Args:
quantile (float): The desired quantile (e.g., 0.5 for median, 0.9 for 90th percentile).
"""
super(PinballLoss, self).__init__()
self.quantile = quantile

def forward(self, predictions, targets):
"""
Compute the Pinball Loss for shape [b, seq_len, channels].
Args:
predictions (torch.Tensor): Predicted values, shape [b, seq_len, channels].
targets (torch.Tensor): Ground truth values, shape [b, seq_len, channels].
Returns:
torch.Tensor: The mean pinball loss over all dimensions.
"""
errors = targets - predictions

loss = torch.max(self.quantile * errors, (self.quantile - 1) * errors)

return loss.mean()


class TinyTimeMixerGatedAttention(nn.Module):
"""
Module that applies gated attention to input data.
Expand Down Expand Up @@ -1723,7 +1752,7 @@ def __init__(self, config: TinyTimeMixerConfig):

self.prediction_filter_length = config.prediction_filter_length

if config.loss in ["mse", "mae"] or config.loss is None:
if config.loss in ["mse", "mae", "pinball", "huber"] or config.loss is None:
self.distribution_output = None
elif config.loss == "nll":
if self.prediction_filter_length is None:
Expand Down Expand Up @@ -1815,6 +1844,10 @@ def forward(
loss = nn.MSELoss(reduction="mean")
elif self.loss == "mae":
loss = nn.L1Loss(reduction="mean")
elif self.loss == "pinball":
loss = PinballLoss(quantile=self.config.quantile)
elif self.loss == "huber":
loss = nn.HuberLoss(delta=self.config.huber_delta)
elif self.loss == "nll":
raise Exception(
"NLL loss and Distribution heads are currently not allowed. Use mse or mae as loss functions."
Expand Down

0 comments on commit 9e4077e

Please sign in to comment.