-
Notifications
You must be signed in to change notification settings - Fork 204
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add tests for data point counts, update calculation
- Loading branch information
Showing
2 changed files
with
47 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,7 +78,7 @@ def get_inference_response( | |
resp = req.json() | ||
|
||
df = [pd.DataFrame.from_dict(r) for r in resp["results"]] | ||
return df | ||
return df, {k: v for k, v in resp.items() if "data_point" in k} | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
wgifford
Author
Collaborator
|
||
else: | ||
print(req.text) | ||
return req | ||
|
@@ -123,9 +123,11 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"]) | ||
assert counts["output_data_points"] == prediction_length * len(params["target_columns"]) | ||
|
||
# test single, very short (length 2) | ||
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() | ||
|
@@ -169,9 +171,11 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"]) | ||
assert counts["output_data_points"] == prediction_length * len(params["target_columns"]) | ||
|
||
# test multi-time series | ||
test_data_ = test_data.copy() | ||
|
@@ -190,10 +194,12 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
|
||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length * num_ids | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"]) * num_ids | ||
assert counts["output_data_points"] == prediction_length * len(params["target_columns"]) * num_ids | ||
|
||
# test multi-time series, errors | ||
test_data_ = test_data.copy() | ||
|
@@ -256,10 +262,12 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
assert df_out[0].shape[1] == 6 | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"][:4]) | ||
assert counts["output_data_points"] == prediction_length * len(params["target_columns"][:4]) | ||
|
||
# single series, less columns, no id | ||
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() | ||
|
@@ -278,10 +286,12 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
assert df_out[0].shape[1] == 2 | ||
assert counts["input_data_points"] == context_length | ||
assert counts["output_data_points"] == prediction_length | ||
|
||
# single series, different prediction length | ||
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() | ||
|
@@ -300,9 +310,11 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length // 4 | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"]) | ||
assert counts["output_data_points"] == (prediction_length // 4) * len(params["target_columns"]) | ||
|
||
# single series | ||
# error wrong prediction length | ||
|
@@ -342,16 +354,19 @@ def test_zero_shot_forecast_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length // 4 | ||
assert counts["input_data_points"] == context_length * len(params["target_columns"][1:]) | ||
assert counts["output_data_points"] == (prediction_length // 4) * len(params["target_columns"][1:]) | ||
|
||
|
||
@pytest.mark.parametrize("ts_data", ["ttm-r2"], indirect=True) | ||
def test_future_data_forecast_inference(ts_data): | ||
test_data, params = ts_data | ||
|
||
prediction_length = params["prediction_length"] | ||
context_length = params["context_length"] | ||
model_id = params["model_id"] | ||
model_id_path: str = model_id | ||
|
||
|
@@ -392,9 +407,18 @@ def test_future_data_forecast_inference(ts_data): | |
"future_data": encode_data(future_data, params["timestamp_column"]), | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, counts = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length * num_ids | ||
assert ( | ||
counts["input_data_points"] | ||
== ( | ||
context_length * len(params["target_columns"]) | ||
+ prediction_length * (len(params["target_columns"]) - len(target_columns)) | ||
) | ||
* num_ids | ||
) | ||
assert counts["output_data_points"] == prediction_length * 1 * num_ids | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -429,10 +453,9 @@ def test_zero_shot_forecast_inference_no_timestamp(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, _ = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
print(df_out[0].head()) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -487,7 +510,7 @@ def test_finetuned_model_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, _ = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length | ||
|
||
|
@@ -525,6 +548,6 @@ def test_trained_model_inference(ts_data): | |
"future_data": {}, | ||
} | ||
|
||
df_out = get_inference_response(msg) | ||
df_out, _ = get_inference_response(msg) | ||
assert len(df_out) == 1 | ||
assert df_out[0].shape[0] == prediction_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if "data_point" in k
or would it be something likeif k.find("data_point") >= 0
?