diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index 6eea8092..3b9e9de5 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -1,5 +1,6 @@ # Copyright contributors to the TSFM project # +from pathlib import Path from typing import Any, Dict import numpy as np @@ -11,12 +12,23 @@ from tsfm_public.toolkit.util import select_by_index +model_param_map = { + "ttm-r1": {"context_length": 512, "prediction_length": 96}, + "ttm-1024-96-r1": {"context_length": 1024, "prediction_length": 96}, + "ttm-r2": {"context_length": 512, "prediction_length": 96}, + "ttm-1024-96-r2": {"context_length": 1024, "prediction_length": 96}, + "ttm-1536-96-r2": {"context_length": 1536, "prediction_length": 96}, + "ibm/test-patchtst": {"context_length": 512, "prediction_length": 96}, + "ibm/test-patchtsmixer": {"context_length": 512, "prediction_length": 96}, +} + + @pytest.fixture(scope="module") -def ts_data(): +def ts_data_base(): dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv" - forecast_length = 96 - context_length = 512 + # forecast_length = 96 + # context_length = 512 timestamp_column = "date" data = pd.read_csv( @@ -24,8 +36,20 @@ def ts_data(): parse_dates=[timestamp_column], ) + return data + + +@pytest.fixture(scope="module") +def ts_data(ts_data_base, request): + # forecast_length = 96 + # context_length = 512 + model_id = request.param + prediction_length = model_param_map[model_id]["prediction_length"] + context_length = model_param_map[model_id]["context_length"] + timestamp_column = "date" + test_data = select_by_index( - data, + ts_data_base, start_index=12 * 30 * 24 + 4 * 30 * 24 - context_length * 5, end_index=12 * 30 * 24 + 4 * 30 * 24, ).reset_index(drop=True) @@ -36,7 +60,9 @@ def ts_data(): "timestamp_column": timestamp_column, "id_columns": ["id"], "target_columns": ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"], - "prediction_length": forecast_length, + "prediction_length": prediction_length, + "context_length": context_length, + "model_id": model_id, } @@ -63,8 +89,17 @@ def encode_data(df: pd.DataFrame, timestamp_column: str) -> Dict[str, Any]: return data_payload +@pytest.mark.parametrize( + "ts_data", ["ttm-r1", "ttm-1024-96-r1", "ttm-r2", "ttm-1024-96-r2", "ttm-1536-96-r2"], indirect=True +) def test_zero_shot_forecast_inference(ts_data): test_data, params = ts_data + + prediction_length = params["prediction_length"] + context_length = params["context_length"] + model_id = params["model_id"] + model_id_path = str(Path("/tmp/test-tsfm") / model_id) + id_columns = params["id_columns"] prediction_length = 96 @@ -74,7 +109,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { # "prediction_length": params["prediction_length"], }, @@ -100,7 +135,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data_.fillna(0) msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { # "prediction_length": params["prediction_length"], }, @@ -121,7 +156,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data.copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { # "prediction_length": params["prediction_length"], }, @@ -144,7 +179,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data_.iloc[3:] msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { # "prediction_length": params["prediction_length"], }, @@ -158,7 +193,7 @@ def test_zero_shot_forecast_inference(ts_data): } out = get_inference_response(msg) - assert "Received 509 time points for id a" in out.text + assert f"Received {context_length-3} time points for id a" in out.text # test multi-time series, multi-id # error due to insufficient context @@ -167,7 +202,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_["id2"] = test_data_[params["id_columns"]] msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { # "prediction_length": params["prediction_length"], }, @@ -181,13 +216,13 @@ def test_zero_shot_forecast_inference(ts_data): } out = get_inference_response(msg) - assert "Received 509 time points for id ('a', 'a')" in out.text + assert f"Received {context_length-3} time points for id ('a', 'a')" in out.text # single series, less columns test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { "prediction_length": params["prediction_length"], }, @@ -209,7 +244,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { "prediction_length": params["prediction_length"], }, @@ -231,7 +266,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { "prediction_length": params["prediction_length"] // 4, }, @@ -253,7 +288,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { "prediction_length": params["prediction_length"] * 4, }, @@ -273,7 +308,7 @@ def test_zero_shot_forecast_inference(ts_data): test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() msg = { - "model_id": "ibm/test-ttm-v1", + "model_id": model_id_path, "parameters": { "prediction_length": params["prediction_length"] // 4, }, @@ -292,15 +327,17 @@ def test_zero_shot_forecast_inference(ts_data): @pytest.mark.parametrize( - "model_path", + "ts_data", [ "ibm/test-patchtst", "ibm/test-patchtsmixer", ], + indirect=True, ) -def test_trained_model_inference(ts_data, model_path): +def test_trained_model_inference(ts_data): test_data, params = ts_data id_columns = params["id_columns"] + model_id = params["model_id"] prediction_length = 96 @@ -309,7 +346,7 @@ def test_trained_model_inference(ts_data, model_path): encoded_data = encode_data(test_data_, params["timestamp_column"]) msg = { - "model_id": model_path, + "model_id": model_id, "parameters": { # "prediction_length": params["prediction_length"], },