diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index 6f9961da..ef53a67e 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -52,6 +52,7 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para dataset = self.preprocess(inputs, **preprocess_params) batch_size = forward_params["batch_size"] + num_workers = forward_params["num_workers"] signature = inspect.signature(self.model.forward) signature_columns = list(signature.parameters.keys()) @@ -66,7 +67,9 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para description=None, model_name=None, ) - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=remove_columns_collator, shuffle=False) + dataloader = DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=remove_columns_collator, shuffle=False + ) # iterate over dataloader it = iter(dataloader) @@ -204,7 +207,14 @@ def _sanitize_parameters( else: batch_size = self._batch_size - forward_kwargs = {"batch_size": batch_size} + num_workers = kwargs.get("num_workers", self._num_workers) + if num_workers is None: + if self._num_workers is None: + num_workers = 0 + else: + num_workers = self._num_workers + + forward_kwargs = {"batch_size": batch_size, "num_workers": num_workers} # if "id_columns" in kwargs: # preprocess_kwargs["id_columns"] = kwargs["id_columns"]