Skip to content

Commit

Permalink
clean up parameter handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Apr 5, 2024
1 parent 547f0fb commit 550fe39
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,25 @@ class TimeSeriesForecastingPipeline(Pipeline):
def __init__(
self,
*args,
freq: Optional[str] = None,
explode_forecasts: bool = False,
freq: Optional[Union[Any]] = None,
inverse_scale_outputs: bool = True,
**kwargs,
):
kwargs["freq"] = freq
kwargs["explode_forecasts"] = explode_forecasts
kwargs["inverse_scale_outputs"] = inverse_scale_outputs
super().__init__(*args, **kwargs)

if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")

self.explode_forecasts = explode_forecasts
self.freq = freq
# self.check_model_type(MODEL_FOR_TIME_SERIES_FORECASTING_MAPPING)

def _sanitize_parameters(self, **kwargs):
def _sanitize_parameters(
self,
**kwargs,
):
"""Assign parameters to the different parts of the process.
For expected parameters see the call method below.
Expand Down Expand Up @@ -94,6 +99,9 @@ def _sanitize_parameters(self, **kwargs):
"control_columns",
"conditional_columns",
"static_categorical_columns",
"freq",
"explode_forecasts",
"inverse_scale_outputs",
]

for c in preprocess_params:
Expand Down Expand Up @@ -178,6 +186,8 @@ def __call__(
explode_forecasts (bool): If true, forecasts are returned one value per row of the pandas dataframe. If false, the
forecast over the prediction length will be contained as a list in a single row of the pandas dataframe.
inverse_scale_outputs (bool): If true and a valid feature extractor is provided, the outputs will be inverse scaled.
Return (pandas dataframe):
A new pandas dataframe containing the forecasts. Each row will contain the id, timestamp, the original
input feature values and the output forecast for each input column. The output forecast is a list containing
Expand Down Expand Up @@ -293,6 +303,7 @@ def postprocess(self, input, **kwargs):
"""
out = {}

print(kwargs)
model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits"

# name the predictions of target columns
Expand All @@ -312,7 +323,7 @@ def postprocess(self, input, **kwargs):
out[c] = [elem[i] for elem in input["id"]]
out = pd.DataFrame(out)

if self.explode_forecasts:
if kwargs["explode_forecasts"]:
# we made only one forecast per time series, explode results
# explode == expand the lists in the dataframe
out_explode = []
Expand All @@ -321,7 +332,7 @@ def postprocess(self, input, **kwargs):
tmp = {}
if "timestamp_column" in kwargs:
tmp[kwargs["timestamp_column"]] = create_timestamps(
row[kwargs["timestamp_column"]], freq=self.freq, periods=l
row[kwargs["timestamp_column"]], freq=kwargs["freq"], periods=l
) # expand timestamps
if "id_columns" in kwargs:
for c in kwargs["id_columns"]:
Expand All @@ -345,7 +356,7 @@ def postprocess(self, input, **kwargs):
out = out[cols_ordered]

# inverse scale if we have a feature extractor
if self.feature_extractor is not None:
if self.feature_extractor is not None and kwargs["inverse_scale_outputs"]:
out = self.feature_extractor.inverse_scale_targets(out, suffix="_prediction")

return out

0 comments on commit 550fe39

Please sign in to comment.