Skip to content

Commit

Permalink
test all
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Oct 21, 2024
1 parent fd33d44 commit 4cbd387
Showing 1 changed file with 57 additions and 20 deletions.
77 changes: 57 additions & 20 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright contributors to the TSFM project
#
from pathlib import Path
from typing import Any, Dict

import numpy as np
Expand All @@ -11,21 +12,44 @@
from tsfm_public.toolkit.util import select_by_index


model_param_map = {
"ttm-r1": {"context_length": 512, "prediction_length": 96},
"ttm-1024-96-r1": {"context_length": 1024, "prediction_length": 96},
"ttm-r2": {"context_length": 512, "prediction_length": 96},
"ttm-1024-96-r2": {"context_length": 1024, "prediction_length": 96},
"ttm-1536-96-r2": {"context_length": 1536, "prediction_length": 96},
"ibm/test-patchtst": {"context_length": 512, "prediction_length": 96},
"ibm/test-patchtsmixer": {"context_length": 512, "prediction_length": 96},
}


@pytest.fixture(scope="module")
def ts_data():
def ts_data_base():
dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"

forecast_length = 96
context_length = 512
# forecast_length = 96
# context_length = 512
timestamp_column = "date"

data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)

return data


@pytest.fixture(scope="module")
def ts_data(ts_data_base, request):
# forecast_length = 96
# context_length = 512
model_id = request.param
prediction_length = model_param_map[model_id]["prediction_length"]
context_length = model_param_map[model_id]["context_length"]
timestamp_column = "date"

test_data = select_by_index(
data,
ts_data_base,
start_index=12 * 30 * 24 + 4 * 30 * 24 - context_length * 5,
end_index=12 * 30 * 24 + 4 * 30 * 24,
).reset_index(drop=True)
Expand All @@ -36,7 +60,9 @@ def ts_data():
"timestamp_column": timestamp_column,
"id_columns": ["id"],
"target_columns": ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"],
"prediction_length": forecast_length,
"prediction_length": prediction_length,
"context_length": context_length,
"model_id": model_id,
}


Expand All @@ -63,8 +89,17 @@ def encode_data(df: pd.DataFrame, timestamp_column: str) -> Dict[str, Any]:
return data_payload


@pytest.mark.parametrize(
"ts_data", ["ttm-r1", "ttm-1024-96-r1", "ttm-r2", "ttm-1024-96-r2", "ttm-1536-96-r2"], indirect=True
)
def test_zero_shot_forecast_inference(ts_data):
test_data, params = ts_data

prediction_length = params["prediction_length"]
context_length = params["context_length"]
model_id = params["model_id"]
model_id_path = str(Path("/tmp/test-tsfm") / model_id)

id_columns = params["id_columns"]

prediction_length = 96
Expand All @@ -74,7 +109,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -100,7 +135,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data_.fillna(0)

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -121,7 +156,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data.copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -144,7 +179,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data_.iloc[3:]

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -158,7 +193,7 @@ def test_zero_shot_forecast_inference(ts_data):
}

out = get_inference_response(msg)
assert "Received 509 time points for id a" in out.text
assert f"Received {context_length-3} time points for id a" in out.text

# test multi-time series, multi-id
# error due to insufficient context
Expand All @@ -167,7 +202,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_["id2"] = test_data_[params["id_columns"]]

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand All @@ -181,13 +216,13 @@ def test_zero_shot_forecast_inference(ts_data):
}

out = get_inference_response(msg)
assert "Received 509 time points for id ('a', 'a')" in out.text
assert f"Received {context_length-3} time points for id ('a', 'a')" in out.text

# single series, less columns
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"],
},
Expand All @@ -209,7 +244,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"],
},
Expand All @@ -231,7 +266,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"] // 4,
},
Expand All @@ -253,7 +288,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"] * 4,
},
Expand All @@ -273,7 +308,7 @@ def test_zero_shot_forecast_inference(ts_data):
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": "ibm/test-ttm-v1",
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"] // 4,
},
Expand All @@ -292,15 +327,17 @@ def test_zero_shot_forecast_inference(ts_data):


@pytest.mark.parametrize(
"model_path",
"ts_data",
[
"ibm/test-patchtst",
"ibm/test-patchtsmixer",
],
indirect=True,
)
def test_trained_model_inference(ts_data, model_path):
def test_trained_model_inference(ts_data):
test_data, params = ts_data
id_columns = params["id_columns"]
model_id = params["model_id"]

prediction_length = 96

Expand All @@ -309,7 +346,7 @@ def test_trained_model_inference(ts_data, model_path):
encoded_data = encode_data(test_data_, params["timestamp_column"])

msg = {
"model_id": model_path,
"model_id": model_id,
"parameters": {
# "prediction_length": params["prediction_length"],
},
Expand Down

0 comments on commit 4cbd387

Please sign in to comment.