Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference performance #49

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,34 @@ def test_forecasting_pipeline_forecasts():
forecasts_exploded = forecast_pipeline(test_data)
assert forecasts_exploded.shape == (prediction_length, len(target_columns) + 1)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
batch_size=10,
)

dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
test_end_index = 12 * 30 * 24 + 8 * 30 * 24
test_start_index = test_end_index - context_length - 9

data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)

test_data = select_by_index(
data,
id_columns=id_columns,
start_index=test_start_index,
end_index=test_end_index,
)
forecasts = forecast_pipeline(test_data)
assert forecast_pipeline._batch_size == 10
assert forecasts.shape == (10, 2 * len(target_columns) + 1)


def test_forecasting_pipeline_forecasts_with_preprocessor():
timestamp_column = "date"
Expand All @@ -92,30 +120,13 @@ def test_forecasting_pipeline_forecasts_with_preprocessor():
model = PatchTSTForPrediction.from_pretrained(model_path)
context_length = model.config.context_length

tsp = TimeSeriesPreprocessor(
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
freq="1h",
)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
)

dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)
train_end_index = 12 * 30 * 24

test_end_index = 12 * 30 * 24 + 8 * 30 * 24
test_start_index = test_end_index - context_length - 4

Expand All @@ -124,18 +135,48 @@ def test_forecasting_pipeline_forecasts_with_preprocessor():
parse_dates=[timestamp_column],
)

train_data = select_by_index(
data,
id_columns=id_columns,
start_index=0,
end_index=train_end_index,
)
test_data = select_by_index(
data,
id_columns=id_columns,
start_index=test_start_index,
end_index=test_end_index,
)

forecasts = forecast_pipeline(test_data)
tsp = TimeSeriesPreprocessor(
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
freq="1h",
scaling=True,
)

tsp.train(train_data)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
inverse_scale_outputs=True,
)

forecasts = forecast_pipeline(tsp.preprocess(test_data))

assert forecasts.shape == (
test_end_index - test_start_index - context_length + 1,
2 * len(target_columns) + 1,
)

# to do: add check on the scaling
# if we have inverse scaled mean should be larger
assert forecasts["HUFL_prediction"].mean().mean() > 10
129 changes: 105 additions & 24 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers.data.data_collator import default_data_collator
from transformers.pipelines.base import (
GenericTensor,
Pipeline,
build_pipeline_init_args,
)
from transformers.trainer_utils import RemoveColumnsCollator
from transformers.utils import add_end_docstrings, logging

from .dataset import ForecastDFDataset
Expand All @@ -31,10 +34,75 @@
logger = logging.get_logger(__name__)


class TimeSeriesPipeline(Pipeline):
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
"""Replaces base `run_single` method which does batching during inference. This is needed to support
large inference requests.

Args:
inputs (_type_): _description_
preprocess_params (_type_): _description_
forward_params (_type_): _description_
postprocess_params (_type_): _description_

Returns:
_type_: _description_
"""
# our preprocess returns a dataset
dataset = self.preprocess(inputs, **preprocess_params)

batch_size = forward_params["batch_size"]
signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys())

# if len(dataset) < batch_size:
# build a dataloader
# collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor)

remove_columns_collator = RemoveColumnsCollator(
data_collator=default_data_collator,
signature_columns=signature_columns,
logger=None,
description=None,
model_name=self.model.__class__.__name__,
)
dataloader = DataLoader(
dataset, num_workers=1, batch_size=batch_size, collate_fn=remove_columns_collator, shuffle=False
)

# iterate over dataloader
it = iter(dataloader)
accumulator = []
model_output_key = None
while (batch := next(it, None)) is not None:
item = self.forward(batch, **forward_params)
if not model_output_key:
model_output_key = "prediction_outputs" if "prediction_outputs" in item.keys() else "prediction_logits"
accumulator.append(item[model_output_key])

# collect all ouputs needed for post processing
first = dataset[0]
model_outputs = {}
for k, v in first.items():
if isinstance(v, torch.Tensor):
model_outputs[k] = torch.stack(tuple(r[k] for r in dataset))
else:
model_outputs[k] = [r[k] for r in dataset]

# without shuffling in the dataloader above, we assume that order is preserved
# otherwise we need to incorporate sequence id somewhere and do a proper join
model_outputs["prediction_outputs"] = torch.cat(accumulator, axis=0)

# call postprocess
outputs = self.postprocess(model_outputs, **postprocess_params)

return outputs


