diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index b2e3d463..5c5d7adb 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -68,7 +68,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm-granite/granite-timeseries-ttm-v1", + "model_id": "ibm/test-ttm-v1", "parameters": { # "prediction_length": params["prediction_length"], }, @@ -89,7 +89,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data.copy() msg = { - "model_id": "ibm-granite/granite-timeseries-ttm-v1", + "model_id": "ibm/test-ttm-v1", "parameters": { # "prediction_length": params["prediction_length"], }, @@ -111,8 +111,8 @@ def test_zero_shot_forecast_inference(ts_data): @pytest.mark.parametrize( "model_path", [ - "ibm-granite/granite-timeseries-patchtst", - "ibm-granite/granite-timeseries-patchtsmixer", + "ibm/test-patchtst", + "ibm/test-patchtsmixer", ], ) def test_trained_model_inference(ts_data, model_path): diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index baa862a0..c788bcd6 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -7,7 +7,7 @@ import pytest from transformers import PatchTSTForPrediction -from tsfm_public import TinyTimeMixerForPrediction +from tsfm_public import TinyTimeMixerConfig, TinyTimeMixerForPrediction from tsfm_public.toolkit.time_series_forecasting_pipeline import ( TimeSeriesForecastingPipeline, ) @@ -25,8 +25,10 @@ def patchtst_model(): @pytest.fixture(scope="module") def ttm_model(): - model_path = "ibm-granite/granite-timeseries-ttm-v1" - model = TinyTimeMixerForPrediction.from_pretrained(model_path) + # model_path = "ibm-granite/granite-timeseries-ttm-v1" + + conf = TinyTimeMixerConfig() + model = TinyTimeMixerForPrediction(conf) return model