Skip to content

Commit

Permalink
check fill value
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Jun 12, 2024
1 parent 23dcd60 commit ec29125
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def test_forecasting_df_dataset_stride(ts_data_with_categorical):
def test_forecasting_observed_mask(ts_data_with_categorical):
prediction_length = 2
context_length = 5
fill_value = 0.0
target_columns = ["value2", "value3"]

df = ts_data_with_categorical.copy()
Expand All @@ -208,6 +209,7 @@ def test_forecasting_observed_mask(ts_data_with_categorical):
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
fill_value=fill_value,
)

# check matching size
Expand All @@ -218,6 +220,23 @@ def test_forecasting_observed_mask(ts_data_with_categorical):
np.testing.assert_allclose(ds[4]["future_observed_mask"], np.array([[True, True], [True, False]]))
np.testing.assert_allclose(ds[6]["past_observed_mask"][-1, :], np.array([True, False]))

# Check mask value is correct
ds[4]["future_values"][1, 1] == fill_value

# Check mask value is correct again
fill_value = -100.0
ds = ForecastDFDataset(
df,
timestamp_column="timestamp",
id_columns=["id"],
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
fill_value=fill_value,
)

ds[4]["future_values"][1, 1] == fill_value


def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical):
prediction_length = 2
Expand Down

0 comments on commit ec29125

Please sign in to comment.