Skip to content

Commit

Permalink
support predicting a subset of series (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Aug 15, 2023
1 parent 8ed4580 commit 7c5406f
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 91 deletions.
2 changes: 2 additions & 0 deletions mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array.GroupedArray.restore_difference': ( 'grouped_array.html#groupedarray.restore_difference',
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array.GroupedArray.take': ( 'grouped_array.html#groupedarray.take',
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array.GroupedArray.take_from_groups': ( 'grouped_array.html#groupedarray.take_from_groups',
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array.GroupedArray.transform_series': ( 'grouped_array.html#groupedarray.transform_series',
Expand Down
40 changes: 30 additions & 10 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def _update_features(self) -> pd.DataFrame:
features[feat_name] = feat_vals

features_df = pd.DataFrame(features, columns=self.features)
features_df[self.id_col] = self.uids
features_df[self.id_col] = self._uids
features_df[self.time_col] = self.curr_dates
return self.static_features_.merge(features_df, on=self.id_col)

Expand All @@ -436,7 +436,7 @@ def _get_predictions(self) -> pd.DataFrame:
"""Get all the predicted values with their corresponding ids and datestamps."""
n_preds = len(self.y_pred)
uids = pd.Series(
np.repeat(self.uids, n_preds), name=self.id_col, dtype=self.uids.dtype
np.repeat(self._uids, n_preds), name=self.id_col, dtype=self.uids.dtype
)
df = pd.DataFrame(
{
Expand All @@ -448,10 +448,13 @@ def _get_predictions(self) -> pd.DataFrame:
return df

def _predict_setup(self) -> None:
self.ga = GroupedArray(self._ga.data, self._ga.indptr)
self.curr_dates = self.last_dates.copy()
if self._idxs is not None:
self.ga = self.ga.take(self._idxs)
self.curr_dates = self.curr_dates[self._idxs]
self.test_dates = []
self.y_pred = []
self.ga = GroupedArray(self._ga.data, self._ga.indptr)
if self.keep_last_n is not None:
self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))
self._h = 0
Expand All @@ -463,7 +466,7 @@ def _get_features_for_next_step(self, dynamic_dfs, X_df=None):
new_x = new_x.merge(df, how="left")
new_x = new_x.sort_values(self.id_col)
if X_df is not None:
n_series = self.uids.size
n_series = len(self._uids)
X = X_df.iloc[self._h * n_series : (self._h + 1) * n_series]
new_x = pd.concat([new_x, X.reset_index(drop=True)], axis=1)
nulls = new_x.isnull().any()
Expand Down Expand Up @@ -492,7 +495,7 @@ def _predict_recursive(
new_x = before_predict_callback(new_x)
predictions = model.predict(new_x)
if after_predict_callback is not None:
predictions_serie = pd.Series(predictions, index=self.uids)
predictions_serie = pd.Series(predictions, index=self._uids)
predictions = after_predict_callback(predictions_serie).values
self._update_y(predictions)
if i == 0:
Expand All @@ -519,7 +522,7 @@ def _predict_multi(
)
if dynamic_dfs is None:
dynamic_dfs = []
uids = np.repeat(self.uids, horizon)
uids = np.repeat(self._uids, horizon)
dates = np.hstack(
[
date + (i + 1) * self.freq
Expand Down Expand Up @@ -548,7 +551,21 @@ def predict(
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
X_df: Optional[pd.DataFrame] = None,
ids: Optional[List[str]] = None,
) -> pd.DataFrame:
if ids is not None:
unseen = set(ids) - set(self.uids)
if unseen:
raise ValueError(
f"The following ids weren't seen during training and thus can't be forecasted: {unseen}"
)
self._uids = self.uids[self.uids.isin(ids)]
self._idxs: Optional[np.ndarray] = np.where(self.uids.isin(self._uids))[0]
last_dates = self.last_dates[self._idxs]
else:
self._uids = self.uids
self._idxs = None
last_dates = self.last_dates
if X_df is not None:
if self.id_col not in X_df or self.time_col not in X_df:
raise ValueError(
Expand All @@ -567,14 +584,14 @@ def predict(
)
dates_validation = pd.DataFrame(
{
self.id_col: self.uids,
"_start": self.last_dates + self.freq,
"_end": self.last_dates + horizon * self.freq,
self.id_col: self._uids,
"_start": last_dates + self.freq,
"_end": last_dates + horizon * self.freq,
}
)
X_df = X_df.merge(dates_validation, on=[self.id_col])
X_df = X_df[X_df[self.time_col].between(X_df["_start"], X_df["_end"])]
if X_df.shape[0] != self.uids.size * horizon:
if X_df.shape[0] != len(self._uids) * horizon:
raise ValueError(
"Found missing inputs in X_df. "
"It should have one row per id and date for the complete forecasting horizon"
Expand All @@ -601,7 +618,10 @@ def predict(
)
if self.target_transforms is not None:
for tfm in self.target_transforms[::-1]:
tfm.idxs = self._idxs
preds = tfm.inverse_transform(preds)
tfm.idxs = None
del self._uids, self._idxs
return preds

def update(self, df: pd.DataFrame) -> None:
Expand Down
20 changes: 12 additions & 8 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
num_threads: int = 1,
target_transforms: Optional[List[BaseTargetTransform]] = None,
):
"""Create forecast object
"""Forecasting pipeline
Parameters
----------
Expand Down Expand Up @@ -409,11 +409,12 @@ def predict(
new_df: Optional[pd.DataFrame] = None,
level: Optional[List[Union[int, float]]] = None,
X_df: Optional[pd.DataFrame] = None,
ids: Optional[List[str]] = None,
*,
horizon: Optional[int] = None, # noqa: ARG002
new_data: Optional[pd.DataFrame] = None, # noqa: ARG002
) -> pd.DataFrame:
"""Compute the predictions for the next `horizon` steps.
"""Compute the predictions for the next `h` steps.
Parameters
----------
Expand All @@ -437,6 +438,8 @@ def predict(
Confidence levels between 0 and 100 for prediction intervals.
X_df : pandas DataFrame, optional (default=None)
Dataframe with the future exogenous features. Should have the id column and the time column.
ids : list of str, optional (default=None)
List with subset of ids seen during training for which the forecasts should be computed.
horizon : int
Number of periods to predict. This argument has been replaced by h and will be removed in a later release.
new_data : pandas DataFrame, optional (default=None)
Expand Down Expand Up @@ -488,12 +491,13 @@ def predict(
ts = self.ts

forecasts = ts.predict(
self.models_,
h,
dynamic_dfs,
before_predict_callback,
after_predict_callback,
X_df,
models=self.models_,
horizon=h,
dynamic_dfs=dynamic_dfs,
before_predict_callback=before_predict_callback,
after_predict_callback=after_predict_callback,
X_df=X_df,
ids=ids,
)
if level is not None:
if self._cs_df is None:
Expand Down
8 changes: 8 additions & 0 deletions mlforecast/grouped_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def __setitem__(self, idx: int, vals: np.ndarray):
raise ValueError(f"vals must be of size {self[idx].size}")
self[idx][:] = vals

def take(self, idxs: np.ndarray) -> "GroupedArray":
ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs]
items = [self.data[rng] for rng in ranges]
sizes = np.array([item.size for item in items])
data = np.hstack(items)
indptr = np.append(0, sizes.cumsum())
return GroupedArray(data, indptr)

@classmethod
def from_sorted_df(
cls, df: "pd.DataFrame", id_col: str, target_col: str
Expand Down
21 changes: 13 additions & 8 deletions mlforecast/target_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# %% ../nbs/target_transforms.ipynb 2
import abc
import reprlib
from typing import TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Iterable, Optional

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -17,6 +17,8 @@

# %% ../nbs/target_transforms.ipynb 3
class BaseTargetTransform(abc.ABC):
idxs: Optional[np.ndarray] = None

def set_column_names(self, id_col: str, time_col: str, target_col: str):
self.id_col = id_col
self.time_col = time_col
Expand Down Expand Up @@ -53,18 +55,20 @@ def fit_transform(self, df: "pd.DataFrame") -> "pd.DataFrame":
new_indptr = d * np.arange(n_series + 1, dtype=np.int32)
_apply_difference(ga.data, ga.indptr, new_data, new_indptr, d)
self.original_values_.append(GroupedArray(new_data, new_indptr))
df = df.copy()
df = df.copy(deep=False)
df[self.target_col] = ga.data
return df

def inverse_transform(self, df: "pd.DataFrame") -> "pd.DataFrame":
model_cols = df.columns.drop([self.id_col, self.time_col])
df = df.copy()
df = df.copy(deep=False)
for model in model_cols:
model_preds = df[model].values.copy()
for d, ga in zip(
reversed(self.differences), reversed(self.original_values_)
):
if self.idxs is not None:
ga = ga.take(self.idxs)
ga.restore_difference(model_preds, d)
df[model] = model_preds
return df
Expand All @@ -76,8 +80,8 @@ def _standard_scaler_transform(data, indptr, stats, out):
for i in range(n_series):
sl = slice(indptr[i], indptr[i + 1])
subs = data[sl]
mean_ = subs.mean()
std_ = subs.std()
mean_ = np.nanmean(subs)
std_ = np.nanstd(subs)
stats[i] = mean_, std_
out[sl] = (data[sl] - mean_) / std_

Expand All @@ -102,15 +106,16 @@ def fit_transform(self, df: "pd.DataFrame") -> "pd.DataFrame":
self.stats_ = np.empty((len(ga.indptr) - 1, 2))
out = np.empty_like(ga.data)
_standard_scaler_transform(ga.data, ga.indptr, self.stats_, out)
df = df.copy()
df = df.copy(deep=False)
df[self.target_col] = out
return df

def inverse_transform(self, df: "pd.DataFrame") -> "pd.DataFrame":
df = df.copy()
df = df.copy(deep=False)
model_cols = df.columns.drop([self.id_col, self.time_col])
stats = self.stats_ if self.idxs is None else self.stats_[self.idxs]
for model in model_cols:
model_preds = df[model].values
_standard_scaler_inverse_transform(model_preds, self.stats_)
_standard_scaler_inverse_transform(model_preds, stats)
df[model] = model_preds
return df
Loading

0 comments on commit 7c5406f

Please sign in to comment.