Skip to content

Commit

Permalink
enable truncation of context len in ttm
Browse files Browse the repository at this point in the history
  • Loading branch information
ajati committed Dec 4, 2024
1 parent 2c7d487 commit 3598090
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ packages = ["tsfm_public", "tsfmhfdemos"]
[project.optional-dependencies]


all = ["tsfm_public[notebooks,testing,dev]"]
all = ["tsfm_public[notebooks,external,testing,dev]"]

notebooks = [
"jupyter",
Expand All @@ -42,7 +42,8 @@ notebooks = [
"kaleido",
"tensorboard",
]
testing = ["pytest", "tsfm_public[notebooks]", "parameterized"]
external = ["tsfm_public[notebooks]", "gluonts"]
testing = ["pytest", "tsfm_public[external]", "parameterized"]
dev = ["pre-commit", "tsfm_public[testing]", "ruff==0.4.4"]

# ogv deployments will already have jupyter
Expand Down
9 changes: 9 additions & 0 deletions tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,15 @@ def forward(
Returns:
"""
if past_values.dim() != 3:
raise ValueError(
"`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`."
)
if past_values.shape[1] > self.config.context_length:
past_values = past_values[:, -self.config.context_length :, :]
elif past_values.shape[1] < self.config.context_length:
raise ValueError("Context length in `past_values` is shorter that TTM context_length.")

if self.loss == "mse":
loss = nn.MSELoss(reduction="mean")
elif self.loss == "mae":
Expand Down
14 changes: 2 additions & 12 deletions tsfm_public/toolkit/gluonts_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Union

import numpy as np
import torch
from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset
from gluonts.itertools import batcher
from gluonts.transform.feature import LastValueImputation
from torch.utils.data import Dataset
from tqdm import tqdm

from tsfm_public.toolkit.dataset import _torch


def impute_series(target):
if np.isnan(target).any():
Expand All @@ -23,17 +24,6 @@ def impute_series(target):
return target


def np_to_torch(np):
if np.dtype == "float" or np.dtype == "float32":
return torch.from_numpy(np).float()
elif np.dtype == "int":
return torch.from_numpy(np)


def _torch(*nps):
return tuple(np_to_torch(x) for x in nps)


class StandardScalingGluonTSDataset:
"""
TTM works best on standard scaled data, especially if fewshot
Expand Down

0 comments on commit 3598090

Please sign in to comment.