Skip to content

Commit

Permalink
Merge pull request #136 from ibm-granite/issue_128
Browse files Browse the repository at this point in the history
Issue 128
  • Loading branch information
wgifford authored Oct 25, 2024
2 parents b1a9c89 + 3e1f4a5 commit 442e4e4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
58 changes: 58 additions & 0 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ def test_zero_shot_forecast_inference(ts_data):
assert len(df_out) == 1
assert df_out[0].shape[0] == prediction_length

# test single, very short (length 2)
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

msg = {
"model_id": model_id_path,
"parameters": {
"prediction_length": params["prediction_length"],
},
"schema": {
"timestamp_column": params["timestamp_column"],
"id_columns": params["id_columns"],
"target_columns": params["target_columns"],
},
"data": encode_data(test_data_.iloc[:2], params["timestamp_column"]),
"future_data": {},
}

out = get_inference_response(msg)
assert "Received 2 time points for id a" in out.text

# test single, more data
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()

Expand Down Expand Up @@ -361,3 +381,41 @@ def test_trained_model_inference(ts_data):
df_out = get_inference_response(msg)
assert len(df_out) == 1
assert df_out[0].shape[0] == prediction_length


# def test_simple():
# import numpy as np
# import pandas as pd

# series_length = 512
# timestamps = pd.date_range("2021-01-01", periods=series_length).to_list()
# num_series = 5

# def encode_data(df: pd.DataFrame, timestamp_column: str) -> Dict[str, Any]:
# df[timestamp_column] = df[timestamp_column].apply(lambda x: x.isoformat())
# data_payload = df.to_dict(orient="list")
# return data_payload

# test_data = pd.DataFrame(
# {
# "date": timestamps * num_series,
# "id": np.array([f"id{i}" for i in range(num_series)]).repeat(series_length),
# "target": np.tile(np.arange(series_length).astype(float), num_series),
# }
# )

# msg = {
# "model_id": "ttm-r2",
# "parameters": {
# "prediction_length": 96,
# },
# "schema": {
# "timestamp_column": "date",
# "id_columns": ["id"],
# "target_columns": ["target"],
# },
# "data": encode_data(test_data, "date"),
# }

# df_out = get_inference_response(msg)
# print(df_out)
28 changes: 28 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,34 @@ def test_forecasting_df_dataset(ts_data_with_categorical):
assert np.all(ds[0]["future_values"][:, 2].numpy() == 0)


def test_short_forecasting_df_dataset(ts_data_with_categorical):
prediction_length = 3
context_length = 4
target_columns = ["value1"]

# df = ts_data_with_categorical.iloc[:2].copy()

df = pd.DataFrame(
{
"timestamp": pd.to_datetime(range(10)),
"id": ["A"] * 10,
"value1": range(10),
}
)
df = df.iloc[:1]

ds = ForecastDFDataset(
df,
timestamp_column="timestamp",
id_columns=["id"],
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
)

assert ds[0]["timestamp"] is pd.NaT


def test_forecasting_df_dataset_stride(ts_data_with_categorical):
prediction_length = 2
context_length = 3
Expand Down
4 changes: 3 additions & 1 deletion tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,9 @@ def ts_padding(
pad_df[c] = pad_df[c].astype(df.dtypes[c], copy=False)

if timestamp_column:
if (df[timestamp_column].dtype.type == np.datetime64) or (df[timestamp_column].dtype == int):
if len(df) < 2:
pad_df[timestamp_column] = None
elif (df[timestamp_column].dtype.type == np.datetime64) or (df[timestamp_column].dtype == int):
last_timestamp = df.iloc[0][timestamp_column]
period = df.iloc[1][timestamp_column] - df.iloc[0][timestamp_column]
prepended_timestamps = [last_timestamp + offset * period for offset in range(-fill_length, 0)]
Expand Down

0 comments on commit 442e4e4

Please sign in to comment.