diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 8486396d..29e2a337 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -196,6 +196,7 @@ def test_forecasting_df_dataset_stride(ts_data_with_categorical): def test_forecasting_observed_mask(ts_data_with_categorical): prediction_length = 2 context_length = 5 + fill_value = 0.0 target_columns = ["value2", "value3"] df = ts_data_with_categorical.copy() @@ -208,6 +209,7 @@ def test_forecasting_observed_mask(ts_data_with_categorical): target_columns=target_columns, context_length=context_length, prediction_length=prediction_length, + fill_value=fill_value, ) # check matching size @@ -218,6 +220,23 @@ def test_forecasting_observed_mask(ts_data_with_categorical): np.testing.assert_allclose(ds[4]["future_observed_mask"], np.array([[True, True], [True, False]])) np.testing.assert_allclose(ds[6]["past_observed_mask"][-1, :], np.array([True, False])) + # Check mask value is correct + ds[4]["future_values"][1, 1] == fill_value + + # Check mask value is correct again + fill_value = -100.0 + ds = ForecastDFDataset( + df, + timestamp_column="timestamp", + id_columns=["id"], + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + fill_value=fill_value, + ) + + ds[4]["future_values"][1, 1] == fill_value + def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical): prediction_length = 2