@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=False, has_feature_extractor=True, has_image_processor=False)
)
class TimeSeriesForecastingPipeline(Pipeline):
class TimeSeriesForecastingPipeline(TimeSeriesPipeline):
"""Hugging Face Pipeline for Time Series Forecasting

feature_extractor (TimeSeriesPreprocessor): A time series preprpocessor object that specifies how the time
Expand Down Expand Up @@ -112,6 +180,16 @@ def _sanitize_parameters(
if c in kwargs:
postprocess_kwargs[c] = kwargs[c]

# same logic as HF Pipeline
batch_size = kwargs.get("batch_size", self._batch_size)
if batch_size is None:
if self._batch_size is None:
batch_size = 1
else:
batch_size = self._batch_size

forward_kwargs = {"batch_size": batch_size}

# if "id_columns" in kwargs:
# preprocess_kwargs["id_columns"] = kwargs["id_columns"]
# postprocess_kwargs["id_columns"] = kwargs["id_columns"]
Expand All @@ -128,7 +206,7 @@ def _sanitize_parameters(
# preprocess_kwargs["output_columns"] = kwargs["input_columns"]
# postprocess_kwargs["output_columns"] = kwargs["input_columns"]

return preprocess_kwargs, {}, postprocess_kwargs
return preprocess_kwargs, forward_kwargs, postprocess_kwargs

def __call__(
self,
Expand Down Expand Up @@ -248,17 +326,18 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li
**kwargs,
)

# stack all the outputs
# torch tensors are stacked, but other values are passed through as a list
first = dataset[0]
full_output = {}
for k, v in first.items():
if isinstance(v, torch.Tensor):
full_output[k] = torch.stack(tuple(r[k] for r in dataset))
else:
full_output[k] = [r[k] for r in dataset]
# # stack all the outputs
# # torch tensors are stacked, but other values are passed through as a list
# first = dataset[0]
# full_output = {}
# for k, v in first.items():
# if isinstance(v, torch.Tensor):
# full_output[k] = torch.stack(tuple(r[k] for r in dataset))
# else:
# full_output[k] = [r[k] for r in dataset]

return full_output
# return full_output
return dataset

def _forward(self, model_inputs, **kwargs):
"""Forward step
Expand All @@ -279,20 +358,22 @@ def _forward(self, model_inputs, **kwargs):
# "freq_token",
# } # todo: this should not be hardcoded

signature = inspect.signature(self.model.forward)
model_input_keys = list(signature.parameters.keys())
# signature = inspect.signature(self.model.forward)
# model_input_keys = list(signature.parameters.keys())

# model_inputs_only = {}
# for k in model_input_keys:
# if k in model_inputs:
# model_inputs_only[k] = model_inputs[k]

model_inputs_only = {}
for k in model_input_keys:
if k in model_inputs:
model_inputs_only[k] = model_inputs[k]
# model_outputs = self.model(**model_inputs_only)

model_outputs = self.model(**model_inputs_only)
# # copy the other inputs
# copy_inputs = True
# for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]:
# model_outputs[k] = model_inputs[k]

# copy the other inputs
copy_inputs = True
for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]:
model_outputs[k] = model_inputs[k]
model_outputs = self.model(**model_inputs)

return model_outputs

Expand All @@ -307,7 +388,7 @@ def postprocess(self, input, **kwargs):
"""
out = {}

model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits"
model_output_key = "prediction_outputs" # if "prediction_outputs" in input.keys() else "prediction_logits"

# name the predictions of target columns
# outputs should only have size equal to target columns
Expand Down
30 changes: 20 additions & 10 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def get_datasets(
split_config: Dict[str, Union[List[Union[int, float]], float]],
fewshot_fraction: Optional[float] = None,
fewshot_location: str = FractionLocation.LAST.value,
return_dataframe: bool = False,
) -> Tuple[Any]:
"""Creates the preprocessed pytorch datasets needed for training and evaluation
using the HuggingFace trainer
Expand Down Expand Up @@ -697,6 +698,8 @@ def get_datasets(
fewshot_location (str): Determines where the fewshot data is chosen. Valid options are "first" and "last"
as described in the enum FewshotLocation. Default is to choose the fewshot data at the end
of the training dataset (i.e., "last").
return_dataframe: Instead for returning a pytorch dataset, return tuples of pandas dataframes, after any
preprocessing.

Returns:
Tuple of pytorch datasets, including: train, validation, test.
Expand Down Expand Up @@ -752,16 +755,23 @@ def get_datasets(
params["prediction_length"] = self.prediction_length

# get torch datasets
test_dataset = ForecastDFDataset(
self.preprocess(test_data),
**params,
)
train_dataset = ForecastDFDataset(self.preprocess(train_data), **params)
valid_dataset = ForecastDFDataset(
self.preprocess(valid_data),
**params,
)
return train_dataset, valid_dataset, test_dataset
train_valid_test = [train_data, valid_data, test_data]

if return_dataframe:
return tuple(train_valid_test)

return tuple([ForecastDFDataset(self.preprocess(d), **params) for d in train_valid_test])

# test_dataset = ForecastDFDataset(
# self.preprocess(test_data),
# **params,
# )
# train_dataset = ForecastDFDataset(self.preprocess(train_data), **params)
# valid_dataset = ForecastDFDataset(
# self.preprocess(valid_data),
# **params,
# )
# return train_dataset, valid_dataset, test_dataset


def create_timestamps(
Expand Down
Loading