diff --git a/.github/workflows/build-docs.yaml b/.github/workflows/build-docs.yaml
index 7574408a..469ff80e 100644
--- a/.github/workflows/build-docs.yaml
+++ b/.github/workflows/build-docs.yaml
@@ -51,6 +51,18 @@ jobs:
publish_dir: ./_docs
user_name: github-actions[bot]
user_email: 41898282+github-actions[bot]@users.noreply.github.com
+ - name: Trigger mintlify workflow
+ if: github.event_name == 'push'
+ uses: actions/github-script@v7
+ with:
+ github-token: ${{ secrets.DOCS_WORKFLOW_TOKEN }}
+ script: |
+ await github.rest.actions.createWorkflowDispatch({
+ owner: 'nixtla',
+ repo: 'docs',
+ workflow_id: 'mintlify-action.yml',
+ ref: 'main',
+ });
- name: Deploy to Github Pages
if: github.event_name == 'push'
uses: peaceiris/actions-gh-pages@v3
diff --git a/environment.yml b/environment.yml
index 6d762a48..16745300 100644
--- a/environment.yml
+++ b/environment.yml
@@ -31,5 +31,5 @@ dependencies:
- polars
- ray<2.8
- triad==0.9.1
- - utilsforecast>=0.0.21
+ - utilsforecast>=0.0.24
- xgboost_ray
diff --git a/local_environment.yml b/local_environment.yml
index d5c711c5..9debf0e1 100644
--- a/local_environment.yml
+++ b/local_environment.yml
@@ -21,4 +21,4 @@ dependencies:
- datasetsforecast
- nbdev
- polars
- - utilsforecast>=0.0.21
+ - utilsforecast>=0.0.24
diff --git a/mlforecast/__init__.py b/mlforecast/__init__.py
index d9d20db7..e01d5f6d 100644
--- a/mlforecast/__init__.py
+++ b/mlforecast/__init__.py
@@ -1,3 +1,3 @@
-__version__ = "0.11.2"
+__version__ = "0.11.5"
__all__ = ['MLForecast']
from mlforecast.forecast import MLForecast
diff --git a/mlforecast/_modidx.py b/mlforecast/_modidx.py
index 19ca6691..b596715b 100644
--- a/mlforecast/_modidx.py
+++ b/mlforecast/_modidx.py
@@ -31,6 +31,8 @@
'mlforecast/core.py'),
'mlforecast.core.TimeSeries._get_raw_predictions': ( 'core.html#timeseries._get_raw_predictions',
'mlforecast/core.py'),
+ 'mlforecast.core.TimeSeries._has_ga_target_tfms': ( 'core.html#timeseries._has_ga_target_tfms',
+ 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._predict_multi': ('core.html#timeseries._predict_multi', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._predict_recursive': ( 'core.html#timeseries._predict_recursive',
'mlforecast/core.py'),
@@ -194,6 +196,8 @@
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array.GroupedArray.take_from_groups': ( 'grouped_array.html#groupedarray.take_from_groups',
'mlforecast/grouped_array.py'),
+ 'mlforecast.grouped_array.GroupedArray.update_difference': ( 'grouped_array.html#groupedarray.update_difference',
+ 'mlforecast/grouped_array.py'),
'mlforecast.grouped_array._append_one': ( 'grouped_array.html#_append_one',
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array._append_several': ( 'grouped_array.html#_append_several',
@@ -208,7 +212,9 @@
'mlforecast.grouped_array._restore_fitted_difference': ( 'grouped_array.html#_restore_fitted_difference',
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array._transform_series': ( 'grouped_array.html#_transform_series',
- 'mlforecast/grouped_array.py')},
+ 'mlforecast/grouped_array.py'),
+ 'mlforecast.grouped_array._update_difference': ( 'grouped_array.html#_update_difference',
+ 'mlforecast/grouped_array.py')},
'mlforecast.lag_transforms': { 'mlforecast.lag_transforms.BaseLagTransform': ( 'lag_transforms.html#baselagtransform',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.BaseLagTransform.transform': ( 'lag_transforms.html#baselagtransform.transform',
@@ -315,14 +321,20 @@
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseGroupedArrayTargetTransform.inverse_transform_fitted': ( 'target_transforms.html#basegroupedarraytargettransform.inverse_transform_fitted',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.BaseGroupedArrayTargetTransform.update': ( 'target_transforms.html#basegroupedarraytargettransform.update',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseLocalScaler': ( 'target_transforms.html#baselocalscaler',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.BaseLocalScaler._is_utils_tfm': ( 'target_transforms.html#baselocalscaler._is_utils_tfm',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseLocalScaler.fit_transform': ( 'target_transforms.html#baselocalscaler.fit_transform',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseLocalScaler.inverse_transform': ( 'target_transforms.html#baselocalscaler.inverse_transform',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseLocalScaler.inverse_transform_fitted': ( 'target_transforms.html#baselocalscaler.inverse_transform_fitted',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.BaseLocalScaler.update': ( 'target_transforms.html#baselocalscaler.update',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseTargetTransform': ( 'target_transforms.html#basetargettransform',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseTargetTransform.fit_transform': ( 'target_transforms.html#basetargettransform.fit_transform',
@@ -331,6 +343,8 @@
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.BaseTargetTransform.set_column_names': ( 'target_transforms.html#basetargettransform.set_column_names',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.BaseTargetTransform.update': ( 'target_transforms.html#basetargettransform.update',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.Differences': ( 'target_transforms.html#differences',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.Differences.__init__': ( 'target_transforms.html#differences.__init__',
@@ -341,6 +355,8 @@
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.Differences.inverse_transform_fitted': ( 'target_transforms.html#differences.inverse_transform_fitted',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.Differences.update': ( 'target_transforms.html#differences.update',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.GlobalSklearnTransformer': ( 'target_transforms.html#globalsklearntransformer',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.GlobalSklearnTransformer.__init__': ( 'target_transforms.html#globalsklearntransformer.__init__',
@@ -349,6 +365,8 @@
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.GlobalSklearnTransformer.inverse_transform': ( 'target_transforms.html#globalsklearntransformer.inverse_transform',
'mlforecast/target_transforms.py'),
+ 'mlforecast.target_transforms.GlobalSklearnTransformer.update': ( 'target_transforms.html#globalsklearntransformer.update',
+ 'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.LocalBoxCox': ( 'target_transforms.html#localboxcox',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.LocalBoxCox.__init__': ( 'target_transforms.html#localboxcox.__init__',
diff --git a/mlforecast/compat.py b/mlforecast/compat.py
index edc93d7b..1bd565b1 100644
--- a/mlforecast/compat.py
+++ b/mlforecast/compat.py
@@ -6,6 +6,7 @@
# %% ../nbs/compat.ipynb 1
try:
import coreforecast.lag_transforms as core_tfms
+ import coreforecast.scalers as core_scalers
from coreforecast.grouped_array import GroupedArray as CoreGroupedArray
from mlforecast.lag_transforms import BaseLagTransform, Lag
@@ -13,6 +14,7 @@
CORE_INSTALLED = True
except ImportError:
core_tfms = None
+ core_scalers = None
CoreGroupedArray = None
class BaseLagTransform:
diff --git a/mlforecast/core.py b/mlforecast/core.py
index 6f9efa13..87d2528e 100644
--- a/mlforecast/core.py
+++ b/mlforecast/core.py
@@ -561,7 +561,7 @@ def _get_predictions(self) -> DataFrame:
return df
def _predict_setup(self) -> None:
- self._ga = copy.copy(self.ga)
+ self.ga = copy.copy(self._ga)
if isinstance(self.last_dates, pl_Series):
self.curr_dates = self.last_dates.clone()
else:
@@ -657,6 +657,12 @@ def _predict_multi(
result = ufp.assign_columns(result, name, raw_preds)
return result
+ def _has_ga_target_tfms(self):
+ return any(
+ isinstance(tfm, BaseGroupedArrayTargetTransform)
+ for tfm in self.target_transforms
+ )
+
def predict(
self,
models: Dict[str, Union[BaseEstimator, List[BaseEstimator]]],
@@ -723,26 +729,30 @@ def predict(
raise ValueError(msg)
drop_cols = [self.id_col, self.time_col, "_start", "_end"]
X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(columns=drop_cols)
- if getattr(self, "max_horizon", None) is None:
- preds = self._predict_recursive(
- models=models,
- horizon=horizon,
- before_predict_callback=before_predict_callback,
- after_predict_callback=after_predict_callback,
- X_df=X_df,
- )
- else:
- preds = self._predict_multi(
- models=models,
- horizon=horizon,
- before_predict_callback=before_predict_callback,
- X_df=X_df,
- )
+ # backup original series. the ga attribute gets modified
+ # and is copied from _ga at the start of each model's predict
+ self._ga = copy.copy(self.ga)
+ try:
+ if getattr(self, "max_horizon", None) is None:
+ preds = self._predict_recursive(
+ models=models,
+ horizon=horizon,
+ before_predict_callback=before_predict_callback,
+ after_predict_callback=after_predict_callback,
+ X_df=X_df,
+ )
+ else:
+ preds = self._predict_multi(
+ models=models,
+ horizon=horizon,
+ before_predict_callback=before_predict_callback,
+ X_df=X_df,
+ )
+ finally:
+ self.ga = self._ga
+ del self._ga
if self.target_transforms is not None:
- if any(
- isinstance(tfm, BaseGroupedArrayTargetTransform)
- for tfm in self.target_transforms
- ):
+ if self._has_ga_target_tfms():
model_cols = [
c for c in preds.columns if c not in (self.id_col, self.time_col)
]
@@ -751,14 +761,15 @@ def predict(
if isinstance(tfm, BaseGroupedArrayTargetTransform):
tfm.idxs = self._idxs
for col in model_cols:
- ga = GroupedArray(preds[col].to_numpy(), indptr)
+ ga = GroupedArray(
+ preds[col].to_numpy().astype(self.ga.data.dtype), indptr
+ )
ga = tfm.inverse_transform(ga)
preds = ufp.assign_columns(preds, col, ga.data)
tfm.idxs = None
else:
preds = tfm.inverse_transform(preds)
- self.ga = self._ga
- del self._uids, self._idxs, self._static_features, self._ga
+ del self._uids, self._idxs, self._static_features
return preds
def save(self, path: Union[str, Path]) -> None:
@@ -778,11 +789,17 @@ def update(self, df: DataFrame) -> None:
if isinstance(uids, pd.Index):
uids = pd.Series(uids)
uids, new_ids = ufp.match_if_categorical(uids, df[self.id_col])
+ df = ufp.copy_if_pandas(df, deep=False)
df = ufp.assign_columns(df, self.id_col, new_ids)
df = ufp.sort(df, by=[self.id_col, self.time_col])
values = df[self.target_col].to_numpy()
+ values = values.astype(self.ga.data.dtype, copy=False)
id_counts = ufp.counts_by_id(df, self.id_col)
- sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer")
+ try:
+ sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer_coalesce")
+ except (KeyError, ValueError):
+ # pandas raises key error, polars before coalesce raises value error
+ sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer")
sizes = ufp.fill_null(sizes, {"counts": 0})
sizes = ufp.sort(sizes, by=self.id_col)
new_groups = ~ufp.is_in(sizes[self.id_col], uids)
@@ -794,9 +811,12 @@ def update(self, df: DataFrame) -> None:
last_dates = ufp.sort(last_dates, by=self.id_col)
self.last_dates = ufp.cast(last_dates[self.time_col], self.last_dates.dtype)
self.uids = ufp.sort(sizes[self.id_col])
- if isinstance(self.uids, pd.Series):
+ if isinstance(df, pd.DataFrame):
self.uids = pd.Index(self.uids)
+ self.last_dates = pd.Index(self.last_dates)
if new_groups.any():
+ if self.target_transforms is not None:
+ raise ValueError("Can not update target_transforms with new series.")
new_ids = ufp.filter_with_mask(sizes[self.id_col], new_groups)
new_ids_df = ufp.filter_with_mask(df, ufp.is_in(df[self.id_col], new_ids))
new_ids_counts = ufp.counts_by_id(new_ids_df, self.id_col)
@@ -808,6 +828,17 @@ def update(self, df: DataFrame) -> None:
[self.static_features_, new_statics]
)
self.static_features_ = ufp.sort(self.static_features_, self.id_col)
+ if self.target_transforms is not None:
+ if self._has_ga_target_tfms():
+ indptr = np.append(0, id_counts["counts"]).cumsum()
+ for tfm in self.target_transforms:
+ if isinstance(tfm, BaseGroupedArrayTargetTransform):
+ ga = GroupedArray(values, indptr)
+ ga = tfm.update(ga)
+ df = ufp.assign_columns(df, self.target_col, ga.data)
+ else:
+ df = tfm.update(df)
+ values = df[self.target_col].to_numpy()
self.ga = self.ga.append_several(
new_sizes=sizes["counts"].to_numpy().astype(np.int32),
new_values=values,
diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py
index cfede794..28a55a3c 100644
--- a/mlforecast/distributed/forecast.py
+++ b/mlforecast/distributed/forecast.py
@@ -431,6 +431,7 @@ def _predict(
horizon,
before_predict_callback=None,
after_predict_callback=None,
+ X_df=None,
) -> Iterable[pd.DataFrame]:
for serialized_ts, _, serialized_valid in items:
valid = cloudpickle.loads(serialized_valid)
@@ -440,6 +441,7 @@ def _predict(
horizon=horizon,
before_predict_callback=before_predict_callback,
after_predict_callback=after_predict_callback,
+ X_df=X_df,
)
if valid is not None:
res = res.merge(valid, how="left")
@@ -459,6 +461,7 @@ def predict(
h: int,
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
+ X_df: Optional[pd.DataFrame] = None,
new_df: Optional[fugue.AnyDataFrame] = None,
) -> fugue.AnyDataFrame:
"""Compute the predictions for the next `horizon` steps.
@@ -475,6 +478,8 @@ def predict(
Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index.
+ X_df : pandas DataFrame, optional (default=None)
+ Dataframe with the future exogenous features. Should have the id column and the time column.
new_df : dask or spark DataFrame, optional (default=None)
Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
@@ -499,6 +504,8 @@ def predict(
else:
partition_results = self._partition_results
schema = self._get_predict_schema()
+ if X_df is not None and not isinstance(X_df, pd.DataFrame):
+ raise ValueError("`X_df` should be a pandas DataFrame")
res = fa.transform(
partition_results,
DistributedMLForecast._predict,
@@ -507,6 +514,7 @@ def predict(
"horizon": h,
"before_predict_callback": before_predict_callback,
"after_predict_callback": after_predict_callback,
+ "X_df": X_df,
},
schema=schema,
engine=self.engine,
diff --git a/mlforecast/forecast.py b/mlforecast/forecast.py
index 56409602..e359be3c 100644
--- a/mlforecast/forecast.py
+++ b/mlforecast/forecast.py
@@ -536,11 +536,33 @@ def fit(
self.fcst_fitted_values_ = fitted_values
return self
- def forecast_fitted_values(self):
- """Access in-sample predictions."""
+ def forecast_fitted_values(
+ self, level: Optional[List[Union[int, float]]] = None
+ ) -> DataFrame:
+ """Access in-sample predictions.
+
+ Parameters
+ ----------
+ level : list of ints or floats, optional (default=None)
+ Confidence levels between 0 and 100 for prediction intervals.
+
+ Returns
+ -------
+ pandas or polars DataFrame
+ Dataframe with predictions for the training set
+ """
if not hasattr(self, "fcst_fitted_values_"):
raise Exception("Please run the `fit` method using `fitted=True`")
- return self.fcst_fitted_values_
+ res = self.fcst_fitted_values_
+ if level is not None:
+ res = ufp.add_insample_levels(
+ res,
+ models=self.models_.keys(),
+ level=level,
+ id_col=self.ts.id_col,
+ target_col=self.ts.target_col,
+ )
+ return res
def make_future_dataframe(self, h: int) -> DataFrame:
"""Create a dataframe with all ids and future times in the forecasting horizon.
diff --git a/mlforecast/grouped_array.py b/mlforecast/grouped_array.py
index 457c56d1..949e4f09 100644
--- a/mlforecast/grouped_array.py
+++ b/mlforecast/grouped_array.py
@@ -79,6 +79,23 @@ def _restore_fitted_difference(diffs_data, diffs_indptr, data, indptr, d):
serie[j] += diffs_data[diffs_indptr[i + 1] - serie.size - d + j]
+@njit
+def _update_difference(
+ d: int,
+ orig_data: np.ndarray,
+ orig_indptr: np.ndarray,
+ data: np.ndarray,
+ indptr: np.ndarray,
+):
+ n_series = len(indptr) - 1
+ for i in range(n_series):
+ orig = orig_data[orig_indptr[i] : orig_indptr[i + 1]]
+ transformed = data[indptr[i] : indptr[i + 1]]
+ combined = np.append(orig, transformed)
+ data[indptr[i] : indptr[i + 1]] = _diff(combined, d)[-transformed.size :]
+ orig_data[orig_indptr[i] : orig_indptr[i + 1]] = combined[-d:]
+
+
@njit
def _expand_target(data, indptr, max_horizon):
out = np.empty((data.size, max_horizon), dtype=data.dtype)
@@ -264,6 +281,9 @@ def restore_fitted_difference(
d,
)
+ def update_difference(self, d: int, ga: "GroupedArray") -> None:
+ _update_difference(d, self.data, self.indptr, ga.data, ga.indptr)
+
def expand_target(self, max_horizon: int) -> np.ndarray:
return _expand_target(self.data, self.indptr, max_horizon)
diff --git a/mlforecast/target_transforms.py b/mlforecast/target_transforms.py
index fc0cfc92..e6f7617f 100644
--- a/mlforecast/target_transforms.py
+++ b/mlforecast/target_transforms.py
@@ -14,6 +14,7 @@
from sklearn.base import TransformerMixin, clone
from utilsforecast.compat import DataFrame
from utilsforecast.target_transforms import (
+ BaseTargetTransform as UtilsTargetTransform,
LocalBoxCox as BoxCox,
LocalMinMaxScaler as MinMaxScaler,
LocalRobustScaler as RobustScaler,
@@ -22,6 +23,7 @@
_transform,
)
+from .compat import CORE_INSTALLED, CoreGroupedArray, core_scalers
from .grouped_array import GroupedArray, _apply_difference
from .utils import _ShortSeriesException
@@ -34,13 +36,16 @@ def set_column_names(self, id_col: str, time_col: str, target_col: str):
self.time_col = time_col
self.target_col = target_col
+ def update(self, df: DataFrame) -> DataFrame:
+ raise NotImplementedError
+
@abc.abstractmethod
def fit_transform(self, df: DataFrame) -> DataFrame:
- raise NotImplementedError
+ ...
@abc.abstractmethod
def inverse_transform(self, df: DataFrame) -> DataFrame:
- raise NotImplementedError
+ ...
# %% ../nbs/target_transforms.ipynb 6
class BaseGroupedArrayTargetTransform(abc.ABC):
@@ -48,13 +53,17 @@ class BaseGroupedArrayTargetTransform(abc.ABC):
idxs: Optional[np.ndarray] = None
+ @abc.abstractmethod
+ def update(self, ga: GroupedArray) -> GroupedArray:
+ ...
+
@abc.abstractmethod
def fit_transform(self, ga: GroupedArray) -> GroupedArray:
- raise NotImplementedError
+ ...
@abc.abstractmethod
def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
- raise NotImplementedError
+ ...
def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:
return self.inverse_transform(ga)
@@ -89,6 +98,12 @@ def fit_transform(self, ga: GroupedArray) -> GroupedArray:
self.original_values_.append(GroupedArray(new_data, new_indptr))
return ga
+ def update(self, ga: GroupedArray) -> GroupedArray:
+ transformed = copy.copy(ga)
+ for d, orig_ga in zip(self.differences, self.original_values_):
+ orig_ga.update_difference(d, transformed)
+ return transformed
+
def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
ga = copy.copy(ga)
for d, orig_vals_ga in zip(
@@ -109,13 +124,24 @@ def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:
# %% ../nbs/target_transforms.ipynb 10
class BaseLocalScaler(BaseGroupedArrayTargetTransform):
- """Standardizes each serie by subtracting its mean and dividing by its standard deviation."""
-
scaler_factory: type
+ def _is_utils_tfm(self):
+ return isinstance(self.scaler_, UtilsTargetTransform)
+
+ def update(self, ga: GroupedArray) -> GroupedArray:
+ if not self._is_utils_tfm():
+ ga = CoreGroupedArray(ga.data, ga.indptr)
+ return GroupedArray(self.scaler_.transform(ga), ga.indptr)
+
def fit_transform(self, ga: GroupedArray) -> GroupedArray:
self.scaler_ = self.scaler_factory()
- transformed = self.scaler_.fit_transform(ga)
+ if self._is_utils_tfm():
+ transformed = self.scaler_.fit_transform(ga)
+ else:
+ core_ga = CoreGroupedArray(ga.data, ga.indptr)
+ self.scaler_.fit(core_ga)
+ transformed = self.scaler_.transform(core_ga)
return GroupedArray(transformed, ga.indptr)
def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
@@ -124,9 +150,14 @@ def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
stats = stats[self.idxs]
if stats.shape[0] != ga.n_groups:
raise ValueError("Found different number of groups in scaler.")
- transformed = _transform(
- ga.data, ga.indptr, stats, _common_scaler_inverse_transform
- )
+ if self._is_utils_tfm() or self.idxs is not None:
+ # core scalers can't transform a subset
+ transformed = _transform(
+ ga.data, ga.indptr, stats, _common_scaler_inverse_transform
+ )
+ else:
+ core_ga = CoreGroupedArray(ga.data, ga.indptr)
+ transformed = self.scaler_.inverse_transform(core_ga)
return GroupedArray(transformed, ga.indptr)
def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:
@@ -136,13 +167,15 @@ def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:
class LocalStandardScaler(BaseLocalScaler):
"""Standardizes each serie by subtracting its mean and dividing by its standard deviation."""
- scaler_factory = StandardScaler
+ scaler_factory = (
+ core_scalers.LocalStandardScaler if CORE_INSTALLED else StandardScaler
+ )
# %% ../nbs/target_transforms.ipynb 14
class LocalMinMaxScaler(BaseLocalScaler):
"""Scales each serie to be in the [0, 1] interval."""
- scaler_factory = MinMaxScaler
+ scaler_factory = core_scalers.LocalMinMaxScaler if CORE_INSTALLED else MinMaxScaler
# %% ../nbs/target_transforms.ipynb 16
class LocalRobustScaler(BaseLocalScaler):
@@ -155,23 +188,23 @@ class LocalRobustScaler(BaseLocalScaler):
"""
def __init__(self, scale: str):
- self.scaler_factory = lambda: RobustScaler(scale) # type: ignore
+ self.scaler_factory = lambda: core_scalers.LocalRobustScaler(scale) if CORE_INSTALLED else RobustScaler(scale) # type: ignore
# %% ../nbs/target_transforms.ipynb 19
class LocalBoxCox(BaseLocalScaler):
"""Finds the optimum lambda for each serie and applies the Box-Cox transformation"""
def __init__(self):
- self.scaler = BoxCox()
+ self.scaler_ = BoxCox()
def fit_transform(self, ga: GroupedArray) -> GroupedArray:
- return GroupedArray(self.scaler.fit_transform(ga), ga.indptr)
+ return GroupedArray(self.scaler_.fit_transform(ga), ga.indptr)
def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
from scipy.special import inv_boxcox1p
sizes = np.diff(ga.indptr)
- lmbdas = self.scaler.lmbdas_
+ lmbdas = self.scaler_.lmbdas_
if self.idxs is not None:
lmbdas = lmbdas[self.idxs]
lmbdas = np.repeat(lmbdas, sizes, axis=0)
@@ -184,6 +217,11 @@ class GlobalSklearnTransformer(BaseTargetTransform):
def __init__(self, transformer: TransformerMixin):
self.transformer = transformer
+ def update(self, df: pd.DataFrame) -> pd.DataFrame:
+ df = df.copy(deep=False)
+ df[self.target_col] = self.transformer_.transform(df[[self.target_col]].values)
+ return df
+
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.copy(deep=False)
self.transformer_ = clone(self.transformer)
diff --git a/nbs/compat.ipynb b/nbs/compat.ipynb
index b5572270..ba136a10 100644
--- a/nbs/compat.ipynb
+++ b/nbs/compat.ipynb
@@ -20,6 +20,7 @@
"#| export\n",
"try:\n",
" import coreforecast.lag_transforms as core_tfms\n",
+ " import coreforecast.scalers as core_scalers \n",
" from coreforecast.grouped_array import GroupedArray as CoreGroupedArray\n",
" \n",
" from mlforecast.lag_transforms import BaseLagTransform, Lag\n",
@@ -27,6 +28,7 @@
" CORE_INSTALLED = True\n",
"except ImportError:\n",
" core_tfms = None\n",
+ " core_scalers = None\n",
" CoreGroupedArray = None\n",
"\n",
" class BaseLagTransform:\n",
diff --git a/nbs/core.ipynb b/nbs/core.ipynb
index 992f6c5b..37a085f0 100644
--- a/nbs/core.ipynb
+++ b/nbs/core.ipynb
@@ -1046,7 +1046,7 @@
" return df\n",
"\n",
" def _predict_setup(self) -> None:\n",
- " self._ga = copy.copy(self.ga)\n",
+ " self.ga = copy.copy(self._ga)\n",
" if isinstance(self.last_dates, pl_Series):\n",
" self.curr_dates = self.last_dates.clone()\n",
" else:\n",
@@ -1123,7 +1123,7 @@
" raise ValueError(f'horizon must be at most max_horizon ({self.max_horizon})')\n",
" self._predict_setup()\n",
" uids = self._get_future_ids(horizon)\n",
- " starts = ufp.offset_times(self.curr_dates, self.freq, 1) \n",
+ " starts = ufp.offset_times(self.curr_dates, self.freq, 1)\n",
" dates = ufp.time_ranges(starts, self.freq, periods=horizon)\n",
" if isinstance(self.curr_dates, pl_Series):\n",
" df_constructor = pl_DataFrame\n",
@@ -1142,6 +1142,9 @@
" result = ufp.assign_columns(result, name, raw_preds)\n",
" return result\n",
"\n",
+ " def _has_ga_target_tfms(self):\n",
+ " return any(isinstance(tfm, BaseGroupedArrayTargetTransform) for tfm in self.target_transforms)\n",
+ "\n",
" def predict(\n",
" self,\n",
" models: Dict[str, Union[BaseEstimator, List[BaseEstimator]]],\n",
@@ -1202,37 +1205,43 @@
" raise ValueError(msg)\n",
" drop_cols = [self.id_col, self.time_col, '_start', '_end']\n",
" X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(columns=drop_cols)\n",
- " if getattr(self, 'max_horizon', None) is None:\n",
- " preds = self._predict_recursive(\n",
- " models=models,\n",
- " horizon=horizon,\n",
- " before_predict_callback=before_predict_callback,\n",
- " after_predict_callback=after_predict_callback,\n",
- " X_df=X_df,\n",
- " )\n",
- " else:\n",
- " preds = self._predict_multi(\n",
- " models=models,\n",
- " horizon=horizon,\n",
- " before_predict_callback=before_predict_callback,\n",
- " X_df=X_df,\n",
- " )\n",
+ " # backup original series. the ga attribute gets modified\n",
+ " # and is copied from _ga at the start of each model's predict\n",
+ " self._ga = copy.copy(self.ga)\n",
+ " try: \n",
+ " if getattr(self, 'max_horizon', None) is None:\n",
+ " preds = self._predict_recursive(\n",
+ " models=models,\n",
+ " horizon=horizon,\n",
+ " before_predict_callback=before_predict_callback,\n",
+ " after_predict_callback=after_predict_callback,\n",
+ " X_df=X_df,\n",
+ " )\n",
+ " else:\n",
+ " preds = self._predict_multi(\n",
+ " models=models,\n",
+ " horizon=horizon,\n",
+ " before_predict_callback=before_predict_callback,\n",
+ " X_df=X_df,\n",
+ " )\n",
+ " finally:\n",
+ " self.ga = self._ga\n",
+ " del self._ga \n",
" if self.target_transforms is not None:\n",
- " if any(isinstance(tfm, BaseGroupedArrayTargetTransform) for tfm in self.target_transforms):\n",
+ " if self._has_ga_target_tfms():\n",
" model_cols = [c for c in preds.columns if c not in (self.id_col, self.time_col)]\n",
" indptr = np.arange(0, horizon * (len(self._uids) + 1), horizon)\n",
" for tfm in self.target_transforms[::-1]:\n",
" if isinstance(tfm, BaseGroupedArrayTargetTransform):\n",
" tfm.idxs = self._idxs\n",
" for col in model_cols:\n",
- " ga = GroupedArray(preds[col].to_numpy(), indptr)\n",
+ " ga = GroupedArray(preds[col].to_numpy().astype(self.ga.data.dtype), indptr)\n",
" ga = tfm.inverse_transform(ga)\n",
" preds = ufp.assign_columns(preds, col, ga.data)\n",
" tfm.idxs = None\n",
" else:\n",
" preds = tfm.inverse_transform(preds)\n",
- " self.ga = self._ga\n",
- " del self._uids, self._idxs, self._static_features, self._ga\n",
+ " del self._uids, self._idxs, self._static_features\n",
" return preds\n",
"\n",
" def save(self, path: Union[str, Path]) -> None:\n",
@@ -1254,11 +1263,17 @@
" if isinstance(uids, pd.Index):\n",
" uids = pd.Series(uids)\n",
" uids, new_ids = ufp.match_if_categorical(uids, df[self.id_col])\n",
+ " df = ufp.copy_if_pandas(df, deep=False)\n",
" df = ufp.assign_columns(df, self.id_col, new_ids)\n",
" df = ufp.sort(df, by=[self.id_col, self.time_col])\n",
- " values = df[self.target_col].to_numpy() \n",
+ " values = df[self.target_col].to_numpy()\n",
+ " values = values.astype(self.ga.data.dtype, copy=False)\n",
" id_counts = ufp.counts_by_id(df, self.id_col)\n",
- " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n",
+ " try:\n",
+ " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer_coalesce')\n",
+ " except (KeyError, ValueError):\n",
+ " # pandas raises key error, polars before coalesce raises value error\n",
+ " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n",
" sizes = ufp.fill_null(sizes, {'counts': 0})\n",
" sizes = ufp.sort(sizes, by=self.id_col)\n",
" new_groups = ~ufp.is_in(sizes[self.id_col], uids)\n",
@@ -1270,9 +1285,12 @@
" last_dates = ufp.sort(last_dates, by=self.id_col)\n",
" self.last_dates = ufp.cast(last_dates[self.time_col], self.last_dates.dtype)\n",
" self.uids = ufp.sort(sizes[self.id_col])\n",
- " if isinstance(self.uids, pd.Series):\n",
+ " if isinstance(df, pd.DataFrame):\n",
" self.uids = pd.Index(self.uids)\n",
+ " self.last_dates = pd.Index(self.last_dates)\n",
" if new_groups.any():\n",
+ " if self.target_transforms is not None:\n",
+ " raise ValueError('Can not update target_transforms with new series.')\n",
" new_ids = ufp.filter_with_mask(sizes[self.id_col], new_groups)\n",
" new_ids_df = ufp.filter_with_mask(df, ufp.is_in(df[self.id_col], new_ids))\n",
" new_ids_counts = ufp.counts_by_id(new_ids_df, self.id_col)\n",
@@ -1280,6 +1298,17 @@
" new_statics = new_statics[self.static_features_.columns]\n",
" self.static_features_ = ufp.vertical_concat([self.static_features_, new_statics])\n",
" self.static_features_ = ufp.sort(self.static_features_, self.id_col)\n",
+ " if self.target_transforms is not None:\n",
+ " if self._has_ga_target_tfms(): \n",
+ " indptr = np.append(0, id_counts['counts']).cumsum()\n",
+ " for tfm in self.target_transforms:\n",
+ " if isinstance(tfm, BaseGroupedArrayTargetTransform):\n",
+ " ga = GroupedArray(values, indptr)\n",
+ " ga = tfm.update(ga)\n",
+ " df = ufp.assign_columns(df, self.target_col, ga.data)\n",
+ " else:\n",
+ " df = tfm.update(df)\n",
+ " values = df[self.target_col].to_numpy() \n",
" self.ga = self.ga.append_several(\n",
" new_sizes=sizes['counts'].to_numpy().astype(np.int32),\n",
" new_values=values,\n",
@@ -1549,6 +1578,7 @@
"ts._uids = ts.uids\n",
"ts._idxs = np.arange(len(ts.ga))\n",
"ts._static_features = ts.static_features_\n",
+ "ts._ga = copy.copy(ts.ga)\n",
"ts._predict_setup()\n",
"updates = ts._update_features()\n",
"\n",
@@ -1585,6 +1615,7 @@
"ts._uids = ts.uids\n",
"ts._idxs = np.arange(len(ts.ga))\n",
"ts._static_features = ts.static_features_\n",
+ "ts._ga = copy.copy(ts.ga)\n",
"ts._predict_setup()\n",
"ts._update_features()\n",
"ts._update_y([1.])\n",
@@ -1788,6 +1819,7 @@
"df = ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', keep_last_n=keep_last_n)\n",
"ts._uids = ts.uids\n",
"ts._idxs = np.arange(len(ts.ga))\n",
+ "ts._ga = copy.copy(ts.ga)\n",
"ts._predict_setup()\n",
"\n",
"expected_lags = ['lag7', 'lag14']\n",
@@ -2114,9 +2146,9 @@
"outputs": [],
"source": [
"#| hide\n",
- "class ZerosModel:\n",
+ "class SeasonalNaiveModel:\n",
" def predict(self, X):\n",
- " return np.full(X.shape[0], 0)\n",
+ " return X['lag7']\n",
"\n",
"class NaiveModel:\n",
" def predict(self, X: pd.DataFrame):\n",
@@ -2124,14 +2156,13 @@
"\n",
"two_series = series[series['unique_id'].isin(['id_00', 'id_19'])].copy()\n",
"two_series['unique_id'] = pd.Categorical(two_series['unique_id'], ['id_00', 'id_19'])\n",
- "ts = TimeSeries(freq='D', lags=[1])\n",
+ "ts = TimeSeries(freq='D', lags=[1], date_features=['dayofweek'])\n",
"ts.fit_transform(\n",
" two_series,\n",
" id_col='unique_id',\n",
" time_col='ds',\n",
" target_col='y',\n",
")\n",
- "ts.predict({'zero': ZerosModel()}, 4)\n",
"last_vals_two_series = two_series.groupby('unique_id').tail(1)\n",
"last_val_id0 = last_vals_two_series[lambda x: x['unique_id'].eq('id_00')].copy()\n",
"new_values = last_val_id0.copy()\n",
@@ -2164,32 +2195,57 @@
" .astype(ts.static_features_.dtypes)\n",
" .reset_index(drop=True)\n",
" )\n",
- ")"
+ ")\n",
+ "# with target transforms\n",
+ "ts = TimeSeries(\n",
+ " freq='D',\n",
+ " lags=[7],\n",
+ " target_transforms=[Differences([1, 2]), LocalStandardScaler()],\n",
+ ")\n",
+ "ts.fit_transform(two_series, id_col='unique_id', time_col='ds', target_col='y')\n",
+ "new_values = two_series.groupby('unique_id').tail(7).copy()\n",
+ "new_values['ds'] += 7 * pd.offsets.Day()\n",
+ "orig_last7 = ts.ga.take_from_groups(slice(-7, None)).data\n",
+ "ts.update(new_values)\n",
+ "preds = ts.predict({'SeasonalNaive': SeasonalNaiveModel()}, 7)\n",
+ "np.testing.assert_allclose(\n",
+ " new_values['y'].values,\n",
+ " preds['SeasonalNaive'].values,\n",
+ ")\n",
+ "last7 = ts.ga.take_from_groups(slice(-7, None)).data\n",
+ "assert 0 < np.abs(last7 / orig_last7 - 1).mean() < 0.5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "sys:1: UserWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n"
+ ]
+ }
+ ],
"source": [
"#| hide\n",
"#| polars\n",
"two_series = generate_daily_series(2, n_static_features=2, engine='polars')\n",
- "ts = TimeSeries(freq='1d', lags=[1])\n",
+ "ts = TimeSeries(freq='1d', lags=[1], date_features=['weekday'])\n",
"ts.fit_transform(\n",
" two_series,\n",
" id_col='unique_id',\n",
" time_col='ds',\n",
" target_col='y',\n",
")\n",
- "ts.predict({'zero': ZerosModel()}, 4)\n",
"last_vals_two_series = two_series.join(\n",
" two_series.group_by('unique_id').agg(pl.col('ds').max()), on=['unique_id', 'ds']\n",
")\n",
"last_val_id0 = last_vals_two_series.filter(pl.col('unique_id') == 'id_0')\n",
"new_values = last_val_id0.with_columns(\n",
- " pl.col('unique_id').cast(pl.Utf8),\n",
+ " pl.col('unique_id').cast(pl.Categorical),\n",
" pl.col('ds').dt.offset_by('1d'),\n",
" pl.col('static_0').cast(pl.Int64),\n",
" pl.col('static_1').cast(pl.Int64),\n",
@@ -2200,7 +2256,10 @@
" 'y': [5.0, 6.0],\n",
" 'static_0': [0, 0],\n",
" 'static_1': [1, 1],\n",
- "}).with_columns(pl.col('ds').dt.cast_time_unit('ns'))\n",
+ "}).with_columns(\n",
+ " pl.col('ds').dt.cast_time_unit('ns'),\n",
+ " pl.col('unique_id').cast(pl.Categorical),\n",
+ ")\n",
"new_values = pl.concat([new_values, new_serie])\n",
"ts.update(new_values)\n",
"preds = ts.predict({'Naive': NaiveModel()}, 1)\n",
@@ -2226,7 +2285,25 @@
" .astype(ts.static_features_.to_pandas().dtypes)\n",
" .reset_index(drop=True)\n",
" )\n",
- ")"
+ ")\n",
+ "# with target transforms\n",
+ "ts = TimeSeries(\n",
+ " freq='1d',\n",
+ " lags=[7],\n",
+ " target_transforms=[Differences([1, 2]), LocalStandardScaler()],\n",
+ ")\n",
+ "ts.fit_transform(two_series, id_col='unique_id', time_col='ds', target_col='y')\n",
+ "new_values = two_series.group_by('unique_id').tail(7)\n",
+ "new_values = new_values.with_columns(pl.col('ds').dt.offset_by('7d'))\n",
+ "orig_last7 = ts.ga.take_from_groups(slice(-7, None)).data\n",
+ "ts.update(new_values)\n",
+ "preds = ts.predict({'SeasonalNaive': SeasonalNaiveModel()}, 7)\n",
+ "np.testing.assert_allclose(\n",
+ " new_values['y'].to_numpy(),\n",
+ " preds['SeasonalNaive'].to_numpy(),\n",
+ ")\n",
+ "last7 = ts.ga.take_from_groups(slice(-7, None)).data\n",
+ "assert 0 < np.abs(last7 / orig_last7 - 1).mean() < 0.5"
]
},
{
@@ -2343,6 +2420,36 @@
")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test predict\n",
+ "class Lag1PlusOneModel:\n",
+ " def predict(self, X):\n",
+ " return X['lag1'] + 1\n",
+ "\n",
+ "ts = TimeSeries(freq='D', lags=[1])\n",
+ "for max_horizon in [None, 2]:\n",
+ " if max_horizon is None:\n",
+ " mod1 = Lag1PlusOneModel()\n",
+ " mod2 = Lag1PlusOneModel()\n",
+ " else:\n",
+ " mod1 = [Lag1PlusOneModel() for _ in range(max_horizon)]\n",
+ " mod2 = [Lag1PlusOneModel() for _ in range(max_horizon)]\n",
+ " ts.fit_transform(train, 'unique_id', 'ds', 'y', max_horizon=max_horizon)\n",
+ " # each model gets the correct historic values\n",
+ " preds = ts.predict(models={'mod1': mod1, 'mod2': mod2}, horizon=2)\n",
+ " np.testing.assert_allclose(preds['mod1'], preds['mod2'])\n",
+ " # idempotency\n",
+ " preds2 = ts.predict(models={'mod1': mod1, 'mod2': mod2}, horizon=2)\n",
+ " np.testing.assert_allclose(preds2['mod1'], preds2['mod2'])\n",
+ " pd.testing.assert_frame_equal(preds, preds2)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb
index 6483def4..2aea1ed5 100644
--- a/nbs/distributed.forecast.ipynb
+++ b/nbs/distributed.forecast.ipynb
@@ -485,6 +485,7 @@
" horizon,\n",
" before_predict_callback=None,\n",
" after_predict_callback=None,\n",
+ " X_df=None, \n",
" ) -> Iterable[pd.DataFrame]:\n",
" for serialized_ts, _, serialized_valid in items:\n",
" valid = cloudpickle.loads(serialized_valid)\n",
@@ -494,6 +495,7 @@
" horizon=horizon,\n",
" before_predict_callback=before_predict_callback,\n",
" after_predict_callback=after_predict_callback,\n",
+ " X_df=X_df,\n",
" )\n",
" if valid is not None:\n",
" res = res.merge(valid, how='left')\n",
@@ -510,6 +512,7 @@
" h: int,\n",
" before_predict_callback: Optional[Callable] = None,\n",
" after_predict_callback: Optional[Callable] = None,\n",
+ " X_df: Optional[pd.DataFrame] = None,\n",
" new_df: Optional[fugue.AnyDataFrame] = None,\n",
" ) -> fugue.AnyDataFrame:\n",
" \"\"\"Compute the predictions for the next `horizon` steps.\n",
@@ -526,6 +529,8 @@
" Function to call on the predictions before updating the targets.\n",
" This function will take a pandas Series with the predictions and should return another one with the same structure.\n",
" The series identifier is on the index.\n",
+ " X_df : pandas DataFrame, optional (default=None)\n",
+ " Dataframe with the future exogenous features. Should have the id column and the time column. \n",
" new_df : dask or spark DataFrame, optional (default=None)\n",
" Series data of new observations for which forecasts are to be generated.\n",
" This dataframe should have the same structure as the one used to fit the model, including any features and time series data.\n",
@@ -550,6 +555,8 @@
" else:\n",
" partition_results = self._partition_results\n",
" schema = self._get_predict_schema()\n",
+ " if X_df is not None and not isinstance(X_df, pd.DataFrame):\n",
+ " raise ValueError('`X_df` should be a pandas DataFrame')\n",
" res = fa.transform(\n",
" partition_results,\n",
" DistributedMLForecast._predict,\n",
@@ -558,6 +565,7 @@
" 'horizon': h,\n",
" 'before_predict_callback': before_predict_callback,\n",
" 'after_predict_callback': after_predict_callback,\n",
+ " 'X_df': X_df,\n",
" },\n",
" schema=schema,\n",
" engine=self.engine,\n",
@@ -761,17 +769,20 @@
"text/markdown": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast\n",
"\n",
"> DistributedMLForecast (models,\n",
"> freq:Union[int,str,pandas._libs.tslibs.offsets.Bas\n",
- "> eOffset,NoneType]=None,\n",
- "> lags:Optional[Iterable[int]]=None, lag_transforms:\n",
- "> Optional[Dict[int,List[Union[Callable,Tuple[Callab\n",
- "> le,Any]]]]]=None, date_features:Optional[Iterable[\n",
- "> Union[str,Callable]]]=None, num_threads:int=1, tar\n",
- "> get_transforms:Optional[List[mlforecast.target_tra\n",
- "> nsforms.BaseTargetTransform]]=None, engine=None,\n",
+ "> eOffset], lags:Optional[Iterable[int]]=None, lag_t\n",
+ "> ransforms:Optional[Dict[int,List[Union[Callable,Tu\n",
+ "> ple[Callable,Any]]]]]=None, date_features:Optional\n",
+ "> [Iterable[Union[str,Callable]]]=None,\n",
+ "> num_threads:int=1, target_transforms:Optional[List\n",
+ "> [Union[mlforecast.target_transforms.BaseTargetTran\n",
+ "> sform,mlforecast.target_transforms.BaseGroupedArra\n",
+ "> yTargetTransform]]]=None, engine=None,\n",
"> num_partitions:Optional[int]=None)\n",
"\n",
"Multi backend distributed pipeline"
@@ -779,17 +790,20 @@
"text/plain": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast\n",
"\n",
"> DistributedMLForecast (models,\n",
"> freq:Union[int,str,pandas._libs.tslibs.offsets.Bas\n",
- "> eOffset,NoneType]=None,\n",
- "> lags:Optional[Iterable[int]]=None, lag_transforms:\n",
- "> Optional[Dict[int,List[Union[Callable,Tuple[Callab\n",
- "> le,Any]]]]]=None, date_features:Optional[Iterable[\n",
- "> Union[str,Callable]]]=None, num_threads:int=1, tar\n",
- "> get_transforms:Optional[List[mlforecast.target_tra\n",
- "> nsforms.BaseTargetTransform]]=None, engine=None,\n",
+ "> eOffset], lags:Optional[Iterable[int]]=None, lag_t\n",
+ "> ransforms:Optional[Dict[int,List[Union[Callable,Tu\n",
+ "> ple[Callable,Any]]]]]=None, date_features:Optional\n",
+ "> [Iterable[Union[str,Callable]]]=None,\n",
+ "> num_threads:int=1, target_transforms:Optional[List\n",
+ "> [Union[mlforecast.target_transforms.BaseTargetTran\n",
+ "> sform,mlforecast.target_transforms.BaseGroupedArra\n",
+ "> yTargetTransform]]]=None, engine=None,\n",
"> num_partitions:Optional[int]=None)\n",
"\n",
"Multi backend distributed pipeline"
@@ -815,14 +829,15 @@
"text/markdown": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.fit\n",
"\n",
"> DistributedMLForecast.fit (df:~AnyDataFrame, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True,\n",
- "> keep_last_n:Optional[int]=None,\n",
- "> data:Optional[~AnyDataFrame]=None)\n",
+ "> keep_last_n:Optional[int]=None)\n",
"\n",
"Apply the feature engineering and train the models.\n",
"\n",
@@ -835,20 +850,20 @@
"| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n",
"| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n",
"| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n",
- "| data | Optional | None | |\n",
- "| **Returns** | **DistributedMLForecast** | | **noqa: ARG002** |"
+ "| **Returns** | **DistributedMLForecast** | | **Forecast object with series values and trained models.** |"
],
"text/plain": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.fit\n",
"\n",
"> DistributedMLForecast.fit (df:~AnyDataFrame, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True,\n",
- "> keep_last_n:Optional[int]=None,\n",
- "> data:Optional[~AnyDataFrame]=None)\n",
+ "> keep_last_n:Optional[int]=None)\n",
"\n",
"Apply the feature engineering and train the models.\n",
"\n",
@@ -861,8 +876,7 @@
"| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n",
"| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n",
"| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n",
- "| data | Optional | None | |\n",
- "| **Returns** | **DistributedMLForecast** | | **noqa: ARG002** |"
+ "| **Returns** | **DistributedMLForecast** | | **Forecast object with series values and trained models.** |"
]
},
"execution_count": null,
@@ -885,15 +899,16 @@
"text/markdown": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.predict\n",
"\n",
"> DistributedMLForecast.predict (h:int,\n",
"> before_predict_callback:Optional[Callable]\n",
"> =None, after_predict_callback:Optional[Cal\n",
- "> lable]=None,\n",
- "> new_df:Optional[~AnyDataFrame]=None,\n",
- "> horizon:Optional[int]=None,\n",
- "> new_data:Optional[~AnyDataFrame]=None)\n",
+ "> lable]=None, X_df:Optional[pandas.core.fra\n",
+ "> me.DataFrame]=None,\n",
+ "> new_df:Optional[~AnyDataFrame]=None)\n",
"\n",
"Compute the predictions for the next `horizon` steps.\n",
"\n",
@@ -902,23 +917,23 @@
"| h | int | | Forecast horizon. |\n",
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
+ "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
"| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
- "| horizon | Optional | None | |\n",
- "| new_data | Optional | None | |\n",
"| **Returns** | **AnyDataFrame** | | **Predictions for each serie and timestep, with one column per model.** |"
],
"text/plain": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.predict\n",
"\n",
"> DistributedMLForecast.predict (h:int,\n",
"> before_predict_callback:Optional[Callable]\n",
"> =None, after_predict_callback:Optional[Cal\n",
- "> lable]=None,\n",
- "> new_df:Optional[~AnyDataFrame]=None,\n",
- "> horizon:Optional[int]=None,\n",
- "> new_data:Optional[~AnyDataFrame]=None)\n",
+ "> lable]=None, X_df:Optional[pandas.core.fra\n",
+ "> me.DataFrame]=None,\n",
+ "> new_df:Optional[~AnyDataFrame]=None)\n",
"\n",
"Compute the predictions for the next `horizon` steps.\n",
"\n",
@@ -927,9 +942,8 @@
"| h | int | | Forecast horizon. |\n",
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
+ "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
"| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
- "| horizon | Optional | None | |\n",
- "| new_data | Optional | None | |\n",
"| **Returns** | **AnyDataFrame** | | **Predictions for each serie and timestep, with one column per model.** |"
]
},
@@ -953,6 +967,8 @@
"text/markdown": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.preprocess\n",
"\n",
"> DistributedMLForecast.preprocess (df:~AnyDataFrame,\n",
@@ -960,8 +976,7 @@
"> time_col:str='ds', target_col:str='y', \n",
"> static_features:Optional[List[str]]=Non\n",
"> e, dropna:bool=True,\n",
- "> keep_last_n:Optional[int]=None,\n",
- "> data:Optional[~AnyDataFrame]=None)\n",
+ "> keep_last_n:Optional[int]=None)\n",
"\n",
"Add the features to `data`.\n",
"\n",
@@ -974,12 +989,13 @@
"| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n",
"| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n",
"| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n",
- "| data | Optional | None | |\n",
- "| **Returns** | **AnyDataFrame** | | **noqa: ARG002** |"
+ "| **Returns** | **AnyDataFrame** | | **`df` with added features.** |"
],
"text/plain": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.preprocess\n",
"\n",
"> DistributedMLForecast.preprocess (df:~AnyDataFrame,\n",
@@ -987,8 +1003,7 @@
"> time_col:str='ds', target_col:str='y', \n",
"> static_features:Optional[List[str]]=Non\n",
"> e, dropna:bool=True,\n",
- "> keep_last_n:Optional[int]=None,\n",
- "> data:Optional[~AnyDataFrame]=None)\n",
+ "> keep_last_n:Optional[int]=None)\n",
"\n",
"Add the features to `data`.\n",
"\n",
@@ -1001,8 +1016,7 @@
"| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n",
"| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n",
"| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n",
- "| data | Optional | None | |\n",
- "| **Returns** | **AnyDataFrame** | | **noqa: ARG002** |"
+ "| **Returns** | **AnyDataFrame** | | **`df` with added features.** |"
]
},
"execution_count": null,
@@ -1025,6 +1039,8 @@
"text/markdown": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.cross_validation\n",
"\n",
"> DistributedMLForecast.cross_validation (df:~AnyDataFrame, n_windows:int,\n",
@@ -1039,9 +1055,7 @@
"> allback:Optional[Callable]=None, \n",
"> after_predict_callback:Optional[C\n",
"> allable]=None,\n",
- "> input_size:Optional[int]=None, da\n",
- "> ta:Optional[~AnyDataFrame]=None,\n",
- "> window_size:Optional[int]=None)\n",
+ "> input_size:Optional[int]=None)\n",
"\n",
"Perform time series cross validation.\n",
"Creates `n_windows` splits where each window has `h` test periods,\n",
@@ -1063,13 +1077,13 @@
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
"| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n",
- "| data | Optional | None | |\n",
- "| window_size | Optional | None | |\n",
- "| **Returns** | **AnyDataFrame** | | **noqa: ARG002
noqa: ARG002** |"
+ "| **Returns** | **AnyDataFrame** | | **Predictions for each window with the series id, timestamp, target value and predictions from each model.** |"
],
"text/plain": [
"---\n",
"\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
"### DistributedMLForecast.cross_validation\n",
"\n",
"> DistributedMLForecast.cross_validation (df:~AnyDataFrame, n_windows:int,\n",
@@ -1084,9 +1098,7 @@
"> allback:Optional[Callable]=None, \n",
"> after_predict_callback:Optional[C\n",
"> allable]=None,\n",
- "> input_size:Optional[int]=None, da\n",
- "> ta:Optional[~AnyDataFrame]=None,\n",
- "> window_size:Optional[int]=None)\n",
+ "> input_size:Optional[int]=None)\n",
"\n",
"Perform time series cross validation.\n",
"Creates `n_windows` splits where each window has `h` test periods,\n",
@@ -1108,9 +1120,7 @@
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
"| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n",
- "| data | Optional | None | |\n",
- "| window_size | Optional | None | |\n",
- "| **Returns** | **AnyDataFrame** | | **noqa: ARG002
noqa: ARG002** |"
+ "| **Returns** | **AnyDataFrame** | | **Predictions for each window with the series id, timestamp, target value and predictions from each model.** |"
]
},
"execution_count": null,
diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb
index 5941d6c8..1a8be2ea 100644
--- a/nbs/docs/getting-started/quick_start_distributed.ipynb
+++ b/nbs/docs/getting-started/quick_start_distributed.ipynb
@@ -492,6 +492,41 @@
"pd.testing.assert_frame_equal(preds, preds3)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0d26427d-314e-4589-8567-097ddf5adce1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test X_df\n",
+ "prices = generate_prices_for_series(series)\n",
+ "series_wexog = series.merge(prices, on=['unique_id', 'ds'])\n",
+ "npartitions = 10\n",
+ "partitioned_series_exog = dd.from_pandas(series_wexog.set_index('unique_id'), npartitions=npartitions)\n",
+ "partitioned_series_exog = partitioned_series_exog.map_partitions(lambda df: df.reset_index())\n",
+ "partitioned_series_exog['unique_id'] = partitioned_series_exog['unique_id'].astype(str)\n",
+ "fcst_exog = DistributedMLForecast(\n",
+ " models=models,\n",
+ " freq='D',\n",
+ " lags=[7],\n",
+ " lag_transforms={\n",
+ " 1: [expanding_mean],\n",
+ " 7: [(rolling_mean, 14)]\n",
+ " },\n",
+ " date_features=['dayofweek', 'month'],\n",
+ " num_threads=1,\n",
+ " engine=client,\n",
+ ")\n",
+ "fcst_exog.fit(partitioned_series_exog, static_features=['static_0', 'static_1'])\n",
+ "preds_exog = fcst_exog.predict(h=7, X_df=prices).compute()\n",
+ "full_preds = preds.merge(preds_exog, on=['unique_id', 'ds'], suffixes=('', '_exog'))\n",
+ "for model in ('DaskXGBForecast', 'DaskLGBMForecast'):\n",
+ " pct_diff = abs(1 - full_preds[f'{model}_exog'].div(full_preds[f'{model}']).mean())\n",
+ " assert 0 < pct_diff < 0.1"
+ ]
+ },
{
"cell_type": "markdown",
"id": "502aeadd-2fd5-4d16-8dfb-56a77080c072",
@@ -1754,48 +1789,48 @@
"
\n",
" \n",
" 0 | \n",
- " id_01 | \n",
+ " id_00 | \n",
" 2001-05-01 | \n",
" 124.758319 | \n",
" 152.856125 | \n",
" 2001-04-30 | \n",
- " 117.876479 | \n",
+ " 400.369603 | \n",
"
\n",
" \n",
" 1 | \n",
- " id_01 | \n",
+ " id_00 | \n",
" 2001-05-02 | \n",
" 145.041000 | \n",
" 177.355331 | \n",
" 2001-04-30 | \n",
- " 153.394375 | \n",
+ " 513.321524 | \n",
"
\n",
" \n",
" 2 | \n",
- " id_01 | \n",
+ " id_00 | \n",
" 2001-05-03 | \n",
" 178.838681 | \n",
" 66.459068 | \n",
" 2001-04-30 | \n",
- " 175.337772 | \n",
+ " 39.373177 | \n",
"
\n",
" \n",
" 3 | \n",
- " id_01 | \n",
+ " id_00 | \n",
" 2001-05-04 | \n",
" 27.212783 | \n",
" 94.735237 | \n",
" 2001-04-30 | \n",
- " 13.202898 | \n",
+ " 108.139791 | \n",
"
\n",
" \n",
" 4 | \n",
- " id_01 | \n",
+ " id_00 | \n",
" 2001-05-05 | \n",
" 56.624979 | \n",
" 125.717896 | \n",
" 2001-04-30 | \n",
- " 30.203090 | \n",
+ " 167.265248 | \n",
"
\n",
" \n",
"\n",
diff --git a/nbs/docs/how-to-guides/exogenous_features.ipynb b/nbs/docs/how-to-guides/exogenous_features.ipynb
index 2f91c90d..3d6fb630 100644
--- a/nbs/docs/how-to-guides/exogenous_features.ipynb
+++ b/nbs/docs/how-to-guides/exogenous_features.ipynb
@@ -144,6 +144,14 @@
"series.head()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "782a6bf1-b360-4aff-b37f-b9ed3ee131e5",
+ "metadata": {},
+ "source": [
+ "## Use existing exogenous features"
+ ]
+ },
{
"cell_type": "markdown",
"id": "96fe6893-70c5-4ef5-b318-686c69660a1d",
@@ -369,7 +377,7 @@
{
"data": {
"text/plain": [
- "MLForecast(models=[LGBMRegressor], freq=, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month'], num_threads=2)"
+ "MLForecast(models=[LGBMRegressor], freq=D, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month'], num_threads=2)"
]
},
"execution_count": null,
@@ -522,6 +530,526 @@
"preds.head()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "39f21994-af9a-4376-9e96-1d67d87e6336",
+ "metadata": {},
+ "source": [
+ "## Generating exogenous features"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f9bb4f27-0a3d-43a8-9f48-57e1983e67d7",
+ "metadata": {},
+ "source": [
+ "Nixtla provides some utilities to generate exogenous features for both training and forecasting such as [statsforecast's mstl_decomposition](https://nixtlaverse.nixtla.io/statsforecast/docs/how-to-guides/generating_features.html) or the [transform_exog function](transforming_exog.ipynb). We also have [utilsforecast's fourier function](https://nixtlaverse.nixtla.io/utilsforecast/feature_engineering.html#fourier), which we'll demonstrate here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "625b300c-9a92-4cf4-af09-d8759e1cef85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.linear_model import LinearRegression\n",
+ "from utilsforecast.feature_engineering import fourier"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f482e987-401c-4a75-bcc6-ed2c77021baa",
+ "metadata": {},
+ "source": [
+ "Suppose you start with some data like the one above where we have a couple of static features."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "90b0f9eb-6c07-417f-a5d3-2dc3d16b1d24",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " y | \n",
+ " static_0 | \n",
+ " product_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2000-10-05 | \n",
+ " 39.811983 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2000-10-06 | \n",
+ " 103.274013 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2000-10-07 | \n",
+ " 176.574744 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2000-10-08 | \n",
+ " 258.987900 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2000-10-09 | \n",
+ " 344.940404 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds y static_0 product_id\n",
+ "0 id_00 2000-10-05 39.811983 79 45\n",
+ "1 id_00 2000-10-06 103.274013 79 45\n",
+ "2 id_00 2000-10-07 176.574744 79 45\n",
+ "3 id_00 2000-10-08 258.987900 79 45\n",
+ "4 id_00 2000-10-09 344.940404 79 45"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "series.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7855322f-1d14-4c9d-9b5f-c8d464896c6f",
+ "metadata": {},
+ "source": [
+ "Now we'd like to add some fourier terms to model the seasonality. We can do that with the following:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "535ae0fd-d16a-49b6-8bdb-c750fdd9a5e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transformed_df, future_df = fourier(series, freq='D', season_length=7, k=2, h=7)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7e0c42c7-5d9e-45b4-b7db-19e72086f5d3",
+ "metadata": {},
+ "source": [
+ "This provides an extended training dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3c4c662e-1a35-46b0-a5f1-19e52584b0f5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " y | \n",
+ " static_0 | \n",
+ " product_id | \n",
+ " sin1_7 | \n",
+ " sin2_7 | \n",
+ " cos1_7 | \n",
+ " cos2_7 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2000-10-05 | \n",
+ " 39.811983 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ " 0.781832 | \n",
+ " 0.974928 | \n",
+ " 0.623490 | \n",
+ " -0.222521 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2000-10-06 | \n",
+ " 103.274013 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ " 0.974928 | \n",
+ " -0.433884 | \n",
+ " -0.222521 | \n",
+ " -0.900969 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2000-10-07 | \n",
+ " 176.574744 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ " 0.433884 | \n",
+ " -0.781831 | \n",
+ " -0.900969 | \n",
+ " 0.623490 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2000-10-08 | \n",
+ " 258.987900 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ " -0.433884 | \n",
+ " 0.781832 | \n",
+ " -0.900969 | \n",
+ " 0.623490 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2000-10-09 | \n",
+ " 344.940404 | \n",
+ " 79 | \n",
+ " 45 | \n",
+ " -0.974928 | \n",
+ " 0.433884 | \n",
+ " -0.222521 | \n",
+ " -0.900969 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds y static_0 product_id sin1_7 sin2_7 \\\n",
+ "0 id_00 2000-10-05 39.811983 79 45 0.781832 0.974928 \n",
+ "1 id_00 2000-10-06 103.274013 79 45 0.974928 -0.433884 \n",
+ "2 id_00 2000-10-07 176.574744 79 45 0.433884 -0.781831 \n",
+ "3 id_00 2000-10-08 258.987900 79 45 -0.433884 0.781832 \n",
+ "4 id_00 2000-10-09 344.940404 79 45 -0.974928 0.433884 \n",
+ "\n",
+ " cos1_7 cos2_7 \n",
+ "0 0.623490 -0.222521 \n",
+ "1 -0.222521 -0.900969 \n",
+ "2 -0.900969 0.623490 \n",
+ "3 -0.900969 0.623490 \n",
+ "4 -0.222521 -0.900969 "
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "transformed_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8e311c3c-392e-4f96-90b8-c1095fe2c21e",
+ "metadata": {},
+ "source": [
+ "Along with the future values of the features."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "53b1e42a-9d13-4469-882b-269b21a00964",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " sin1_7 | \n",
+ " sin2_7 | \n",
+ " cos1_7 | \n",
+ " cos2_7 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2001-05-15 | \n",
+ " -0.781828 | \n",
+ " -0.974930 | \n",
+ " 0.623494 | \n",
+ " -0.222511 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2001-05-16 | \n",
+ " 0.000006 | \n",
+ " 0.000011 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2001-05-17 | \n",
+ " 0.781835 | \n",
+ " 0.974925 | \n",
+ " 0.623485 | \n",
+ " -0.222533 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2001-05-18 | \n",
+ " 0.974927 | \n",
+ " -0.433895 | \n",
+ " -0.222527 | \n",
+ " -0.900963 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2001-05-19 | \n",
+ " 0.433878 | \n",
+ " -0.781823 | \n",
+ " -0.900972 | \n",
+ " 0.623500 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds sin1_7 sin2_7 cos1_7 cos2_7\n",
+ "0 id_00 2001-05-15 -0.781828 -0.974930 0.623494 -0.222511\n",
+ "1 id_00 2001-05-16 0.000006 0.000011 1.000000 1.000000\n",
+ "2 id_00 2001-05-17 0.781835 0.974925 0.623485 -0.222533\n",
+ "3 id_00 2001-05-18 0.974927 -0.433895 -0.222527 -0.900963\n",
+ "4 id_00 2001-05-19 0.433878 -0.781823 -0.900972 0.623500"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "future_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8df48b27-7765-4107-aa7d-09c8737669f6",
+ "metadata": {},
+ "source": [
+ "We can now train using only these features (and the static ones)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "53874524-d2d7-40bc-8d82-93575c8c35f0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "MLForecast(models=[LinearRegression], freq=D, lag_features=[], date_features=[], num_threads=1)"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fcst2 = MLForecast(models=LinearRegression(), freq='D')\n",
+ "fcst2.fit(transformed_df, static_features=['static_0', 'product_id'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a16a719d-dfc0-4800-bac5-345ad2dc2681",
+ "metadata": {},
+ "source": [
+ "And provide the future values to the predict method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "666c78ff-3814-44ee-acd3-ae61bd9e771b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " LinearRegression | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2001-05-15 | \n",
+ " 275.822342 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2001-05-16 | \n",
+ " 262.258117 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2001-05-17 | \n",
+ " 238.195850 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2001-05-18 | \n",
+ " 240.997814 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2001-05-19 | \n",
+ " 262.247123 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds LinearRegression\n",
+ "0 id_00 2001-05-15 275.822342\n",
+ "1 id_00 2001-05-16 262.258117\n",
+ "2 id_00 2001-05-17 238.195850\n",
+ "3 id_00 2001-05-18 240.997814\n",
+ "4 id_00 2001-05-19 262.247123"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fcst2.predict(h=7, X_df=future_df).head()"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/nbs/docs/how-to-guides/predict_callbacks.ipynb b/nbs/docs/how-to-guides/predict_callbacks.ipynb
index 6fa5a40e..fa7464d6 100644
--- a/nbs/docs/how-to-guides/predict_callbacks.ipynb
+++ b/nbs/docs/how-to-guides/predict_callbacks.ipynb
@@ -38,6 +38,8 @@
"metadata": {},
"outputs": [],
"source": [
+ "import copy\n",
+ "\n",
"import lightgbm as lgb\n",
"import numpy as np\n",
"from IPython.display import display\n",
@@ -442,6 +444,7 @@
"fcst.ts._uids = fcst.ts.uids\n",
"fcst.ts._idxs = None\n",
"fcst.ts._static_features = fcst.ts.static_features_\n",
+ "fcst.ts._ga = copy.copy(fcst.ts.ga)\n",
"fcst.ts._predict_setup()\n",
"\n",
"for attr in ('head', 'tail'):\n",
diff --git a/nbs/forecast.ipynb b/nbs/forecast.ipynb
index 3ae4ebd1..8000e6ab 100644
--- a/nbs/forecast.ipynb
+++ b/nbs/forecast.ipynb
@@ -625,11 +625,31 @@
" self.fcst_fitted_values_ = fitted_values\n",
" return self\n",
"\n",
- " def forecast_fitted_values(self):\n",
- " \"\"\"Access in-sample predictions.\"\"\"\n",
+ " def forecast_fitted_values(self, level: Optional[List[Union[int, float]]] = None) -> DataFrame:\n",
+ " \"\"\"Access in-sample predictions.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " level : list of ints or floats, optional (default=None)\n",
+ " Confidence levels between 0 and 100 for prediction intervals.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pandas or polars DataFrame\n",
+ " Dataframe with predictions for the training set\n",
+ " \"\"\"\n",
" if not hasattr(self, 'fcst_fitted_values_'):\n",
" raise Exception('Please run the `fit` method using `fitted=True`')\n",
- " return self.fcst_fitted_values_\n",
+ " res = self.fcst_fitted_values_\n",
+ " if level is not None:\n",
+ " res = ufp.add_insample_levels(\n",
+ " res,\n",
+ " models=self.models_.keys(),\n",
+ " level=level,\n",
+ " id_col=self.ts.id_col,\n",
+ " target_col=self.ts.target_col,\n",
+ " )\n",
+ " return res\n",
"\n",
" def make_future_dataframe(self, h: int) -> DataFrame:\n",
" \"\"\"Create a dataframe with all ids and future times in the forecasting horizon.\n",
@@ -1686,9 +1706,16 @@
"\n",
"### MLForecast.forecast_fitted_values\n",
"\n",
- "> MLForecast.forecast_fitted_values ()\n",
+ "> MLForecast.forecast_fitted_values\n",
+ "> (level:Optional[List[Union[int,float]]\n",
+ "> ]=None)\n",
+ "\n",
+ "Access in-sample predictions.\n",
"\n",
- "Access in-sample predictions."
+ "| | **Type** | **Default** | **Details** |\n",
+ "| -- | -------- | ----------- | ----------- |\n",
+ "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n",
+ "| **Returns** | **Union** | | **Dataframe with predictions for the training set** |"
],
"text/plain": [
"---\n",
@@ -1697,9 +1724,16 @@
"\n",
"### MLForecast.forecast_fitted_values\n",
"\n",
- "> MLForecast.forecast_fitted_values ()\n",
+ "> MLForecast.forecast_fitted_values\n",
+ "> (level:Optional[List[Union[int,float]]\n",
+ "> ]=None)\n",
"\n",
- "Access in-sample predictions."
+ "Access in-sample predictions.\n",
+ "\n",
+ "| | **Type** | **Default** | **Details** |\n",
+ "| -- | -------- | ----------- | ----------- |\n",
+ "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n",
+ "| **Returns** | **Union** | | **Dataframe with predictions for the training set** |"
]
},
"execution_count": null,
@@ -1853,6 +1887,185 @@
"fcst.forecast_fitted_values()"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a9e84d1e-5c92-4d3c-99e8-e9d897ac0f6f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " y | \n",
+ " LGBMRegressor | \n",
+ " LGBMRegressor-lo-90 | \n",
+ " LGBMRegressor-hi-90 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " H196 | \n",
+ " 193 | \n",
+ " 12.7 | \n",
+ " 12.671271 | \n",
+ " 12.540634 | \n",
+ " 12.801909 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " H196 | \n",
+ " 194 | \n",
+ " 12.3 | \n",
+ " 12.271271 | \n",
+ " 12.140634 | \n",
+ " 12.401909 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " H196 | \n",
+ " 195 | \n",
+ " 11.9 | \n",
+ " 11.871271 | \n",
+ " 11.740634 | \n",
+ " 12.001909 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " H196 | \n",
+ " 196 | \n",
+ " 11.7 | \n",
+ " 11.671271 | \n",
+ " 11.540634 | \n",
+ " 11.801909 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " H196 | \n",
+ " 197 | \n",
+ " 11.4 | \n",
+ " 11.471271 | \n",
+ " 11.340634 | \n",
+ " 11.601909 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 3067 | \n",
+ " H413 | \n",
+ " 956 | \n",
+ " 59.0 | \n",
+ " 68.280574 | \n",
+ " 58.846640 | \n",
+ " 77.714509 | \n",
+ "
\n",
+ " \n",
+ " 3068 | \n",
+ " H413 | \n",
+ " 957 | \n",
+ " 58.0 | \n",
+ " 70.427570 | \n",
+ " 60.993636 | \n",
+ " 79.861504 | \n",
+ "
\n",
+ " \n",
+ " 3069 | \n",
+ " H413 | \n",
+ " 958 | \n",
+ " 53.0 | \n",
+ " 44.767965 | \n",
+ " 35.334031 | \n",
+ " 54.201899 | \n",
+ "
\n",
+ " \n",
+ " 3070 | \n",
+ " H413 | \n",
+ " 959 | \n",
+ " 38.0 | \n",
+ " 48.691257 | \n",
+ " 39.257323 | \n",
+ " 58.125191 | \n",
+ "
\n",
+ " \n",
+ " 3071 | \n",
+ " H413 | \n",
+ " 960 | \n",
+ " 46.0 | \n",
+ " 46.652238 | \n",
+ " 37.218304 | \n",
+ " 56.086172 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
3072 rows × 6 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds y LGBMRegressor LGBMRegressor-lo-90 \\\n",
+ "0 H196 193 12.7 12.671271 12.540634 \n",
+ "1 H196 194 12.3 12.271271 12.140634 \n",
+ "2 H196 195 11.9 11.871271 11.740634 \n",
+ "3 H196 196 11.7 11.671271 11.540634 \n",
+ "4 H196 197 11.4 11.471271 11.340634 \n",
+ "... ... ... ... ... ... \n",
+ "3067 H413 956 59.0 68.280574 58.846640 \n",
+ "3068 H413 957 58.0 70.427570 60.993636 \n",
+ "3069 H413 958 53.0 44.767965 35.334031 \n",
+ "3070 H413 959 38.0 48.691257 39.257323 \n",
+ "3071 H413 960 46.0 46.652238 37.218304 \n",
+ "\n",
+ " LGBMRegressor-hi-90 \n",
+ "0 12.801909 \n",
+ "1 12.401909 \n",
+ "2 12.001909 \n",
+ "3 11.801909 \n",
+ "4 11.601909 \n",
+ "... ... \n",
+ "3067 77.714509 \n",
+ "3068 79.861504 \n",
+ "3069 54.201899 \n",
+ "3070 58.125191 \n",
+ "3071 56.086172 \n",
+ "\n",
+ "[3072 rows x 6 columns]"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fcst.forecast_fitted_values(level=[90])"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -1930,7 +2143,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
@@ -1960,7 +2173,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
@@ -2519,7 +2732,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
@@ -2551,7 +2764,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
@@ -2871,7 +3084,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
@@ -2890,7 +3103,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
@@ -3015,7 +3228,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
@@ -3068,7 +3281,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
@@ -4009,7 +4222,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
@@ -4018,7 +4231,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
diff --git a/nbs/grouped_array.ipynb b/nbs/grouped_array.ipynb
index 84e1a5aa..981fd3d5 100644
--- a/nbs/grouped_array.ipynb
+++ b/nbs/grouped_array.ipynb
@@ -97,6 +97,16 @@
" serie[j] += diffs_data[diffs_indptr[i + 1] - serie.size - d + j]\n",
"\n",
"@njit\n",
+ "def _update_difference(d: int, orig_data: np.ndarray, orig_indptr: np.ndarray, data: np.ndarray, indptr: np.ndarray):\n",
+ " n_series = len(indptr) - 1\n",
+ " for i in range(n_series):\n",
+ " orig = orig_data[orig_indptr[i] : orig_indptr[i + 1]]\n",
+ " transformed = data[indptr[i] : indptr[i + 1]]\n",
+ " combined = np.append(orig, transformed)\n",
+ " data[indptr[i] : indptr[i + 1]] = _diff(combined, d)[-transformed.size:]\n",
+ " orig_data[orig_indptr[i] : orig_indptr[i + 1]] = combined[-d:]\n",
+ "\n",
+ "@njit\n",
"def _expand_target(data, indptr, max_horizon):\n",
" out = np.empty((data.size, max_horizon), dtype=data.dtype)\n",
" n_series = len(indptr) - 1\n",
@@ -310,6 +320,9 @@
" d,\n",
" )\n",
"\n",
+ " def update_difference(self, d: int, ga: 'GroupedArray') -> None:\n",
+ " _update_difference(d, self.data, self.indptr, ga.data, ga.indptr)\n",
+ "\n",
" def expand_target(self, max_horizon: int) -> np.ndarray:\n",
" return _expand_target(self.data, self.indptr, max_horizon)\n",
" \n",
diff --git a/nbs/target_transforms.ipynb b/nbs/target_transforms.ipynb
index d5878423..8206eeb1 100644
--- a/nbs/target_transforms.ipynb
+++ b/nbs/target_transforms.ipynb
@@ -48,6 +48,7 @@
"from sklearn.base import TransformerMixin, clone\n",
"from utilsforecast.compat import DataFrame\n",
"from utilsforecast.target_transforms import (\n",
+ " BaseTargetTransform as UtilsTargetTransform,\n",
" LocalBoxCox as BoxCox,\n",
" LocalMinMaxScaler as MinMaxScaler,\n",
" LocalRobustScaler as RobustScaler,\n",
@@ -56,6 +57,7 @@
" _transform,\n",
")\n",
"\n",
+ "from mlforecast.compat import CORE_INSTALLED, CoreGroupedArray, core_scalers\n",
"from mlforecast.grouped_array import GroupedArray, _apply_difference\n",
"from mlforecast.utils import _ShortSeriesException"
]
@@ -92,13 +94,16 @@
" self.time_col = time_col\n",
" self.target_col = target_col\n",
"\n",
+ " def update(self, df: DataFrame) -> DataFrame:\n",
+ " raise NotImplementedError\n",
+ "\n",
" @abc.abstractmethod\n",
" def fit_transform(self, df: DataFrame) -> DataFrame:\n",
- " raise NotImplementedError\n",
+ " ...\n",
" \n",
" @abc.abstractmethod\n",
" def inverse_transform(self, df: DataFrame) -> DataFrame:\n",
- " raise NotImplementedError"
+ " ..."
]
},
{
@@ -114,12 +119,16 @@
" idxs: Optional[np.ndarray] = None\n",
"\n",
" @abc.abstractmethod\n",
+ " def update(self, ga: GroupedArray) -> GroupedArray:\n",
+ " ...\n",
+ " \n",
+ " @abc.abstractmethod\n",
" def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n",
- " raise NotImplementedError\n",
+ " ...\n",
" \n",
" @abc.abstractmethod\n",
" def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n",
- " raise NotImplementedError\n",
+ " ...\n",
"\n",
" def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:\n",
" return self.inverse_transform(ga)"
@@ -161,6 +170,12 @@
" self.original_values_.append(GroupedArray(new_data, new_indptr))\n",
" return ga\n",
"\n",
+ " def update(self, ga: GroupedArray) -> GroupedArray:\n",
+ " transformed = copy.copy(ga)\n",
+ " for d, orig_ga in zip(self.differences, self.original_values_):\n",
+ " orig_ga.update_difference(d, transformed)\n",
+ " return transformed\n",
+ "\n",
" def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n",
" ga = copy.copy(ga)\n",
" for d, orig_vals_ga in zip(reversed(self.differences), reversed(self.original_values_)):\n",
@@ -217,6 +232,24 @@
"restored_subs = diffs.inverse_transform_fitted(transformed.take_from_groups(slice(8, None)))\n",
"np.testing.assert_allclose(ga.data[keep_mask], restored_subs.data)\n",
"\n",
+ "# test transform\n",
+ "new_ga = GroupedArray(np.random.rand(10), np.arange(11))\n",
+ "prev_orig = [diffs.original_values_[i].data[::d].copy() for i, d in enumerate(diffs.differences)]\n",
+ "expected = new_ga.data - np.add.reduce(prev_orig)\n",
+ "updates = diffs.update(new_ga)\n",
+ "np.testing.assert_allclose(expected, updates.data)\n",
+ "np.testing.assert_allclose(diffs.original_values_[0].data, new_ga.data)\n",
+ "np.testing.assert_allclose(diffs.original_values_[1].data[1::2], new_ga.data - prev_orig[0])\n",
+ "np.testing.assert_allclose(diffs.original_values_[2].data[4::5], new_ga.data - np.add.reduce(prev_orig[:2]))\n",
+ "# variable sizes\n",
+ "diff1 = Differences([1])\n",
+ "ga = GroupedArray(np.arange(10), np.array([0, 3, 10]))\n",
+ "diff1.fit_transform(ga)\n",
+ "new_ga = GroupedArray(np.arange(4), np.array([0, 1, 4]))\n",
+ "updates = diff1.update(new_ga)\n",
+ "np.testing.assert_allclose(updates.data, np.array([0 - 2, 1 - 9, 2 - 1, 3 - 2]))\n",
+ "np.testing.assert_allclose(diff1.original_values_[0].data, np.array([0, 3]))\n",
+ "\n",
"# short series\n",
"ga = GroupedArray(np.arange(20), np.array([0, 2, 20]))\n",
"test_fail(lambda: diffs.fit_transform(ga), contains=\"[0]\")"
@@ -231,12 +264,24 @@
"source": [
"#| exporti\n",
"class BaseLocalScaler(BaseGroupedArrayTargetTransform):\n",
- " \"\"\"Standardizes each serie by subtracting its mean and dividing by its standard deviation.\"\"\"\n",
" scaler_factory: type\n",
+ "\n",
+ " def _is_utils_tfm(self):\n",
+ " return isinstance(self.scaler_, UtilsTargetTransform)\n",
+ "\n",
+ " def update(self, ga: GroupedArray) -> GroupedArray:\n",
+ " if not self._is_utils_tfm():\n",
+ " ga = CoreGroupedArray(ga.data, ga.indptr)\n",
+ " return GroupedArray(self.scaler_.transform(ga), ga.indptr)\n",
" \n",
" def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n",
" self.scaler_ = self.scaler_factory()\n",
- " transformed = self.scaler_.fit_transform(ga)\n",
+ " if self._is_utils_tfm():\n",
+ " transformed = self.scaler_.fit_transform(ga)\n",
+ " else:\n",
+ " core_ga = CoreGroupedArray(ga.data, ga.indptr)\n",
+ " self.scaler_.fit(core_ga)\n",
+ " transformed = self.scaler_.transform(core_ga)\n",
" return GroupedArray(transformed, ga.indptr)\n",
"\n",
" def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n",
@@ -245,7 +290,12 @@
" stats = stats[self.idxs]\n",
" if stats.shape[0] != ga.n_groups:\n",
" raise ValueError('Found different number of groups in scaler.')\n",
- " transformed = _transform(ga.data, ga.indptr, stats, _common_scaler_inverse_transform)\n",
+ " if self._is_utils_tfm() or self.idxs is not None:\n",
+ " # core scalers can't transform a subset\n",
+ " transformed = _transform(ga.data, ga.indptr, stats, _common_scaler_inverse_transform)\n",
+ " else:\n",
+ " core_ga = CoreGroupedArray(ga.data, ga.indptr)\n",
+ " transformed = self.scaler_.inverse_transform(core_ga)\n",
" return GroupedArray(transformed, ga.indptr)\n",
"\n",
" def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:\n",
@@ -268,13 +318,8 @@
" sc.inverse_transform(transformed).data,\n",
" ga.data,\n",
" )\n",
- " \n",
- " def filter_df(df):\n",
- " return (\n",
- " df[df['unique_id'].isin(['id_0', 'id_7'])]\n",
- " .groupby('unique_id', observed=True)\n",
- " .head(10)\n",
- " )\n",
+ " transformed2 = sc.update(ga)\n",
+ " np.testing.assert_allclose(transformed.data, transformed2.data)\n",
" \n",
" idxs = [0, 7]\n",
" subset = ga.take(idxs)\n",
@@ -296,7 +341,7 @@
"#| export\n",
"class LocalStandardScaler(BaseLocalScaler):\n",
" \"\"\"Standardizes each serie by subtracting its mean and dividing by its standard deviation.\"\"\"\n",
- " scaler_factory = StandardScaler"
+ " scaler_factory = core_scalers.LocalStandardScaler if CORE_INSTALLED else StandardScaler"
]
},
{
@@ -319,7 +364,7 @@
"#| export\n",
"class LocalMinMaxScaler(BaseLocalScaler):\n",
" \"\"\"Scales each serie to be in the [0, 1] interval.\"\"\"\n",
- " scaler_factory = MinMaxScaler"
+ " scaler_factory = core_scalers.LocalMinMaxScaler if CORE_INSTALLED else MinMaxScaler"
]
},
{
@@ -350,7 +395,7 @@
" \"\"\"\n",
"\n",
" def __init__(self, scale: str):\n",
- " self.scaler_factory = lambda: RobustScaler(scale) # type: ignore"
+ " self.scaler_factory = lambda: core_scalers.LocalRobustScaler(scale) if CORE_INSTALLED else RobustScaler(scale) # type: ignore"
]
},
{
@@ -384,16 +429,16 @@
"class LocalBoxCox(BaseLocalScaler):\n",
" \"\"\"Finds the optimum lambda for each serie and applies the Box-Cox transformation\"\"\"\n",
" def __init__(self):\n",
- " self.scaler = BoxCox()\n",
+ " self.scaler_ = BoxCox()\n",
" \n",
" def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n",
- " return GroupedArray(self.scaler.fit_transform(ga), ga.indptr)\n",
+ " return GroupedArray(self.scaler_.fit_transform(ga), ga.indptr)\n",
"\n",
" def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n",
" from scipy.special import inv_boxcox1p\n",
"\n",
" sizes = np.diff(ga.indptr)\n",
- " lmbdas = self.scaler.lmbdas_\n",
+ " lmbdas = self.scaler_.lmbdas_\n",
" if self.idxs is not None:\n",
" lmbdas = lmbdas[self.idxs]\n",
" lmbdas = np.repeat(lmbdas, sizes, axis=0)\n",
@@ -423,6 +468,11 @@
" def __init__(self, transformer: TransformerMixin):\n",
" self.transformer = transformer\n",
"\n",
+ " def update(self, df: pd.DataFrame) -> pd.DataFrame:\n",
+ " df = df.copy(deep=False)\n",
+ " df[self.target_col] = self.transformer_.transform(df[[self.target_col]].values)\n",
+ " return df\n",
+ "\n",
" def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:\n",
" df = df.copy(deep=False)\n",
" self.transformer_ = clone(self.transformer)\n",
diff --git a/settings.ini b/settings.ini
index 54f77f8f..5d80afca 100644
--- a/settings.ini
+++ b/settings.ini
@@ -8,7 +8,7 @@ author = José Morales
author_email = jmoralz92@gmail.com
copyright = Nixtla
branch = main
-version = 0.11.2
+version = 0.11.5
min_python = 3.8
audience = Developers
language = English