diff --git a/pyproject.toml b/pyproject.toml index 60a7c490..d21fb799 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ packages = ["tsfm_public", "tsfmhfdemos"] [project.optional-dependencies] -all = ["tsfm_public[notebooks,testing,dev]"] +all = ["tsfm_public[notebooks,external,testing,dev]"] notebooks = [ "jupyter", @@ -42,7 +42,8 @@ notebooks = [ "kaleido", "tensorboard", ] -testing = ["pytest", "tsfm_public[notebooks]", "parameterized"] +external = ["tsfm_public[notebooks]", "gluonts"] +testing = ["pytest", "tsfm_public[external]", "parameterized"] dev = ["pre-commit", "tsfm_public[testing]", "ruff==0.4.4"] # ogv deployments will already have jupyter diff --git a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py index 2cf77f35..7ec8651e 100644 --- a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py +++ b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py @@ -1802,6 +1802,15 @@ def forward( Returns: """ + if past_values.dim() != 3: + raise ValueError( + "`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`." + ) + if past_values.shape[1] > self.config.context_length: + past_values = past_values[:, -self.config.context_length :, :] + elif past_values.shape[1] < self.config.context_length: + raise ValueError("Context length in `past_values` is shorter that TTM context_length.") + if self.loss == "mse": loss = nn.MSELoss(reduction="mean") elif self.loss == "mae": diff --git a/tsfm_public/toolkit/gluonts_data_wrapper.py b/tsfm_public/toolkit/gluonts_data_wrapper.py index e96043d1..40e58c1c 100644 --- a/tsfm_public/toolkit/gluonts_data_wrapper.py +++ b/tsfm_public/toolkit/gluonts_data_wrapper.py @@ -2,13 +2,14 @@ from typing import Union import numpy as np -import torch from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset from gluonts.itertools import batcher from gluonts.transform.feature import LastValueImputation from torch.utils.data import Dataset from tqdm import tqdm +from tsfm_public.toolkit.dataset import _torch + def impute_series(target): if np.isnan(target).any(): @@ -23,17 +24,6 @@ def impute_series(target): return target -def np_to_torch(np): - if np.dtype == "float" or np.dtype == "float32": - return torch.from_numpy(np).float() - elif np.dtype == "int": - return torch.from_numpy(np) - - -def _torch(*nps): - return tuple(np_to_torch(x) for x in nps) - - class StandardScalingGluonTSDataset: """ TTM works best on standard scaled data, especially if fewshot