Skip to content

Commit

Permalink
Merge pull request #112 from ibm-granite/model_path
Browse files Browse the repository at this point in the history
Redirect to tests to "test-specific" models
  • Loading branch information
wgifford authored Aug 16, 2024
2 parents 0263bd5 + 2daa4a5 commit cb3982e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 4 additions & 4 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm-granite/granite-timeseries-ttm-v1",
"model_id": "ibm/test-ttm-v1",
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -89,7 +89,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data.copy()

msg = {
"model_id": "ibm-granite/granite-timeseries-ttm-v1",
"model_id": "ibm/test-ttm-v1",
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -111,8 +111,8 @@ def test_zero_shot_forecast_inference(ts_data):
@pytest.mark.parametrize(
"model_path",
[
"ibm-granite/granite-timeseries-patchtst",
"ibm-granite/granite-timeseries-patchtsmixer",
"ibm/test-patchtst",
"ibm/test-patchtsmixer",
],
)
def test_trained_model_inference(ts_data, model_path):
Expand Down
8 changes: 5 additions & 3 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from transformers import PatchTSTForPrediction

from tsfm_public import TinyTimeMixerForPrediction
from tsfm_public import TinyTimeMixerConfig, TinyTimeMixerForPrediction
from tsfm_public.toolkit.time_series_forecasting_pipeline import (
TimeSeriesForecastingPipeline,
)
Expand All @@ -25,8 +25,10 @@ def patchtst_model():

@pytest.fixture(scope="module")
def ttm_model():
model_path = "ibm-granite/granite-timeseries-ttm-v1"
model = TinyTimeMixerForPrediction.from_pretrained(model_path)
# model_path = "ibm-granite/granite-timeseries-ttm-v1"

conf = TinyTimeMixerConfig()
model = TinyTimeMixerForPrediction(conf)

return model

Expand Down

0 comments on commit cb3982e

Please sign in to comment.