Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chronos handler to support chronos-bolt-* models #223

Open
wants to merge 40 commits into
base: new_model_integrations
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1f00c9a
relax parameters strictness
wgifford Nov 26, 2024
27f6a3c
allow extra
wgifford Nov 26, 2024
fa81367
Merge pull request #214 from ibm-granite/service_updates
ssiegel95 Nov 30, 2024
6dd4d4d
clarify citations
wgifford Dec 3, 2024
ab5777d
Merge pull request #218 from ibm-granite/wiki_update
wgifford Dec 3, 2024
2c7d487
gluonts data wrapper, and ttm gluonts predictor
ajati Dec 3, 2024
3598090
enable truncation of context len in ttm
ajati Dec 4, 2024
176e7d1
fix issues with future exogenous
wgifford Dec 4, 2024
3859359
force_return in get_model
ajati Dec 4, 2024
a1da2c5
Merge pull request #220 from ibm-granite/pipeline_exog
wgifford Dec 4, 2024
beba825
code moved to extras folder outside tsfm_public
ajati Dec 5, 2024
c0b03eb
tests moved
ajati Dec 5, 2024
3df4e68
gift srcs removed, get_model updated
ajati Dec 5, 2024
69ed4fd
revert toml and visualization functions
ajati Dec 5, 2024
62ea27f
add optional verbose payload dumps
ssiegel95 Dec 6, 2024
bfa6535
exception -> valueerror
ajati Dec 6, 2024
98730fd
Merge pull request #219 from ibm-granite/gift
wgifford Dec 6, 2024
8354f6a
we can't resolve to a single directory here, need to scan them in load
ssiegel95 Dec 6, 2024
96e5bb4
add additional directory to TSFM_MODEL_DIR
ssiegel95 Dec 6, 2024
7d8d3be
model path resolver
ssiegel95 Dec 6, 2024
82ab987
ignore prometheus metrics dir
ssiegel95 Dec 6, 2024
1a1b75a
model dir resolver
ssiegel95 Dec 6, 2024
06eb16b
use model path resolver
ssiegel95 Dec 6, 2024
eddab4e
test model path resolver
ssiegel95 Dec 6, 2024
b0a6809
Merge remote-tracking branch 'origin/main' into byom
ssiegel95 Dec 6, 2024
f45c0b7
boilerplate code
ssiegel95 Dec 9, 2024
1e67a11
ignore dirutil.py
ssiegel95 Dec 9, 2024
35870fd
automate maintenance of .gitignore
ssiegel95 Dec 9, 2024
0afbf8d
Merge pull request #222 from ibm-granite/byom
ssiegel95 Dec 9, 2024
6d90328
chronos handler to support chronos-bolt-* models
gganapavarapu Dec 9, 2024
cd10ff4
test min context length 2
gganapavarapu Dec 10, 2024
5330e1e
explicitly set device
wgifford Dec 10, 2024
c96905d
select device
wgifford Dec 10, 2024
48346a6
Merge pull request #225 from ibm-granite/set_service_device
wgifford Dec 11, 2024
071f0c3
merge main
gganapavarapu Dec 11, 2024
88067b3
fix merge issue and revert cd10ff4
gganapavarapu Dec 11, 2024
ebe0f8e
chronos repo name in deps, make style
gganapavarapu Dec 11, 2024
d4f57ee
poetry lock
gganapavarapu Dec 12, 2024
9b793f2
ID column support for chronos models
gganapavarapu Dec 13, 2024
61e822a
support no ID columns as well for chronos models
gganapavarapu Dec 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
enable truncation of context len in ttm
  • Loading branch information
ajati committed Dec 4, 2024
commit 3598090ba2dca854d9af8226e1b9e814b3471026
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ packages = ["tsfm_public", "tsfmhfdemos"]
[project.optional-dependencies]


all = ["tsfm_public[notebooks,testing,dev]"]
all = ["tsfm_public[notebooks,external,testing,dev]"]

notebooks = [
"jupyter",
Expand All @@ -42,7 +42,8 @@ notebooks = [
"kaleido",
"tensorboard",
]
testing = ["pytest", "tsfm_public[notebooks]", "parameterized"]
external = ["tsfm_public[notebooks]", "gluonts"]
testing = ["pytest", "tsfm_public[external]", "parameterized"]
dev = ["pre-commit", "tsfm_public[testing]", "ruff==0.4.4"]

# ogv deployments will already have jupyter
Expand Down
9 changes: 9 additions & 0 deletions tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,15 @@ def forward(
Returns:

"""
if past_values.dim() != 3:
raise ValueError(
"`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`."
)
if past_values.shape[1] > self.config.context_length:
past_values = past_values[:, -self.config.context_length :, :]
elif past_values.shape[1] < self.config.context_length:
raise ValueError("Context length in `past_values` is shorter that TTM context_length.")

if self.loss == "mse":
loss = nn.MSELoss(reduction="mean")
elif self.loss == "mae":
Expand Down
14 changes: 2 additions & 12 deletions tsfm_public/toolkit/gluonts_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Union

import numpy as np
import torch
from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset
from gluonts.itertools import batcher
from gluonts.transform.feature import LastValueImputation
from torch.utils.data import Dataset
from tqdm import tqdm

from tsfm_public.toolkit.dataset import _torch


def impute_series(target):
if np.isnan(target).any():
Expand All @@ -23,17 +24,6 @@ def impute_series(target):
return target


def np_to_torch(np):
if np.dtype == "float" or np.dtype == "float32":
return torch.from_numpy(np).float()
elif np.dtype == "int":
return torch.from_numpy(np)


def _torch(*nps):
return tuple(np_to_torch(x) for x in nps)


class StandardScalingGluonTSDataset:
"""
TTM works best on standard scaled data, especially if fewshot
Expand Down
Loading