Skip to content

Commit

Permalink
update extend logic
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Nov 11, 2024
1 parent c81491e commit 932a7c7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
27 changes: 22 additions & 5 deletions services/inference/tsfminference/hf_service_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel

from tsfm_public import TimeSeriesForecastingPipeline, TimeSeriesPreprocessor
from tsfm_public.toolkit.time_series_preprocessor import extend_time_series
from tsfm_public.toolkit.util import select_by_index

from .inference_payloads import BaseParameters, ForecastingMetadataInput, ForecastingParameters
Expand Down Expand Up @@ -284,7 +285,11 @@ def _run(

# if data is too short, raise error
prediction_length = self.config.get("prediction_filter_length", None)
prediction_length = prediction_length if prediction_length is not None else self.config.prediction_length
has_prediction_filter = prediction_length is not None

model_prediction_length = self.config.prediction_length

prediction_length = prediction_length if prediction_length is not None else model_prediction_length
if fd_min_data_length < prediction_length:
err_str = (
"Future data should have time series of length that is at least the specified prediction length."
Expand All @@ -295,13 +300,25 @@ def _run(
err_str += (
f"Received {fd_min_data_length} time points, but expected {prediction_length} time points"
)

raise ValueError(err_str)

# if data exceeds prediction filter length, truncate
if fd_max_data_length > prediction_length:
LOGGER.info(f"Truncating future series lengths to {prediction_length}")
future_data = select_by_index(future_data, id_columns=schema.id_columns, end_index=prediction_length)
if fd_max_data_length > model_prediction_length:
LOGGER.info(f"Truncating future series lengths to {model_prediction_length}")
future_data = select_by_index(
future_data, id_columns=schema.id_columns, end_index=model_prediction_length
)

# if provided data is greater than prediction_filter_length, but less than model_prediction_length we extend
if has_prediction_filter and fd_min_data_length < model_prediction_length:
future_data = extend_time_series(
time_series=future_data,
freq=self.preprocessor.freq,
timestamp_column=schema.timestamp_column,
grouping_columns=schema.id_columns,
periods=model_prediction_length,
)
pass

forecast_pipeline = TimeSeriesForecastingPipeline(
model=self.model,
Expand Down
19 changes: 18 additions & 1 deletion tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)
from tsfm_public.toolkit.util import FractionLocation

from ..util import nreps


def test_standard_scaler(sample_data):
scaler = StandardScaler()
Expand Down Expand Up @@ -155,7 +157,7 @@ def test_time_series_preprocessor_inv_scales_lists(ts_data):
assert out_inv["value2"].mean()[0] == df["value2"].mean()


def test_augment_time_series(ts_data):
def test_extend_time_series(ts_data):
periods = 5
a = extend_time_series(ts_data, timestamp_column="timestamp", grouping_columns=["id"], periods=periods)

Expand All @@ -175,6 +177,21 @@ def test_augment_time_series(ts_data):
assert a.shape[0] == ts_data.shape[0] + 3 * periods
assert a.shape[1] == ts_data.shape[1]

# test different lengths

ts_data_2 = pd.DataFrame(
{
"id": list(nreps(["A", "B"], 50)) + ["C"] * 20,
"timestamp": [datetime(2021, 1, 1) + timedelta(days=i) for i in range(50)] * 2
+ [datetime(2021, 1, 1) + timedelta(days=i) for i in range(20)],
"value1": range(120),
}
)

a = extend_time_series(ts_data_2, timestamp_column="timestamp", grouping_columns=["id"], total_periods=60)

assert len(a) == 180


def test_create_timestamps():
base_last_timestamp = datetime(2020, 1, 1)
Expand Down
26 changes: 19 additions & 7 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,29 +1038,36 @@ def estimate_frequency(timestamp_data: Union[pd.Series, np.ndarray]):

def extend_time_series(
time_series: pd.DataFrame,
# last_known_timestamp,
timestamp_column: str,
grouping_columns: List[str],
freq: Optional[Union[int, float, datetime.timedelta, pd.Timedelta]] = None,
periods: int = 1,
# delta: datetime.timedelta = datetime.timedelta(days=1),
periods: int = None,
total_periods: Optional[int] = None,
):
"""Extends the provided time series with empty data for the number of periods specified. For each time series, based
on groups defined by grouping columns, adds emptry records following the last timestamp. The empty records contain
only timestamps and grouping indicators, remaining fields will be null.
One of periods or total_periods must be specified.
Args:
time_series (pd.DataFrame): _description_
start_timestamp (_type_): _description_
column_name (str): _description_
grouping_columns (List[str]): _description_
freq:
periods (int, optional): _description_. Defaults to 1.
delta (datetime.timedelta, optional): _description_. Defaults to datetime.timedelta(days=1).
total_periods (int, optional): total length of the series after extending. Defaults to None.
"""

def augment_one_series(group: Union[pd.Series, pd.DataFrame]):
def augment_one_series(
group: Union[pd.Series, pd.DataFrame], periods: Optional[int] = None, total_periods: Optional[int] = None
):
last_timestamp = group[timestamp_column].iloc[-1]

if periods is None:
periods = total_periods - len(group)

new_data = pd.DataFrame(
{
timestamp_column: create_timestamps(
Expand All @@ -1078,10 +1085,15 @@ def augment_one_series(group: Union[pd.Series, pd.DataFrame]):
)
return df.reset_index(drop=True)

if (periods is None and total_periods is None) or (periods is not None and total_periods is not None):
raise ValueError("Exactly one of `periods` or `total_periods` must be specified")

if grouping_columns == []:
new_time_series = augment_one_series(time_series)
new_time_series = augment_one_series(time_series, periods=periods, total_periods=total_periods)
else:
new_time_series = time_series.groupby(grouping_columns).apply(augment_one_series, include_groups=False)
new_time_series = time_series.groupby(grouping_columns).apply(
augment_one_series, include_groups=False, periods=periods, total_periods=total_periods
)
idx_names = list(new_time_series.index.names)
idx_names[-1] = "__delete"
new_time_series = new_time_series.reset_index(names=idx_names)
Expand Down

0 comments on commit 932a7c7

Please sign in to comment.