Skip to content

Commit

Permalink
allow batch_size, num_workers
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Jul 26, 2024
1 parent 493a2f7 commit 8176f85
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para
dataset = self.preprocess(inputs, **preprocess_params)

batch_size = forward_params["batch_size"]
num_workers = forward_params["num_workers"]
signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys())

Expand All @@ -66,7 +67,9 @@ def run_single(self, inputs, preprocess_params, forward_params, postprocess_para
description=None,
model_name=None,
)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=remove_columns_collator, shuffle=False)
dataloader = DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=remove_columns_collator, shuffle=False
)

# iterate over dataloader
it = iter(dataloader)
Expand Down Expand Up @@ -204,7 +207,14 @@ def _sanitize_parameters(
else:
batch_size = self._batch_size

forward_kwargs = {"batch_size": batch_size}
num_workers = kwargs.get("num_workers", self._num_workers)
if num_workers is None:
if self._num_workers is None:
num_workers = 0
else:
num_workers = self._num_workers

forward_kwargs = {"batch_size": batch_size, "num_workers": num_workers}

# if "id_columns" in kwargs:
# preprocess_kwargs["id_columns"] = kwargs["id_columns"]
Expand Down

0 comments on commit 8176f85

Please sign in to comment.