Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up parameter handling in forecasting pipeline #33

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,25 @@ class TimeSeriesForecastingPipeline(Pipeline):
def __init__(
self,
*args,
freq: Optional[str] = None,
explode_forecasts: bool = False,
freq: Optional[Union[Any]] = None,
inverse_scale_outputs: bool = True,
**kwargs,
):
kwargs["freq"] = freq
kwargs["explode_forecasts"] = explode_forecasts
kwargs["inverse_scale_outputs"] = inverse_scale_outputs
super().__init__(*args, **kwargs)

if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

self.explode_forecasts = explode_forecasts
self.freq = freq
# self.check_model_type(MODEL_FOR_TIME_SERIES_FORECASTING_MAPPING)

def _sanitize_parameters(self, **kwargs):
def _sanitize_parameters(
self,
**kwargs,
):
"""Assign parameters to the different parts of the process.

For expected parameters see the call method below.
Expand Down Expand Up @@ -94,6 +99,9 @@ def _sanitize_parameters(self, **kwargs):
"control_columns",
"conditional_columns",
"static_categorical_columns",
"freq",
"explode_forecasts",
"inverse_scale_outputs",
]

for c in preprocess_params:
Expand Down Expand Up @@ -178,6 +186,8 @@ def __call__(
explode_forecasts (bool): If true, forecasts are returned one value per row of the pandas dataframe. If false, the
forecast over the prediction length will be contained as a list in a single row of the pandas dataframe.

inverse_scale_outputs (bool): If true and a valid feature extractor is provided, the outputs will be inverse scaled.

Return (pandas dataframe):
A new pandas dataframe containing the forecasts. Each row will contain the id, timestamp, the original
input feature values and the output forecast for each input column. The output forecast is a list containing
Expand Down Expand Up @@ -293,6 +303,7 @@ def postprocess(self, input, **kwargs):
"""
out = {}

print(kwargs)
model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits"

# name the predictions of target columns
Expand All @@ -312,7 +323,7 @@ def postprocess(self, input, **kwargs):
out[c] = [elem[i] for elem in input["id"]]
out = pd.DataFrame(out)

if self.explode_forecasts:
if kwargs["explode_forecasts"]:
# we made only one forecast per time series, explode results
# explode == expand the lists in the dataframe
out_explode = []
Expand All @@ -321,7 +332,7 @@ def postprocess(self, input, **kwargs):
tmp = {}
if "timestamp_column" in kwargs:
tmp[kwargs["timestamp_column"]] = create_timestamps(
row[kwargs["timestamp_column"]], freq=self.freq, periods=l
row[kwargs["timestamp_column"]], freq=kwargs["freq"], periods=l
) # expand timestamps
if "id_columns" in kwargs:
for c in kwargs["id_columns"]:
Expand All @@ -345,7 +356,7 @@ def postprocess(self, input, **kwargs):
out = out[cols_ordered]

# inverse scale if we have a feature extractor
if self.feature_extractor is not None:
if self.feature_extractor is not None and kwargs["inverse_scale_outputs"]:
out = self.feature_extractor.inverse_scale_targets(out, suffix="_prediction")

return out
Loading