diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index 87e8753c..c22b1cca 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -81,6 +81,34 @@ def test_forecasting_pipeline_forecasts(): forecasts_exploded = forecast_pipeline(test_data) assert forecasts_exploded.shape == (prediction_length, len(target_columns) + 1) + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + batch_size=10, + ) + + dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" + test_end_index = 12 * 30 * 24 + 8 * 30 * 24 + test_start_index = test_end_index - context_length - 9 + + data = pd.read_csv( + dataset_path, + parse_dates=[timestamp_column], + ) + + test_data = select_by_index( + data, + id_columns=id_columns, + start_index=test_start_index, + end_index=test_end_index, + ) + forecasts = forecast_pipeline(test_data) + assert forecast_pipeline._batch_size == 10 + assert forecasts.shape == (10, 2 * len(target_columns) + 1) + def test_forecasting_pipeline_forecasts_with_preprocessor(): timestamp_column = "date" diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index e0788324..35bb2ec9 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -51,7 +51,7 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para # our preprocess returns a dataset dataset = self.preprocess(inputs, **preprocess_params) - batch_size = 100 + batch_size = forward_params["batch_size"] signature = inspect.signature(self.model.forward) signature_columns = list(signature.parameters.keys()) @@ -180,6 +180,16 @@ def _sanitize_parameters( if c in kwargs: postprocess_kwargs[c] = kwargs[c] + # same logic as HF Pipeline + batch_size = kwargs.get("batch_size", self._batch_size) + if batch_size is None: + if self._batch_size is None: + batch_size = 1 + else: + batch_size = self._batch_size + + forward_kwargs = {"batch_size": batch_size} + # if "id_columns" in kwargs: # preprocess_kwargs["id_columns"] = kwargs["id_columns"] # postprocess_kwargs["id_columns"] = kwargs["id_columns"] @@ -196,7 +206,7 @@ def _sanitize_parameters( # preprocess_kwargs["output_columns"] = kwargs["input_columns"] # postprocess_kwargs["output_columns"] = kwargs["input_columns"] - return preprocess_kwargs, {}, postprocess_kwargs + return preprocess_kwargs, forward_kwargs, postprocess_kwargs def __call__( self,