Skip to content

Commit

Permalink
handle batch size, test
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed May 13, 2024
1 parent d6cf39c commit db7d6c6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
28 changes: 28 additions & 0 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,34 @@ def test_forecasting_pipeline_forecasts():
forecasts_exploded = forecast_pipeline(test_data)
assert forecasts_exploded.shape == (prediction_length, len(target_columns) + 1)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
batch_size=10,
)

dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
test_end_index = 12 * 30 * 24 + 8 * 30 * 24
test_start_index = test_end_index - context_length - 9

data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)

test_data = select_by_index(
data,
id_columns=id_columns,
start_index=test_start_index,
end_index=test_end_index,
)
forecasts = forecast_pipeline(test_data)
assert forecast_pipeline._batch_size == 10
assert forecasts.shape == (10, 2 * len(target_columns) + 1)


def test_forecasting_pipeline_forecasts_with_preprocessor():
timestamp_column = "date"
Expand Down
14 changes: 12 additions & 2 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para
# our preprocess returns a dataset
dataset = self.preprocess(inputs, **preprocess_params)

batch_size = 100
batch_size = forward_params["batch_size"]
signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys())

Expand Down Expand Up @@ -180,6 +180,16 @@ def _sanitize_parameters(
if c in kwargs:
postprocess_kwargs[c] = kwargs[c]

# same logic as HF Pipeline
batch_size = kwargs.get("batch_size", self._batch_size)
if batch_size is None:
if self._batch_size is None:
batch_size = 1
else:
batch_size = self._batch_size

forward_kwargs = {"batch_size": batch_size}

# if "id_columns" in kwargs:
# preprocess_kwargs["id_columns"] = kwargs["id_columns"]
# postprocess_kwargs["id_columns"] = kwargs["id_columns"]
Expand All @@ -196,7 +206,7 @@ def _sanitize_parameters(
# preprocess_kwargs["output_columns"] = kwargs["input_columns"]
# postprocess_kwargs["output_columns"] = kwargs["input_columns"]

return preprocess_kwargs, {}, postprocess_kwargs
return preprocess_kwargs, forward_kwargs, postprocess_kwargs

def __call__(
self,
Expand Down

0 comments on commit db7d6c6

Please sign in to comment.