diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 5741158a..18f27b68 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -196,6 +196,51 @@ def test_forecasting_df_dataset_stride(ts_data_with_categorical): np.testing.assert_allclose(ds_past_np, ds_past_np_expected) +def test_forecasting_observed_mask(ts_data_with_categorical): + prediction_length = 2 + context_length = 5 + fill_value = 0.0 + target_columns = ["value2", "value3"] + + df = ts_data_with_categorical.copy() + df.loc[10, "value3"] = np.nan + + ds = ForecastDFDataset( + df, + timestamp_column="timestamp", + id_columns=["id"], + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + fill_value=fill_value, + ) + + # check matching size + assert ds[0]["past_observed_mask"].shape == ds[0]["past_values"].shape + assert ds[0]["future_observed_mask"].shape == ds[0]["future_values"].shape + + # Check mask is correct + np.testing.assert_allclose(ds[4]["future_observed_mask"], np.array([[True, True], [True, False]])) + np.testing.assert_allclose(ds[6]["past_observed_mask"][-1, :], np.array([True, False])) + + # Check mask value is correct + ds[4]["future_values"][1, 1] == fill_value + + # Check mask value is correct again + fill_value = -100.0 + ds = ForecastDFDataset( + df, + timestamp_column="timestamp", + id_columns=["id"], + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + fill_value=fill_value, + ) + + ds[4]["future_values"][1, 1] == fill_value + + def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical): prediction_length = 2 target_columns = ["value1"] diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 9520e271..1487563d 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -43,6 +43,7 @@ def __init__( prediction_length: int = 0, zero_padding: bool = True, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): super().__init__() if not isinstance(x_cols, list): @@ -71,6 +72,7 @@ def __init__( self.context_length = context_length self.prediction_length = prediction_length self.zero_padding = zero_padding + self.fill_value = fill_value self.timestamps = None self.group_id = group_id self.stride = stride @@ -154,6 +156,7 @@ def __init__( context_length: int = 1, prediction_length: int = 1, num_workers: int = 1, + fill_value: Union[float, int] = 0.0, cls=BaseDFDataset, stride: int = 1, **kwargs, @@ -171,6 +174,8 @@ def __init__( self.prediction_length = prediction_length self.stride = stride self.extra_kwargs = kwargs + self.fill_value = fill_value + self.cls = cls # create groupby object if len(id_columns) == 1: @@ -213,6 +218,7 @@ def concat_dataset(self): self.prediction_length, self.drop_cols, self.stride, + self.fill_value, self.extra_kwargs, ) for group_id, group in group_df @@ -234,6 +240,7 @@ def get_group_data( prediction_length: int = 1, drop_cols: Optional[List[str]] = None, stride: int = 1, + fill_value: Union[float, int] = 0.0, extra_kwargs: Dict[str, Any] = {}, ): return cls( @@ -245,6 +252,7 @@ def get_group_data( prediction_length=prediction_length, drop_cols=drop_cols, stride=stride, + fill_value=fill_value, **extra_kwargs, ) @@ -272,6 +280,7 @@ def __init__( context_length: int = 1, num_workers: int = 1, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): super().__init__( data_df=data, @@ -283,6 +292,7 @@ def __init__( cls=self.BasePretrainDFDataset, target_columns=target_columns, stride=stride, + fill_value=fill_value, ) self.n_inp = 1 @@ -298,6 +308,7 @@ def __init__( timestamp_column: Optional[str] = None, target_columns: List[str] = [], stride: int = 1, + fill_value: Union[float, int] = 0.0, ): self.target_columns = target_columns @@ -315,12 +326,16 @@ def __init__( group_id=group_id, drop_cols=drop_cols, stride=stride, + fill_value=fill_value, ) def __getitem__(self, index): time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values - ret = {"past_values": np_to_torch(seq_x)} + ret = { + "past_values": np_to_torch(seq_x), + "past_observed_mask": np_to_torch(~np.isnan(seq_x)), + } if self.datetime_col: ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] if self.group_id: @@ -359,6 +374,7 @@ def __init__( frequency_token: Optional[int] = None, autoregressive_modeling: bool = True, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): # output_columns_tmp = input_columns if output_columns == [] else output_columns @@ -369,6 +385,7 @@ def __init__( num_workers=num_workers, context_length=context_length, prediction_length=prediction_length, + fill_value=fill_value, cls=self.BaseForecastDFDataset, stride=stride, # extra_args @@ -406,6 +423,7 @@ def __init__( frequency_token: Optional[int] = None, autoregressive_modeling: bool = True, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): self.frequency_token = frequency_token self.target_columns = target_columns @@ -446,6 +464,7 @@ def __init__( group_id=group_id, drop_cols=drop_cols, stride=stride, + fill_value=fill_value, ) def __getitem__(self, index): @@ -465,9 +484,12 @@ def __getitem__(self, index): seq_y[:, self.y_mask_conditional] = 0 ret = { - "past_values": np_to_torch(seq_x), - "future_values": np_to_torch(seq_y), + "past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)), + "future_values": np_to_torch(np.nan_to_num(seq_y, nan=self.fill_value)), + "past_observed_mask": np_to_torch(~np.isnan(seq_x)), + "future_observed_mask": np_to_torch(~np.isnan(seq_y)), } + if self.datetime_col: ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] @@ -599,6 +621,8 @@ def np_to_torch(data: np.array, float_type=np.float32): return torch.from_numpy(data.astype(float_type)) elif data.dtype == "int": return torch.from_numpy(data) + elif data.dtype == "bool": + return torch.from_numpy(data) return torch.from_numpy(data)