diff --git a/.github/workflows/build-docs.yaml b/.github/workflows/build-docs.yaml index 7574408a..469ff80e 100644 --- a/.github/workflows/build-docs.yaml +++ b/.github/workflows/build-docs.yaml @@ -51,6 +51,18 @@ jobs: publish_dir: ./_docs user_name: github-actions[bot] user_email: 41898282+github-actions[bot]@users.noreply.github.com + - name: Trigger mintlify workflow + if: github.event_name == 'push' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.DOCS_WORKFLOW_TOKEN }} + script: | + await github.rest.actions.createWorkflowDispatch({ + owner: 'nixtla', + repo: 'docs', + workflow_id: 'mintlify-action.yml', + ref: 'main', + }); - name: Deploy to Github Pages if: github.event_name == 'push' uses: peaceiris/actions-gh-pages@v3 diff --git a/environment.yml b/environment.yml index 6d762a48..16745300 100644 --- a/environment.yml +++ b/environment.yml @@ -31,5 +31,5 @@ dependencies: - polars - ray<2.8 - triad==0.9.1 - - utilsforecast>=0.0.21 + - utilsforecast>=0.0.24 - xgboost_ray diff --git a/local_environment.yml b/local_environment.yml index d5c711c5..9debf0e1 100644 --- a/local_environment.yml +++ b/local_environment.yml @@ -21,4 +21,4 @@ dependencies: - datasetsforecast - nbdev - polars - - utilsforecast>=0.0.21 + - utilsforecast>=0.0.24 diff --git a/mlforecast/__init__.py b/mlforecast/__init__.py index d9d20db7..e01d5f6d 100644 --- a/mlforecast/__init__.py +++ b/mlforecast/__init__.py @@ -1,3 +1,3 @@ -__version__ = "0.11.2" +__version__ = "0.11.5" __all__ = ['MLForecast'] from mlforecast.forecast import MLForecast diff --git a/mlforecast/_modidx.py b/mlforecast/_modidx.py index 19ca6691..b596715b 100644 --- a/mlforecast/_modidx.py +++ b/mlforecast/_modidx.py @@ -31,6 +31,8 @@ 'mlforecast/core.py'), 'mlforecast.core.TimeSeries._get_raw_predictions': ( 'core.html#timeseries._get_raw_predictions', 'mlforecast/core.py'), + 'mlforecast.core.TimeSeries._has_ga_target_tfms': ( 'core.html#timeseries._has_ga_target_tfms', + 'mlforecast/core.py'), 'mlforecast.core.TimeSeries._predict_multi': ('core.html#timeseries._predict_multi', 'mlforecast/core.py'), 'mlforecast.core.TimeSeries._predict_recursive': ( 'core.html#timeseries._predict_recursive', 'mlforecast/core.py'), @@ -194,6 +196,8 @@ 'mlforecast/grouped_array.py'), 'mlforecast.grouped_array.GroupedArray.take_from_groups': ( 'grouped_array.html#groupedarray.take_from_groups', 'mlforecast/grouped_array.py'), + 'mlforecast.grouped_array.GroupedArray.update_difference': ( 'grouped_array.html#groupedarray.update_difference', + 'mlforecast/grouped_array.py'), 'mlforecast.grouped_array._append_one': ( 'grouped_array.html#_append_one', 'mlforecast/grouped_array.py'), 'mlforecast.grouped_array._append_several': ( 'grouped_array.html#_append_several', @@ -208,7 +212,9 @@ 'mlforecast.grouped_array._restore_fitted_difference': ( 'grouped_array.html#_restore_fitted_difference', 'mlforecast/grouped_array.py'), 'mlforecast.grouped_array._transform_series': ( 'grouped_array.html#_transform_series', - 'mlforecast/grouped_array.py')}, + 'mlforecast/grouped_array.py'), + 'mlforecast.grouped_array._update_difference': ( 'grouped_array.html#_update_difference', + 'mlforecast/grouped_array.py')}, 'mlforecast.lag_transforms': { 'mlforecast.lag_transforms.BaseLagTransform': ( 'lag_transforms.html#baselagtransform', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.BaseLagTransform.transform': ( 'lag_transforms.html#baselagtransform.transform', @@ -315,14 +321,20 @@ 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseGroupedArrayTargetTransform.inverse_transform_fitted': ( 'target_transforms.html#basegroupedarraytargettransform.inverse_transform_fitted', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.BaseGroupedArrayTargetTransform.update': ( 'target_transforms.html#basegroupedarraytargettransform.update', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseLocalScaler': ( 'target_transforms.html#baselocalscaler', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.BaseLocalScaler._is_utils_tfm': ( 'target_transforms.html#baselocalscaler._is_utils_tfm', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseLocalScaler.fit_transform': ( 'target_transforms.html#baselocalscaler.fit_transform', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseLocalScaler.inverse_transform': ( 'target_transforms.html#baselocalscaler.inverse_transform', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseLocalScaler.inverse_transform_fitted': ( 'target_transforms.html#baselocalscaler.inverse_transform_fitted', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.BaseLocalScaler.update': ( 'target_transforms.html#baselocalscaler.update', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseTargetTransform': ( 'target_transforms.html#basetargettransform', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseTargetTransform.fit_transform': ( 'target_transforms.html#basetargettransform.fit_transform', @@ -331,6 +343,8 @@ 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.BaseTargetTransform.set_column_names': ( 'target_transforms.html#basetargettransform.set_column_names', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.BaseTargetTransform.update': ( 'target_transforms.html#basetargettransform.update', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.Differences': ( 'target_transforms.html#differences', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.Differences.__init__': ( 'target_transforms.html#differences.__init__', @@ -341,6 +355,8 @@ 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.Differences.inverse_transform_fitted': ( 'target_transforms.html#differences.inverse_transform_fitted', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.Differences.update': ( 'target_transforms.html#differences.update', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.GlobalSklearnTransformer': ( 'target_transforms.html#globalsklearntransformer', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.GlobalSklearnTransformer.__init__': ( 'target_transforms.html#globalsklearntransformer.__init__', @@ -349,6 +365,8 @@ 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.GlobalSklearnTransformer.inverse_transform': ( 'target_transforms.html#globalsklearntransformer.inverse_transform', 'mlforecast/target_transforms.py'), + 'mlforecast.target_transforms.GlobalSklearnTransformer.update': ( 'target_transforms.html#globalsklearntransformer.update', + 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.LocalBoxCox': ( 'target_transforms.html#localboxcox', 'mlforecast/target_transforms.py'), 'mlforecast.target_transforms.LocalBoxCox.__init__': ( 'target_transforms.html#localboxcox.__init__', diff --git a/mlforecast/compat.py b/mlforecast/compat.py index edc93d7b..1bd565b1 100644 --- a/mlforecast/compat.py +++ b/mlforecast/compat.py @@ -6,6 +6,7 @@ # %% ../nbs/compat.ipynb 1 try: import coreforecast.lag_transforms as core_tfms + import coreforecast.scalers as core_scalers from coreforecast.grouped_array import GroupedArray as CoreGroupedArray from mlforecast.lag_transforms import BaseLagTransform, Lag @@ -13,6 +14,7 @@ CORE_INSTALLED = True except ImportError: core_tfms = None + core_scalers = None CoreGroupedArray = None class BaseLagTransform: diff --git a/mlforecast/core.py b/mlforecast/core.py index 6f9efa13..87d2528e 100644 --- a/mlforecast/core.py +++ b/mlforecast/core.py @@ -561,7 +561,7 @@ def _get_predictions(self) -> DataFrame: return df def _predict_setup(self) -> None: - self._ga = copy.copy(self.ga) + self.ga = copy.copy(self._ga) if isinstance(self.last_dates, pl_Series): self.curr_dates = self.last_dates.clone() else: @@ -657,6 +657,12 @@ def _predict_multi( result = ufp.assign_columns(result, name, raw_preds) return result + def _has_ga_target_tfms(self): + return any( + isinstance(tfm, BaseGroupedArrayTargetTransform) + for tfm in self.target_transforms + ) + def predict( self, models: Dict[str, Union[BaseEstimator, List[BaseEstimator]]], @@ -723,26 +729,30 @@ def predict( raise ValueError(msg) drop_cols = [self.id_col, self.time_col, "_start", "_end"] X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(columns=drop_cols) - if getattr(self, "max_horizon", None) is None: - preds = self._predict_recursive( - models=models, - horizon=horizon, - before_predict_callback=before_predict_callback, - after_predict_callback=after_predict_callback, - X_df=X_df, - ) - else: - preds = self._predict_multi( - models=models, - horizon=horizon, - before_predict_callback=before_predict_callback, - X_df=X_df, - ) + # backup original series. the ga attribute gets modified + # and is copied from _ga at the start of each model's predict + self._ga = copy.copy(self.ga) + try: + if getattr(self, "max_horizon", None) is None: + preds = self._predict_recursive( + models=models, + horizon=horizon, + before_predict_callback=before_predict_callback, + after_predict_callback=after_predict_callback, + X_df=X_df, + ) + else: + preds = self._predict_multi( + models=models, + horizon=horizon, + before_predict_callback=before_predict_callback, + X_df=X_df, + ) + finally: + self.ga = self._ga + del self._ga if self.target_transforms is not None: - if any( - isinstance(tfm, BaseGroupedArrayTargetTransform) - for tfm in self.target_transforms - ): + if self._has_ga_target_tfms(): model_cols = [ c for c in preds.columns if c not in (self.id_col, self.time_col) ] @@ -751,14 +761,15 @@ def predict( if isinstance(tfm, BaseGroupedArrayTargetTransform): tfm.idxs = self._idxs for col in model_cols: - ga = GroupedArray(preds[col].to_numpy(), indptr) + ga = GroupedArray( + preds[col].to_numpy().astype(self.ga.data.dtype), indptr + ) ga = tfm.inverse_transform(ga) preds = ufp.assign_columns(preds, col, ga.data) tfm.idxs = None else: preds = tfm.inverse_transform(preds) - self.ga = self._ga - del self._uids, self._idxs, self._static_features, self._ga + del self._uids, self._idxs, self._static_features return preds def save(self, path: Union[str, Path]) -> None: @@ -778,11 +789,17 @@ def update(self, df: DataFrame) -> None: if isinstance(uids, pd.Index): uids = pd.Series(uids) uids, new_ids = ufp.match_if_categorical(uids, df[self.id_col]) + df = ufp.copy_if_pandas(df, deep=False) df = ufp.assign_columns(df, self.id_col, new_ids) df = ufp.sort(df, by=[self.id_col, self.time_col]) values = df[self.target_col].to_numpy() + values = values.astype(self.ga.data.dtype, copy=False) id_counts = ufp.counts_by_id(df, self.id_col) - sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer") + try: + sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer_coalesce") + except (KeyError, ValueError): + # pandas raises key error, polars before coalesce raises value error + sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer") sizes = ufp.fill_null(sizes, {"counts": 0}) sizes = ufp.sort(sizes, by=self.id_col) new_groups = ~ufp.is_in(sizes[self.id_col], uids) @@ -794,9 +811,12 @@ def update(self, df: DataFrame) -> None: last_dates = ufp.sort(last_dates, by=self.id_col) self.last_dates = ufp.cast(last_dates[self.time_col], self.last_dates.dtype) self.uids = ufp.sort(sizes[self.id_col]) - if isinstance(self.uids, pd.Series): + if isinstance(df, pd.DataFrame): self.uids = pd.Index(self.uids) + self.last_dates = pd.Index(self.last_dates) if new_groups.any(): + if self.target_transforms is not None: + raise ValueError("Can not update target_transforms with new series.") new_ids = ufp.filter_with_mask(sizes[self.id_col], new_groups) new_ids_df = ufp.filter_with_mask(df, ufp.is_in(df[self.id_col], new_ids)) new_ids_counts = ufp.counts_by_id(new_ids_df, self.id_col) @@ -808,6 +828,17 @@ def update(self, df: DataFrame) -> None: [self.static_features_, new_statics] ) self.static_features_ = ufp.sort(self.static_features_, self.id_col) + if self.target_transforms is not None: + if self._has_ga_target_tfms(): + indptr = np.append(0, id_counts["counts"]).cumsum() + for tfm in self.target_transforms: + if isinstance(tfm, BaseGroupedArrayTargetTransform): + ga = GroupedArray(values, indptr) + ga = tfm.update(ga) + df = ufp.assign_columns(df, self.target_col, ga.data) + else: + df = tfm.update(df) + values = df[self.target_col].to_numpy() self.ga = self.ga.append_several( new_sizes=sizes["counts"].to_numpy().astype(np.int32), new_values=values, diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index cfede794..28a55a3c 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -431,6 +431,7 @@ def _predict( horizon, before_predict_callback=None, after_predict_callback=None, + X_df=None, ) -> Iterable[pd.DataFrame]: for serialized_ts, _, serialized_valid in items: valid = cloudpickle.loads(serialized_valid) @@ -440,6 +441,7 @@ def _predict( horizon=horizon, before_predict_callback=before_predict_callback, after_predict_callback=after_predict_callback, + X_df=X_df, ) if valid is not None: res = res.merge(valid, how="left") @@ -459,6 +461,7 @@ def predict( h: int, before_predict_callback: Optional[Callable] = None, after_predict_callback: Optional[Callable] = None, + X_df: Optional[pd.DataFrame] = None, new_df: Optional[fugue.AnyDataFrame] = None, ) -> fugue.AnyDataFrame: """Compute the predictions for the next `horizon` steps. @@ -475,6 +478,8 @@ def predict( Function to call on the predictions before updating the targets. This function will take a pandas Series with the predictions and should return another one with the same structure. The series identifier is on the index. + X_df : pandas DataFrame, optional (default=None) + Dataframe with the future exogenous features. Should have the id column and the time column. new_df : dask or spark DataFrame, optional (default=None) Series data of new observations for which forecasts are to be generated. This dataframe should have the same structure as the one used to fit the model, including any features and time series data. @@ -499,6 +504,8 @@ def predict( else: partition_results = self._partition_results schema = self._get_predict_schema() + if X_df is not None and not isinstance(X_df, pd.DataFrame): + raise ValueError("`X_df` should be a pandas DataFrame") res = fa.transform( partition_results, DistributedMLForecast._predict, @@ -507,6 +514,7 @@ def predict( "horizon": h, "before_predict_callback": before_predict_callback, "after_predict_callback": after_predict_callback, + "X_df": X_df, }, schema=schema, engine=self.engine, diff --git a/mlforecast/forecast.py b/mlforecast/forecast.py index 56409602..e359be3c 100644 --- a/mlforecast/forecast.py +++ b/mlforecast/forecast.py @@ -536,11 +536,33 @@ def fit( self.fcst_fitted_values_ = fitted_values return self - def forecast_fitted_values(self): - """Access in-sample predictions.""" + def forecast_fitted_values( + self, level: Optional[List[Union[int, float]]] = None + ) -> DataFrame: + """Access in-sample predictions. + + Parameters + ---------- + level : list of ints or floats, optional (default=None) + Confidence levels between 0 and 100 for prediction intervals. + + Returns + ------- + pandas or polars DataFrame + Dataframe with predictions for the training set + """ if not hasattr(self, "fcst_fitted_values_"): raise Exception("Please run the `fit` method using `fitted=True`") - return self.fcst_fitted_values_ + res = self.fcst_fitted_values_ + if level is not None: + res = ufp.add_insample_levels( + res, + models=self.models_.keys(), + level=level, + id_col=self.ts.id_col, + target_col=self.ts.target_col, + ) + return res def make_future_dataframe(self, h: int) -> DataFrame: """Create a dataframe with all ids and future times in the forecasting horizon. diff --git a/mlforecast/grouped_array.py b/mlforecast/grouped_array.py index 457c56d1..949e4f09 100644 --- a/mlforecast/grouped_array.py +++ b/mlforecast/grouped_array.py @@ -79,6 +79,23 @@ def _restore_fitted_difference(diffs_data, diffs_indptr, data, indptr, d): serie[j] += diffs_data[diffs_indptr[i + 1] - serie.size - d + j] +@njit +def _update_difference( + d: int, + orig_data: np.ndarray, + orig_indptr: np.ndarray, + data: np.ndarray, + indptr: np.ndarray, +): + n_series = len(indptr) - 1 + for i in range(n_series): + orig = orig_data[orig_indptr[i] : orig_indptr[i + 1]] + transformed = data[indptr[i] : indptr[i + 1]] + combined = np.append(orig, transformed) + data[indptr[i] : indptr[i + 1]] = _diff(combined, d)[-transformed.size :] + orig_data[orig_indptr[i] : orig_indptr[i + 1]] = combined[-d:] + + @njit def _expand_target(data, indptr, max_horizon): out = np.empty((data.size, max_horizon), dtype=data.dtype) @@ -264,6 +281,9 @@ def restore_fitted_difference( d, ) + def update_difference(self, d: int, ga: "GroupedArray") -> None: + _update_difference(d, self.data, self.indptr, ga.data, ga.indptr) + def expand_target(self, max_horizon: int) -> np.ndarray: return _expand_target(self.data, self.indptr, max_horizon) diff --git a/mlforecast/target_transforms.py b/mlforecast/target_transforms.py index fc0cfc92..e6f7617f 100644 --- a/mlforecast/target_transforms.py +++ b/mlforecast/target_transforms.py @@ -14,6 +14,7 @@ from sklearn.base import TransformerMixin, clone from utilsforecast.compat import DataFrame from utilsforecast.target_transforms import ( + BaseTargetTransform as UtilsTargetTransform, LocalBoxCox as BoxCox, LocalMinMaxScaler as MinMaxScaler, LocalRobustScaler as RobustScaler, @@ -22,6 +23,7 @@ _transform, ) +from .compat import CORE_INSTALLED, CoreGroupedArray, core_scalers from .grouped_array import GroupedArray, _apply_difference from .utils import _ShortSeriesException @@ -34,13 +36,16 @@ def set_column_names(self, id_col: str, time_col: str, target_col: str): self.time_col = time_col self.target_col = target_col + def update(self, df: DataFrame) -> DataFrame: + raise NotImplementedError + @abc.abstractmethod def fit_transform(self, df: DataFrame) -> DataFrame: - raise NotImplementedError + ... @abc.abstractmethod def inverse_transform(self, df: DataFrame) -> DataFrame: - raise NotImplementedError + ... # %% ../nbs/target_transforms.ipynb 6 class BaseGroupedArrayTargetTransform(abc.ABC): @@ -48,13 +53,17 @@ class BaseGroupedArrayTargetTransform(abc.ABC): idxs: Optional[np.ndarray] = None + @abc.abstractmethod + def update(self, ga: GroupedArray) -> GroupedArray: + ... + @abc.abstractmethod def fit_transform(self, ga: GroupedArray) -> GroupedArray: - raise NotImplementedError + ... @abc.abstractmethod def inverse_transform(self, ga: GroupedArray) -> GroupedArray: - raise NotImplementedError + ... def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray: return self.inverse_transform(ga) @@ -89,6 +98,12 @@ def fit_transform(self, ga: GroupedArray) -> GroupedArray: self.original_values_.append(GroupedArray(new_data, new_indptr)) return ga + def update(self, ga: GroupedArray) -> GroupedArray: + transformed = copy.copy(ga) + for d, orig_ga in zip(self.differences, self.original_values_): + orig_ga.update_difference(d, transformed) + return transformed + def inverse_transform(self, ga: GroupedArray) -> GroupedArray: ga = copy.copy(ga) for d, orig_vals_ga in zip( @@ -109,13 +124,24 @@ def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray: # %% ../nbs/target_transforms.ipynb 10 class BaseLocalScaler(BaseGroupedArrayTargetTransform): - """Standardizes each serie by subtracting its mean and dividing by its standard deviation.""" - scaler_factory: type + def _is_utils_tfm(self): + return isinstance(self.scaler_, UtilsTargetTransform) + + def update(self, ga: GroupedArray) -> GroupedArray: + if not self._is_utils_tfm(): + ga = CoreGroupedArray(ga.data, ga.indptr) + return GroupedArray(self.scaler_.transform(ga), ga.indptr) + def fit_transform(self, ga: GroupedArray) -> GroupedArray: self.scaler_ = self.scaler_factory() - transformed = self.scaler_.fit_transform(ga) + if self._is_utils_tfm(): + transformed = self.scaler_.fit_transform(ga) + else: + core_ga = CoreGroupedArray(ga.data, ga.indptr) + self.scaler_.fit(core_ga) + transformed = self.scaler_.transform(core_ga) return GroupedArray(transformed, ga.indptr) def inverse_transform(self, ga: GroupedArray) -> GroupedArray: @@ -124,9 +150,14 @@ def inverse_transform(self, ga: GroupedArray) -> GroupedArray: stats = stats[self.idxs] if stats.shape[0] != ga.n_groups: raise ValueError("Found different number of groups in scaler.") - transformed = _transform( - ga.data, ga.indptr, stats, _common_scaler_inverse_transform - ) + if self._is_utils_tfm() or self.idxs is not None: + # core scalers can't transform a subset + transformed = _transform( + ga.data, ga.indptr, stats, _common_scaler_inverse_transform + ) + else: + core_ga = CoreGroupedArray(ga.data, ga.indptr) + transformed = self.scaler_.inverse_transform(core_ga) return GroupedArray(transformed, ga.indptr) def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray: @@ -136,13 +167,15 @@ def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray: class LocalStandardScaler(BaseLocalScaler): """Standardizes each serie by subtracting its mean and dividing by its standard deviation.""" - scaler_factory = StandardScaler + scaler_factory = ( + core_scalers.LocalStandardScaler if CORE_INSTALLED else StandardScaler + ) # %% ../nbs/target_transforms.ipynb 14 class LocalMinMaxScaler(BaseLocalScaler): """Scales each serie to be in the [0, 1] interval.""" - scaler_factory = MinMaxScaler + scaler_factory = core_scalers.LocalMinMaxScaler if CORE_INSTALLED else MinMaxScaler # %% ../nbs/target_transforms.ipynb 16 class LocalRobustScaler(BaseLocalScaler): @@ -155,23 +188,23 @@ class LocalRobustScaler(BaseLocalScaler): """ def __init__(self, scale: str): - self.scaler_factory = lambda: RobustScaler(scale) # type: ignore + self.scaler_factory = lambda: core_scalers.LocalRobustScaler(scale) if CORE_INSTALLED else RobustScaler(scale) # type: ignore # %% ../nbs/target_transforms.ipynb 19 class LocalBoxCox(BaseLocalScaler): """Finds the optimum lambda for each serie and applies the Box-Cox transformation""" def __init__(self): - self.scaler = BoxCox() + self.scaler_ = BoxCox() def fit_transform(self, ga: GroupedArray) -> GroupedArray: - return GroupedArray(self.scaler.fit_transform(ga), ga.indptr) + return GroupedArray(self.scaler_.fit_transform(ga), ga.indptr) def inverse_transform(self, ga: GroupedArray) -> GroupedArray: from scipy.special import inv_boxcox1p sizes = np.diff(ga.indptr) - lmbdas = self.scaler.lmbdas_ + lmbdas = self.scaler_.lmbdas_ if self.idxs is not None: lmbdas = lmbdas[self.idxs] lmbdas = np.repeat(lmbdas, sizes, axis=0) @@ -184,6 +217,11 @@ class GlobalSklearnTransformer(BaseTargetTransform): def __init__(self, transformer: TransformerMixin): self.transformer = transformer + def update(self, df: pd.DataFrame) -> pd.DataFrame: + df = df.copy(deep=False) + df[self.target_col] = self.transformer_.transform(df[[self.target_col]].values) + return df + def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame: df = df.copy(deep=False) self.transformer_ = clone(self.transformer) diff --git a/nbs/compat.ipynb b/nbs/compat.ipynb index b5572270..ba136a10 100644 --- a/nbs/compat.ipynb +++ b/nbs/compat.ipynb @@ -20,6 +20,7 @@ "#| export\n", "try:\n", " import coreforecast.lag_transforms as core_tfms\n", + " import coreforecast.scalers as core_scalers \n", " from coreforecast.grouped_array import GroupedArray as CoreGroupedArray\n", " \n", " from mlforecast.lag_transforms import BaseLagTransform, Lag\n", @@ -27,6 +28,7 @@ " CORE_INSTALLED = True\n", "except ImportError:\n", " core_tfms = None\n", + " core_scalers = None\n", " CoreGroupedArray = None\n", "\n", " class BaseLagTransform:\n", diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 992f6c5b..37a085f0 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -1046,7 +1046,7 @@ " return df\n", "\n", " def _predict_setup(self) -> None:\n", - " self._ga = copy.copy(self.ga)\n", + " self.ga = copy.copy(self._ga)\n", " if isinstance(self.last_dates, pl_Series):\n", " self.curr_dates = self.last_dates.clone()\n", " else:\n", @@ -1123,7 +1123,7 @@ " raise ValueError(f'horizon must be at most max_horizon ({self.max_horizon})')\n", " self._predict_setup()\n", " uids = self._get_future_ids(horizon)\n", - " starts = ufp.offset_times(self.curr_dates, self.freq, 1) \n", + " starts = ufp.offset_times(self.curr_dates, self.freq, 1)\n", " dates = ufp.time_ranges(starts, self.freq, periods=horizon)\n", " if isinstance(self.curr_dates, pl_Series):\n", " df_constructor = pl_DataFrame\n", @@ -1142,6 +1142,9 @@ " result = ufp.assign_columns(result, name, raw_preds)\n", " return result\n", "\n", + " def _has_ga_target_tfms(self):\n", + " return any(isinstance(tfm, BaseGroupedArrayTargetTransform) for tfm in self.target_transforms)\n", + "\n", " def predict(\n", " self,\n", " models: Dict[str, Union[BaseEstimator, List[BaseEstimator]]],\n", @@ -1202,37 +1205,43 @@ " raise ValueError(msg)\n", " drop_cols = [self.id_col, self.time_col, '_start', '_end']\n", " X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(columns=drop_cols)\n", - " if getattr(self, 'max_horizon', None) is None:\n", - " preds = self._predict_recursive(\n", - " models=models,\n", - " horizon=horizon,\n", - " before_predict_callback=before_predict_callback,\n", - " after_predict_callback=after_predict_callback,\n", - " X_df=X_df,\n", - " )\n", - " else:\n", - " preds = self._predict_multi(\n", - " models=models,\n", - " horizon=horizon,\n", - " before_predict_callback=before_predict_callback,\n", - " X_df=X_df,\n", - " )\n", + " # backup original series. the ga attribute gets modified\n", + " # and is copied from _ga at the start of each model's predict\n", + " self._ga = copy.copy(self.ga)\n", + " try: \n", + " if getattr(self, 'max_horizon', None) is None:\n", + " preds = self._predict_recursive(\n", + " models=models,\n", + " horizon=horizon,\n", + " before_predict_callback=before_predict_callback,\n", + " after_predict_callback=after_predict_callback,\n", + " X_df=X_df,\n", + " )\n", + " else:\n", + " preds = self._predict_multi(\n", + " models=models,\n", + " horizon=horizon,\n", + " before_predict_callback=before_predict_callback,\n", + " X_df=X_df,\n", + " )\n", + " finally:\n", + " self.ga = self._ga\n", + " del self._ga \n", " if self.target_transforms is not None:\n", - " if any(isinstance(tfm, BaseGroupedArrayTargetTransform) for tfm in self.target_transforms):\n", + " if self._has_ga_target_tfms():\n", " model_cols = [c for c in preds.columns if c not in (self.id_col, self.time_col)]\n", " indptr = np.arange(0, horizon * (len(self._uids) + 1), horizon)\n", " for tfm in self.target_transforms[::-1]:\n", " if isinstance(tfm, BaseGroupedArrayTargetTransform):\n", " tfm.idxs = self._idxs\n", " for col in model_cols:\n", - " ga = GroupedArray(preds[col].to_numpy(), indptr)\n", + " ga = GroupedArray(preds[col].to_numpy().astype(self.ga.data.dtype), indptr)\n", " ga = tfm.inverse_transform(ga)\n", " preds = ufp.assign_columns(preds, col, ga.data)\n", " tfm.idxs = None\n", " else:\n", " preds = tfm.inverse_transform(preds)\n", - " self.ga = self._ga\n", - " del self._uids, self._idxs, self._static_features, self._ga\n", + " del self._uids, self._idxs, self._static_features\n", " return preds\n", "\n", " def save(self, path: Union[str, Path]) -> None:\n", @@ -1254,11 +1263,17 @@ " if isinstance(uids, pd.Index):\n", " uids = pd.Series(uids)\n", " uids, new_ids = ufp.match_if_categorical(uids, df[self.id_col])\n", + " df = ufp.copy_if_pandas(df, deep=False)\n", " df = ufp.assign_columns(df, self.id_col, new_ids)\n", " df = ufp.sort(df, by=[self.id_col, self.time_col])\n", - " values = df[self.target_col].to_numpy() \n", + " values = df[self.target_col].to_numpy()\n", + " values = values.astype(self.ga.data.dtype, copy=False)\n", " id_counts = ufp.counts_by_id(df, self.id_col)\n", - " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n", + " try:\n", + " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer_coalesce')\n", + " except (KeyError, ValueError):\n", + " # pandas raises key error, polars before coalesce raises value error\n", + " sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n", " sizes = ufp.fill_null(sizes, {'counts': 0})\n", " sizes = ufp.sort(sizes, by=self.id_col)\n", " new_groups = ~ufp.is_in(sizes[self.id_col], uids)\n", @@ -1270,9 +1285,12 @@ " last_dates = ufp.sort(last_dates, by=self.id_col)\n", " self.last_dates = ufp.cast(last_dates[self.time_col], self.last_dates.dtype)\n", " self.uids = ufp.sort(sizes[self.id_col])\n", - " if isinstance(self.uids, pd.Series):\n", + " if isinstance(df, pd.DataFrame):\n", " self.uids = pd.Index(self.uids)\n", + " self.last_dates = pd.Index(self.last_dates)\n", " if new_groups.any():\n", + " if self.target_transforms is not None:\n", + " raise ValueError('Can not update target_transforms with new series.')\n", " new_ids = ufp.filter_with_mask(sizes[self.id_col], new_groups)\n", " new_ids_df = ufp.filter_with_mask(df, ufp.is_in(df[self.id_col], new_ids))\n", " new_ids_counts = ufp.counts_by_id(new_ids_df, self.id_col)\n", @@ -1280,6 +1298,17 @@ " new_statics = new_statics[self.static_features_.columns]\n", " self.static_features_ = ufp.vertical_concat([self.static_features_, new_statics])\n", " self.static_features_ = ufp.sort(self.static_features_, self.id_col)\n", + " if self.target_transforms is not None:\n", + " if self._has_ga_target_tfms(): \n", + " indptr = np.append(0, id_counts['counts']).cumsum()\n", + " for tfm in self.target_transforms:\n", + " if isinstance(tfm, BaseGroupedArrayTargetTransform):\n", + " ga = GroupedArray(values, indptr)\n", + " ga = tfm.update(ga)\n", + " df = ufp.assign_columns(df, self.target_col, ga.data)\n", + " else:\n", + " df = tfm.update(df)\n", + " values = df[self.target_col].to_numpy() \n", " self.ga = self.ga.append_several(\n", " new_sizes=sizes['counts'].to_numpy().astype(np.int32),\n", " new_values=values,\n", @@ -1549,6 +1578,7 @@ "ts._uids = ts.uids\n", "ts._idxs = np.arange(len(ts.ga))\n", "ts._static_features = ts.static_features_\n", + "ts._ga = copy.copy(ts.ga)\n", "ts._predict_setup()\n", "updates = ts._update_features()\n", "\n", @@ -1585,6 +1615,7 @@ "ts._uids = ts.uids\n", "ts._idxs = np.arange(len(ts.ga))\n", "ts._static_features = ts.static_features_\n", + "ts._ga = copy.copy(ts.ga)\n", "ts._predict_setup()\n", "ts._update_features()\n", "ts._update_y([1.])\n", @@ -1788,6 +1819,7 @@ "df = ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', keep_last_n=keep_last_n)\n", "ts._uids = ts.uids\n", "ts._idxs = np.arange(len(ts.ga))\n", + "ts._ga = copy.copy(ts.ga)\n", "ts._predict_setup()\n", "\n", "expected_lags = ['lag7', 'lag14']\n", @@ -2114,9 +2146,9 @@ "outputs": [], "source": [ "#| hide\n", - "class ZerosModel:\n", + "class SeasonalNaiveModel:\n", " def predict(self, X):\n", - " return np.full(X.shape[0], 0)\n", + " return X['lag7']\n", "\n", "class NaiveModel:\n", " def predict(self, X: pd.DataFrame):\n", @@ -2124,14 +2156,13 @@ "\n", "two_series = series[series['unique_id'].isin(['id_00', 'id_19'])].copy()\n", "two_series['unique_id'] = pd.Categorical(two_series['unique_id'], ['id_00', 'id_19'])\n", - "ts = TimeSeries(freq='D', lags=[1])\n", + "ts = TimeSeries(freq='D', lags=[1], date_features=['dayofweek'])\n", "ts.fit_transform(\n", " two_series,\n", " id_col='unique_id',\n", " time_col='ds',\n", " target_col='y',\n", ")\n", - "ts.predict({'zero': ZerosModel()}, 4)\n", "last_vals_two_series = two_series.groupby('unique_id').tail(1)\n", "last_val_id0 = last_vals_two_series[lambda x: x['unique_id'].eq('id_00')].copy()\n", "new_values = last_val_id0.copy()\n", @@ -2164,32 +2195,57 @@ " .astype(ts.static_features_.dtypes)\n", " .reset_index(drop=True)\n", " )\n", - ")" + ")\n", + "# with target transforms\n", + "ts = TimeSeries(\n", + " freq='D',\n", + " lags=[7],\n", + " target_transforms=[Differences([1, 2]), LocalStandardScaler()],\n", + ")\n", + "ts.fit_transform(two_series, id_col='unique_id', time_col='ds', target_col='y')\n", + "new_values = two_series.groupby('unique_id').tail(7).copy()\n", + "new_values['ds'] += 7 * pd.offsets.Day()\n", + "orig_last7 = ts.ga.take_from_groups(slice(-7, None)).data\n", + "ts.update(new_values)\n", + "preds = ts.predict({'SeasonalNaive': SeasonalNaiveModel()}, 7)\n", + "np.testing.assert_allclose(\n", + " new_values['y'].values,\n", + " preds['SeasonalNaive'].values,\n", + ")\n", + "last7 = ts.ga.take_from_groups(slice(-7, None)).data\n", + "assert 0 < np.abs(last7 / orig_last7 - 1).mean() < 0.5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sys:1: UserWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n" + ] + } + ], "source": [ "#| hide\n", "#| polars\n", "two_series = generate_daily_series(2, n_static_features=2, engine='polars')\n", - "ts = TimeSeries(freq='1d', lags=[1])\n", + "ts = TimeSeries(freq='1d', lags=[1], date_features=['weekday'])\n", "ts.fit_transform(\n", " two_series,\n", " id_col='unique_id',\n", " time_col='ds',\n", " target_col='y',\n", ")\n", - "ts.predict({'zero': ZerosModel()}, 4)\n", "last_vals_two_series = two_series.join(\n", " two_series.group_by('unique_id').agg(pl.col('ds').max()), on=['unique_id', 'ds']\n", ")\n", "last_val_id0 = last_vals_two_series.filter(pl.col('unique_id') == 'id_0')\n", "new_values = last_val_id0.with_columns(\n", - " pl.col('unique_id').cast(pl.Utf8),\n", + " pl.col('unique_id').cast(pl.Categorical),\n", " pl.col('ds').dt.offset_by('1d'),\n", " pl.col('static_0').cast(pl.Int64),\n", " pl.col('static_1').cast(pl.Int64),\n", @@ -2200,7 +2256,10 @@ " 'y': [5.0, 6.0],\n", " 'static_0': [0, 0],\n", " 'static_1': [1, 1],\n", - "}).with_columns(pl.col('ds').dt.cast_time_unit('ns'))\n", + "}).with_columns(\n", + " pl.col('ds').dt.cast_time_unit('ns'),\n", + " pl.col('unique_id').cast(pl.Categorical),\n", + ")\n", "new_values = pl.concat([new_values, new_serie])\n", "ts.update(new_values)\n", "preds = ts.predict({'Naive': NaiveModel()}, 1)\n", @@ -2226,7 +2285,25 @@ " .astype(ts.static_features_.to_pandas().dtypes)\n", " .reset_index(drop=True)\n", " )\n", - ")" + ")\n", + "# with target transforms\n", + "ts = TimeSeries(\n", + " freq='1d',\n", + " lags=[7],\n", + " target_transforms=[Differences([1, 2]), LocalStandardScaler()],\n", + ")\n", + "ts.fit_transform(two_series, id_col='unique_id', time_col='ds', target_col='y')\n", + "new_values = two_series.group_by('unique_id').tail(7)\n", + "new_values = new_values.with_columns(pl.col('ds').dt.offset_by('7d'))\n", + "orig_last7 = ts.ga.take_from_groups(slice(-7, None)).data\n", + "ts.update(new_values)\n", + "preds = ts.predict({'SeasonalNaive': SeasonalNaiveModel()}, 7)\n", + "np.testing.assert_allclose(\n", + " new_values['y'].to_numpy(),\n", + " preds['SeasonalNaive'].to_numpy(),\n", + ")\n", + "last7 = ts.ga.take_from_groups(slice(-7, None)).data\n", + "assert 0 < np.abs(last7 / orig_last7 - 1).mean() < 0.5" ] }, { @@ -2343,6 +2420,36 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test predict\n", + "class Lag1PlusOneModel:\n", + " def predict(self, X):\n", + " return X['lag1'] + 1\n", + "\n", + "ts = TimeSeries(freq='D', lags=[1])\n", + "for max_horizon in [None, 2]:\n", + " if max_horizon is None:\n", + " mod1 = Lag1PlusOneModel()\n", + " mod2 = Lag1PlusOneModel()\n", + " else:\n", + " mod1 = [Lag1PlusOneModel() for _ in range(max_horizon)]\n", + " mod2 = [Lag1PlusOneModel() for _ in range(max_horizon)]\n", + " ts.fit_transform(train, 'unique_id', 'ds', 'y', max_horizon=max_horizon)\n", + " # each model gets the correct historic values\n", + " preds = ts.predict(models={'mod1': mod1, 'mod2': mod2}, horizon=2)\n", + " np.testing.assert_allclose(preds['mod1'], preds['mod2'])\n", + " # idempotency\n", + " preds2 = ts.predict(models={'mod1': mod1, 'mod2': mod2}, horizon=2)\n", + " np.testing.assert_allclose(preds2['mod1'], preds2['mod2'])\n", + " pd.testing.assert_frame_equal(preds, preds2)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index 6483def4..2aea1ed5 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -485,6 +485,7 @@ " horizon,\n", " before_predict_callback=None,\n", " after_predict_callback=None,\n", + " X_df=None, \n", " ) -> Iterable[pd.DataFrame]:\n", " for serialized_ts, _, serialized_valid in items:\n", " valid = cloudpickle.loads(serialized_valid)\n", @@ -494,6 +495,7 @@ " horizon=horizon,\n", " before_predict_callback=before_predict_callback,\n", " after_predict_callback=after_predict_callback,\n", + " X_df=X_df,\n", " )\n", " if valid is not None:\n", " res = res.merge(valid, how='left')\n", @@ -510,6 +512,7 @@ " h: int,\n", " before_predict_callback: Optional[Callable] = None,\n", " after_predict_callback: Optional[Callable] = None,\n", + " X_df: Optional[pd.DataFrame] = None,\n", " new_df: Optional[fugue.AnyDataFrame] = None,\n", " ) -> fugue.AnyDataFrame:\n", " \"\"\"Compute the predictions for the next `horizon` steps.\n", @@ -526,6 +529,8 @@ " Function to call on the predictions before updating the targets.\n", " This function will take a pandas Series with the predictions and should return another one with the same structure.\n", " The series identifier is on the index.\n", + " X_df : pandas DataFrame, optional (default=None)\n", + " Dataframe with the future exogenous features. Should have the id column and the time column. \n", " new_df : dask or spark DataFrame, optional (default=None)\n", " Series data of new observations for which forecasts are to be generated.\n", " This dataframe should have the same structure as the one used to fit the model, including any features and time series data.\n", @@ -550,6 +555,8 @@ " else:\n", " partition_results = self._partition_results\n", " schema = self._get_predict_schema()\n", + " if X_df is not None and not isinstance(X_df, pd.DataFrame):\n", + " raise ValueError('`X_df` should be a pandas DataFrame')\n", " res = fa.transform(\n", " partition_results,\n", " DistributedMLForecast._predict,\n", @@ -558,6 +565,7 @@ " 'horizon': h,\n", " 'before_predict_callback': before_predict_callback,\n", " 'after_predict_callback': after_predict_callback,\n", + " 'X_df': X_df,\n", " },\n", " schema=schema,\n", " engine=self.engine,\n", @@ -761,17 +769,20 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast\n", "\n", "> DistributedMLForecast (models,\n", "> freq:Union[int,str,pandas._libs.tslibs.offsets.Bas\n", - "> eOffset,NoneType]=None,\n", - "> lags:Optional[Iterable[int]]=None, lag_transforms:\n", - "> Optional[Dict[int,List[Union[Callable,Tuple[Callab\n", - "> le,Any]]]]]=None, date_features:Optional[Iterable[\n", - "> Union[str,Callable]]]=None, num_threads:int=1, tar\n", - "> get_transforms:Optional[List[mlforecast.target_tra\n", - "> nsforms.BaseTargetTransform]]=None, engine=None,\n", + "> eOffset], lags:Optional[Iterable[int]]=None, lag_t\n", + "> ransforms:Optional[Dict[int,List[Union[Callable,Tu\n", + "> ple[Callable,Any]]]]]=None, date_features:Optional\n", + "> [Iterable[Union[str,Callable]]]=None,\n", + "> num_threads:int=1, target_transforms:Optional[List\n", + "> [Union[mlforecast.target_transforms.BaseTargetTran\n", + "> sform,mlforecast.target_transforms.BaseGroupedArra\n", + "> yTargetTransform]]]=None, engine=None,\n", "> num_partitions:Optional[int]=None)\n", "\n", "Multi backend distributed pipeline" @@ -779,17 +790,20 @@ "text/plain": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast\n", "\n", "> DistributedMLForecast (models,\n", "> freq:Union[int,str,pandas._libs.tslibs.offsets.Bas\n", - "> eOffset,NoneType]=None,\n", - "> lags:Optional[Iterable[int]]=None, lag_transforms:\n", - "> Optional[Dict[int,List[Union[Callable,Tuple[Callab\n", - "> le,Any]]]]]=None, date_features:Optional[Iterable[\n", - "> Union[str,Callable]]]=None, num_threads:int=1, tar\n", - "> get_transforms:Optional[List[mlforecast.target_tra\n", - "> nsforms.BaseTargetTransform]]=None, engine=None,\n", + "> eOffset], lags:Optional[Iterable[int]]=None, lag_t\n", + "> ransforms:Optional[Dict[int,List[Union[Callable,Tu\n", + "> ple[Callable,Any]]]]]=None, date_features:Optional\n", + "> [Iterable[Union[str,Callable]]]=None,\n", + "> num_threads:int=1, target_transforms:Optional[List\n", + "> [Union[mlforecast.target_transforms.BaseTargetTran\n", + "> sform,mlforecast.target_transforms.BaseGroupedArra\n", + "> yTargetTransform]]]=None, engine=None,\n", "> num_partitions:Optional[int]=None)\n", "\n", "Multi backend distributed pipeline" @@ -815,14 +829,15 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.fit\n", "\n", "> DistributedMLForecast.fit (df:~AnyDataFrame, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y',\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True,\n", - "> keep_last_n:Optional[int]=None,\n", - "> data:Optional[~AnyDataFrame]=None)\n", + "> keep_last_n:Optional[int]=None)\n", "\n", "Apply the feature engineering and train the models.\n", "\n", @@ -835,20 +850,20 @@ "| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n", "| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n", "| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n", - "| data | Optional | None | |\n", - "| **Returns** | **DistributedMLForecast** | | **noqa: ARG002** |" + "| **Returns** | **DistributedMLForecast** | | **Forecast object with series values and trained models.** |" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.fit\n", "\n", "> DistributedMLForecast.fit (df:~AnyDataFrame, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y',\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True,\n", - "> keep_last_n:Optional[int]=None,\n", - "> data:Optional[~AnyDataFrame]=None)\n", + "> keep_last_n:Optional[int]=None)\n", "\n", "Apply the feature engineering and train the models.\n", "\n", @@ -861,8 +876,7 @@ "| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n", "| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n", "| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n", - "| data | Optional | None | |\n", - "| **Returns** | **DistributedMLForecast** | | **noqa: ARG002** |" + "| **Returns** | **DistributedMLForecast** | | **Forecast object with series values and trained models.** |" ] }, "execution_count": null, @@ -885,15 +899,16 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.predict\n", "\n", "> DistributedMLForecast.predict (h:int,\n", "> before_predict_callback:Optional[Callable]\n", "> =None, after_predict_callback:Optional[Cal\n", - "> lable]=None,\n", - "> new_df:Optional[~AnyDataFrame]=None,\n", - "> horizon:Optional[int]=None,\n", - "> new_data:Optional[~AnyDataFrame]=None)\n", + "> lable]=None, X_df:Optional[pandas.core.fra\n", + "> me.DataFrame]=None,\n", + "> new_df:Optional[~AnyDataFrame]=None)\n", "\n", "Compute the predictions for the next `horizon` steps.\n", "\n", @@ -902,23 +917,23 @@ "| h | int | | Forecast horizon. |\n", "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", + "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", - "| horizon | Optional | None | |\n", - "| new_data | Optional | None | |\n", "| **Returns** | **AnyDataFrame** | | **Predictions for each serie and timestep, with one column per model.** |" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.predict\n", "\n", "> DistributedMLForecast.predict (h:int,\n", "> before_predict_callback:Optional[Callable]\n", "> =None, after_predict_callback:Optional[Cal\n", - "> lable]=None,\n", - "> new_df:Optional[~AnyDataFrame]=None,\n", - "> horizon:Optional[int]=None,\n", - "> new_data:Optional[~AnyDataFrame]=None)\n", + "> lable]=None, X_df:Optional[pandas.core.fra\n", + "> me.DataFrame]=None,\n", + "> new_df:Optional[~AnyDataFrame]=None)\n", "\n", "Compute the predictions for the next `horizon` steps.\n", "\n", @@ -927,9 +942,8 @@ "| h | int | | Forecast horizon. |\n", "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", + "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", - "| horizon | Optional | None | |\n", - "| new_data | Optional | None | |\n", "| **Returns** | **AnyDataFrame** | | **Predictions for each serie and timestep, with one column per model.** |" ] }, @@ -953,6 +967,8 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.preprocess\n", "\n", "> DistributedMLForecast.preprocess (df:~AnyDataFrame,\n", @@ -960,8 +976,7 @@ "> time_col:str='ds', target_col:str='y', \n", "> static_features:Optional[List[str]]=Non\n", "> e, dropna:bool=True,\n", - "> keep_last_n:Optional[int]=None,\n", - "> data:Optional[~AnyDataFrame]=None)\n", + "> keep_last_n:Optional[int]=None)\n", "\n", "Add the features to `data`.\n", "\n", @@ -974,12 +989,13 @@ "| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n", "| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n", "| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n", - "| data | Optional | None | |\n", - "| **Returns** | **AnyDataFrame** | | **noqa: ARG002** |" + "| **Returns** | **AnyDataFrame** | | **`df` with added features.** |" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.preprocess\n", "\n", "> DistributedMLForecast.preprocess (df:~AnyDataFrame,\n", @@ -987,8 +1003,7 @@ "> time_col:str='ds', target_col:str='y', \n", "> static_features:Optional[List[str]]=Non\n", "> e, dropna:bool=True,\n", - "> keep_last_n:Optional[int]=None,\n", - "> data:Optional[~AnyDataFrame]=None)\n", + "> keep_last_n:Optional[int]=None)\n", "\n", "Add the features to `data`.\n", "\n", @@ -1001,8 +1016,7 @@ "| static_features | Optional | None | Names of the features that are static and will be repeated when forecasting. |\n", "| dropna | bool | True | Drop rows with missing values produced by the transformations. |\n", "| keep_last_n | Optional | None | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. |\n", - "| data | Optional | None | |\n", - "| **Returns** | **AnyDataFrame** | | **noqa: ARG002** |" + "| **Returns** | **AnyDataFrame** | | **`df` with added features.** |" ] }, "execution_count": null, @@ -1025,6 +1039,8 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.cross_validation\n", "\n", "> DistributedMLForecast.cross_validation (df:~AnyDataFrame, n_windows:int,\n", @@ -1039,9 +1055,7 @@ "> allback:Optional[Callable]=None, \n", "> after_predict_callback:Optional[C\n", "> allable]=None,\n", - "> input_size:Optional[int]=None, da\n", - "> ta:Optional[~AnyDataFrame]=None,\n", - "> window_size:Optional[int]=None)\n", + "> input_size:Optional[int]=None)\n", "\n", "Perform time series cross validation.\n", "Creates `n_windows` splits where each window has `h` test periods,\n", @@ -1063,13 +1077,13 @@ "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", "| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n", - "| data | Optional | None | |\n", - "| window_size | Optional | None | |\n", - "| **Returns** | **AnyDataFrame** | | **noqa: ARG002
noqa: ARG002** |" + "| **Returns** | **AnyDataFrame** | | **Predictions for each window with the series id, timestamp, target value and predictions from each model.** |" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "### DistributedMLForecast.cross_validation\n", "\n", "> DistributedMLForecast.cross_validation (df:~AnyDataFrame, n_windows:int,\n", @@ -1084,9 +1098,7 @@ "> allback:Optional[Callable]=None, \n", "> after_predict_callback:Optional[C\n", "> allable]=None,\n", - "> input_size:Optional[int]=None, da\n", - "> ta:Optional[~AnyDataFrame]=None,\n", - "> window_size:Optional[int]=None)\n", + "> input_size:Optional[int]=None)\n", "\n", "Perform time series cross validation.\n", "Creates `n_windows` splits where each window has `h` test periods,\n", @@ -1108,9 +1120,7 @@ "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", "| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n", - "| data | Optional | None | |\n", - "| window_size | Optional | None | |\n", - "| **Returns** | **AnyDataFrame** | | **noqa: ARG002
noqa: ARG002** |" + "| **Returns** | **AnyDataFrame** | | **Predictions for each window with the series id, timestamp, target value and predictions from each model.** |" ] }, "execution_count": null, diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 5941d6c8..1a8be2ea 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -492,6 +492,41 @@ "pd.testing.assert_frame_equal(preds, preds3)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d26427d-314e-4589-8567-097ddf5adce1", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test X_df\n", + "prices = generate_prices_for_series(series)\n", + "series_wexog = series.merge(prices, on=['unique_id', 'ds'])\n", + "npartitions = 10\n", + "partitioned_series_exog = dd.from_pandas(series_wexog.set_index('unique_id'), npartitions=npartitions)\n", + "partitioned_series_exog = partitioned_series_exog.map_partitions(lambda df: df.reset_index())\n", + "partitioned_series_exog['unique_id'] = partitioned_series_exog['unique_id'].astype(str)\n", + "fcst_exog = DistributedMLForecast(\n", + " models=models,\n", + " freq='D',\n", + " lags=[7],\n", + " lag_transforms={\n", + " 1: [expanding_mean],\n", + " 7: [(rolling_mean, 14)]\n", + " },\n", + " date_features=['dayofweek', 'month'],\n", + " num_threads=1,\n", + " engine=client,\n", + ")\n", + "fcst_exog.fit(partitioned_series_exog, static_features=['static_0', 'static_1'])\n", + "preds_exog = fcst_exog.predict(h=7, X_df=prices).compute()\n", + "full_preds = preds.merge(preds_exog, on=['unique_id', 'ds'], suffixes=('', '_exog'))\n", + "for model in ('DaskXGBForecast', 'DaskLGBMForecast'):\n", + " pct_diff = abs(1 - full_preds[f'{model}_exog'].div(full_preds[f'{model}']).mean())\n", + " assert 0 < pct_diff < 0.1" + ] + }, { "cell_type": "markdown", "id": "502aeadd-2fd5-4d16-8dfb-56a77080c072", @@ -1754,48 +1789,48 @@ " \n", " \n", " 0\n", - " id_01\n", + " id_00\n", " 2001-05-01\n", " 124.758319\n", " 152.856125\n", " 2001-04-30\n", - " 117.876479\n", + " 400.369603\n", " \n", " \n", " 1\n", - " id_01\n", + " id_00\n", " 2001-05-02\n", " 145.041000\n", " 177.355331\n", " 2001-04-30\n", - " 153.394375\n", + " 513.321524\n", " \n", " \n", " 2\n", - " id_01\n", + " id_00\n", " 2001-05-03\n", " 178.838681\n", " 66.459068\n", " 2001-04-30\n", - " 175.337772\n", + " 39.373177\n", " \n", " \n", " 3\n", - " id_01\n", + " id_00\n", " 2001-05-04\n", " 27.212783\n", " 94.735237\n", " 2001-04-30\n", - " 13.202898\n", + " 108.139791\n", " \n", " \n", " 4\n", - " id_01\n", + " id_00\n", " 2001-05-05\n", " 56.624979\n", " 125.717896\n", " 2001-04-30\n", - " 30.203090\n", + " 167.265248\n", " \n", " \n", "\n", diff --git a/nbs/docs/how-to-guides/exogenous_features.ipynb b/nbs/docs/how-to-guides/exogenous_features.ipynb index 2f91c90d..3d6fb630 100644 --- a/nbs/docs/how-to-guides/exogenous_features.ipynb +++ b/nbs/docs/how-to-guides/exogenous_features.ipynb @@ -144,6 +144,14 @@ "series.head()" ] }, + { + "cell_type": "markdown", + "id": "782a6bf1-b360-4aff-b37f-b9ed3ee131e5", + "metadata": {}, + "source": [ + "## Use existing exogenous features" + ] + }, { "cell_type": "markdown", "id": "96fe6893-70c5-4ef5-b318-686c69660a1d", @@ -369,7 +377,7 @@ { "data": { "text/plain": [ - "MLForecast(models=[LGBMRegressor], freq=, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month'], num_threads=2)" + "MLForecast(models=[LGBMRegressor], freq=D, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month'], num_threads=2)" ] }, "execution_count": null, @@ -522,6 +530,526 @@ "preds.head()" ] }, + { + "cell_type": "markdown", + "id": "39f21994-af9a-4376-9e96-1d67d87e6336", + "metadata": {}, + "source": [ + "## Generating exogenous features" + ] + }, + { + "cell_type": "markdown", + "id": "f9bb4f27-0a3d-43a8-9f48-57e1983e67d7", + "metadata": {}, + "source": [ + "Nixtla provides some utilities to generate exogenous features for both training and forecasting such as [statsforecast's mstl_decomposition](https://nixtlaverse.nixtla.io/statsforecast/docs/how-to-guides/generating_features.html) or the [transform_exog function](transforming_exog.ipynb). We also have [utilsforecast's fourier function](https://nixtlaverse.nixtla.io/utilsforecast/feature_engineering.html#fourier), which we'll demonstrate here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "625b300c-9a92-4cf4-af09-d8759e1cef85", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from utilsforecast.feature_engineering import fourier" + ] + }, + { + "cell_type": "markdown", + "id": "f482e987-401c-4a75-bcc6-ed2c77021baa", + "metadata": {}, + "source": [ + "Suppose you start with some data like the one above where we have a couple of static features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90b0f9eb-6c07-417f-a5d3-2dc3d16b1d24", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsystatic_0product_id
0id_002000-10-0539.8119837945
1id_002000-10-06103.2740137945
2id_002000-10-07176.5747447945
3id_002000-10-08258.9879007945
4id_002000-10-09344.9404047945
\n", + "
" + ], + "text/plain": [ + " unique_id ds y static_0 product_id\n", + "0 id_00 2000-10-05 39.811983 79 45\n", + "1 id_00 2000-10-06 103.274013 79 45\n", + "2 id_00 2000-10-07 176.574744 79 45\n", + "3 id_00 2000-10-08 258.987900 79 45\n", + "4 id_00 2000-10-09 344.940404 79 45" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "series.head()" + ] + }, + { + "cell_type": "markdown", + "id": "7855322f-1d14-4c9d-9b5f-c8d464896c6f", + "metadata": {}, + "source": [ + "Now we'd like to add some fourier terms to model the seasonality. We can do that with the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "535ae0fd-d16a-49b6-8bdb-c750fdd9a5e1", + "metadata": {}, + "outputs": [], + "source": [ + "transformed_df, future_df = fourier(series, freq='D', season_length=7, k=2, h=7)" + ] + }, + { + "cell_type": "markdown", + "id": "7e0c42c7-5d9e-45b4-b7db-19e72086f5d3", + "metadata": {}, + "source": [ + "This provides an extended training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c4c662e-1a35-46b0-a5f1-19e52584b0f5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsystatic_0product_idsin1_7sin2_7cos1_7cos2_7
0id_002000-10-0539.81198379450.7818320.9749280.623490-0.222521
1id_002000-10-06103.27401379450.974928-0.433884-0.222521-0.900969
2id_002000-10-07176.57474479450.433884-0.781831-0.9009690.623490
3id_002000-10-08258.9879007945-0.4338840.781832-0.9009690.623490
4id_002000-10-09344.9404047945-0.9749280.433884-0.222521-0.900969
\n", + "
" + ], + "text/plain": [ + " unique_id ds y static_0 product_id sin1_7 sin2_7 \\\n", + "0 id_00 2000-10-05 39.811983 79 45 0.781832 0.974928 \n", + "1 id_00 2000-10-06 103.274013 79 45 0.974928 -0.433884 \n", + "2 id_00 2000-10-07 176.574744 79 45 0.433884 -0.781831 \n", + "3 id_00 2000-10-08 258.987900 79 45 -0.433884 0.781832 \n", + "4 id_00 2000-10-09 344.940404 79 45 -0.974928 0.433884 \n", + "\n", + " cos1_7 cos2_7 \n", + "0 0.623490 -0.222521 \n", + "1 -0.222521 -0.900969 \n", + "2 -0.900969 0.623490 \n", + "3 -0.900969 0.623490 \n", + "4 -0.222521 -0.900969 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformed_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "8e311c3c-392e-4f96-90b8-c1095fe2c21e", + "metadata": {}, + "source": [ + "Along with the future values of the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53b1e42a-9d13-4469-882b-269b21a00964", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddssin1_7sin2_7cos1_7cos2_7
0id_002001-05-15-0.781828-0.9749300.623494-0.222511
1id_002001-05-160.0000060.0000111.0000001.000000
2id_002001-05-170.7818350.9749250.623485-0.222533
3id_002001-05-180.974927-0.433895-0.222527-0.900963
4id_002001-05-190.433878-0.781823-0.9009720.623500
\n", + "
" + ], + "text/plain": [ + " unique_id ds sin1_7 sin2_7 cos1_7 cos2_7\n", + "0 id_00 2001-05-15 -0.781828 -0.974930 0.623494 -0.222511\n", + "1 id_00 2001-05-16 0.000006 0.000011 1.000000 1.000000\n", + "2 id_00 2001-05-17 0.781835 0.974925 0.623485 -0.222533\n", + "3 id_00 2001-05-18 0.974927 -0.433895 -0.222527 -0.900963\n", + "4 id_00 2001-05-19 0.433878 -0.781823 -0.900972 0.623500" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "future_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "8df48b27-7765-4107-aa7d-09c8737669f6", + "metadata": {}, + "source": [ + "We can now train using only these features (and the static ones)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53874524-d2d7-40bc-8d82-93575c8c35f0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MLForecast(models=[LinearRegression], freq=D, lag_features=[], date_features=[], num_threads=1)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fcst2 = MLForecast(models=LinearRegression(), freq='D')\n", + "fcst2.fit(transformed_df, static_features=['static_0', 'product_id'])" + ] + }, + { + "cell_type": "markdown", + "id": "a16a719d-dfc0-4800-bac5-345ad2dc2681", + "metadata": {}, + "source": [ + "And provide the future values to the predict method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "666c78ff-3814-44ee-acd3-ae61bd9e771b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsLinearRegression
0id_002001-05-15275.822342
1id_002001-05-16262.258117
2id_002001-05-17238.195850
3id_002001-05-18240.997814
4id_002001-05-19262.247123
\n", + "
" + ], + "text/plain": [ + " unique_id ds LinearRegression\n", + "0 id_00 2001-05-15 275.822342\n", + "1 id_00 2001-05-16 262.258117\n", + "2 id_00 2001-05-17 238.195850\n", + "3 id_00 2001-05-18 240.997814\n", + "4 id_00 2001-05-19 262.247123" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fcst2.predict(h=7, X_df=future_df).head()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/docs/how-to-guides/predict_callbacks.ipynb b/nbs/docs/how-to-guides/predict_callbacks.ipynb index 6fa5a40e..fa7464d6 100644 --- a/nbs/docs/how-to-guides/predict_callbacks.ipynb +++ b/nbs/docs/how-to-guides/predict_callbacks.ipynb @@ -38,6 +38,8 @@ "metadata": {}, "outputs": [], "source": [ + "import copy\n", + "\n", "import lightgbm as lgb\n", "import numpy as np\n", "from IPython.display import display\n", @@ -442,6 +444,7 @@ "fcst.ts._uids = fcst.ts.uids\n", "fcst.ts._idxs = None\n", "fcst.ts._static_features = fcst.ts.static_features_\n", + "fcst.ts._ga = copy.copy(fcst.ts.ga)\n", "fcst.ts._predict_setup()\n", "\n", "for attr in ('head', 'tail'):\n", diff --git a/nbs/forecast.ipynb b/nbs/forecast.ipynb index 3ae4ebd1..8000e6ab 100644 --- a/nbs/forecast.ipynb +++ b/nbs/forecast.ipynb @@ -625,11 +625,31 @@ " self.fcst_fitted_values_ = fitted_values\n", " return self\n", "\n", - " def forecast_fitted_values(self):\n", - " \"\"\"Access in-sample predictions.\"\"\"\n", + " def forecast_fitted_values(self, level: Optional[List[Union[int, float]]] = None) -> DataFrame:\n", + " \"\"\"Access in-sample predictions.\n", + " \n", + " Parameters\n", + " ----------\n", + " level : list of ints or floats, optional (default=None)\n", + " Confidence levels between 0 and 100 for prediction intervals.\n", + "\n", + " Returns\n", + " -------\n", + " pandas or polars DataFrame\n", + " Dataframe with predictions for the training set\n", + " \"\"\"\n", " if not hasattr(self, 'fcst_fitted_values_'):\n", " raise Exception('Please run the `fit` method using `fitted=True`')\n", - " return self.fcst_fitted_values_\n", + " res = self.fcst_fitted_values_\n", + " if level is not None:\n", + " res = ufp.add_insample_levels(\n", + " res,\n", + " models=self.models_.keys(),\n", + " level=level,\n", + " id_col=self.ts.id_col,\n", + " target_col=self.ts.target_col,\n", + " )\n", + " return res\n", "\n", " def make_future_dataframe(self, h: int) -> DataFrame:\n", " \"\"\"Create a dataframe with all ids and future times in the forecasting horizon.\n", @@ -1686,9 +1706,16 @@ "\n", "### MLForecast.forecast_fitted_values\n", "\n", - "> MLForecast.forecast_fitted_values ()\n", + "> MLForecast.forecast_fitted_values\n", + "> (level:Optional[List[Union[int,float]]\n", + "> ]=None)\n", + "\n", + "Access in-sample predictions.\n", "\n", - "Access in-sample predictions." + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n", + "| **Returns** | **Union** | | **Dataframe with predictions for the training set** |" ], "text/plain": [ "---\n", @@ -1697,9 +1724,16 @@ "\n", "### MLForecast.forecast_fitted_values\n", "\n", - "> MLForecast.forecast_fitted_values ()\n", + "> MLForecast.forecast_fitted_values\n", + "> (level:Optional[List[Union[int,float]]\n", + "> ]=None)\n", "\n", - "Access in-sample predictions." + "Access in-sample predictions.\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n", + "| **Returns** | **Union** | | **Dataframe with predictions for the training set** |" ] }, "execution_count": null, @@ -1853,6 +1887,185 @@ "fcst.forecast_fitted_values()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9e84d1e-5c92-4d3c-99e8-e9d897ac0f6f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
unique_iddsyLGBMRegressorLGBMRegressor-lo-90LGBMRegressor-hi-90
0H19619312.712.67127112.54063412.801909
1H19619412.312.27127112.14063412.401909
2H19619511.911.87127111.74063412.001909
3H19619611.711.67127111.54063411.801909
4H19619711.411.47127111.34063411.601909
.....................
3067H41395659.068.28057458.84664077.714509
3068H41395758.070.42757060.99363679.861504
3069H41395853.044.76796535.33403154.201899
3070H41395938.048.69125739.25732358.125191
3071H41396046.046.65223837.21830456.086172
\n", + "

3072 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " unique_id ds y LGBMRegressor LGBMRegressor-lo-90 \\\n", + "0 H196 193 12.7 12.671271 12.540634 \n", + "1 H196 194 12.3 12.271271 12.140634 \n", + "2 H196 195 11.9 11.871271 11.740634 \n", + "3 H196 196 11.7 11.671271 11.540634 \n", + "4 H196 197 11.4 11.471271 11.340634 \n", + "... ... ... ... ... ... \n", + "3067 H413 956 59.0 68.280574 58.846640 \n", + "3068 H413 957 58.0 70.427570 60.993636 \n", + "3069 H413 958 53.0 44.767965 35.334031 \n", + "3070 H413 959 38.0 48.691257 39.257323 \n", + "3071 H413 960 46.0 46.652238 37.218304 \n", + "\n", + " LGBMRegressor-hi-90 \n", + "0 12.801909 \n", + "1 12.401909 \n", + "2 12.001909 \n", + "3 11.801909 \n", + "4 11.601909 \n", + "... ... \n", + "3067 77.714509 \n", + "3068 79.861504 \n", + "3069 54.201899 \n", + "3070 58.125191 \n", + "3071 56.086172 \n", + "\n", + "[3072 rows x 6 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fcst.forecast_fitted_values(level=[90])" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1930,7 +2143,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.predict\n", "\n", @@ -1960,7 +2173,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.predict\n", "\n", @@ -2519,7 +2732,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.preprocess\n", "\n", @@ -2551,7 +2764,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.preprocess\n", "\n", @@ -2871,7 +3084,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit_models\n", "\n", @@ -2890,7 +3103,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit_models\n", "\n", @@ -3015,7 +3228,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.cross_validation\n", "\n", @@ -3068,7 +3281,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.cross_validation\n", "\n", @@ -4009,7 +4222,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.from_cv\n", "\n", @@ -4018,7 +4231,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.from_cv\n", "\n", diff --git a/nbs/grouped_array.ipynb b/nbs/grouped_array.ipynb index 84e1a5aa..981fd3d5 100644 --- a/nbs/grouped_array.ipynb +++ b/nbs/grouped_array.ipynb @@ -97,6 +97,16 @@ " serie[j] += diffs_data[diffs_indptr[i + 1] - serie.size - d + j]\n", "\n", "@njit\n", + "def _update_difference(d: int, orig_data: np.ndarray, orig_indptr: np.ndarray, data: np.ndarray, indptr: np.ndarray):\n", + " n_series = len(indptr) - 1\n", + " for i in range(n_series):\n", + " orig = orig_data[orig_indptr[i] : orig_indptr[i + 1]]\n", + " transformed = data[indptr[i] : indptr[i + 1]]\n", + " combined = np.append(orig, transformed)\n", + " data[indptr[i] : indptr[i + 1]] = _diff(combined, d)[-transformed.size:]\n", + " orig_data[orig_indptr[i] : orig_indptr[i + 1]] = combined[-d:]\n", + "\n", + "@njit\n", "def _expand_target(data, indptr, max_horizon):\n", " out = np.empty((data.size, max_horizon), dtype=data.dtype)\n", " n_series = len(indptr) - 1\n", @@ -310,6 +320,9 @@ " d,\n", " )\n", "\n", + " def update_difference(self, d: int, ga: 'GroupedArray') -> None:\n", + " _update_difference(d, self.data, self.indptr, ga.data, ga.indptr)\n", + "\n", " def expand_target(self, max_horizon: int) -> np.ndarray:\n", " return _expand_target(self.data, self.indptr, max_horizon)\n", " \n", diff --git a/nbs/target_transforms.ipynb b/nbs/target_transforms.ipynb index d5878423..8206eeb1 100644 --- a/nbs/target_transforms.ipynb +++ b/nbs/target_transforms.ipynb @@ -48,6 +48,7 @@ "from sklearn.base import TransformerMixin, clone\n", "from utilsforecast.compat import DataFrame\n", "from utilsforecast.target_transforms import (\n", + " BaseTargetTransform as UtilsTargetTransform,\n", " LocalBoxCox as BoxCox,\n", " LocalMinMaxScaler as MinMaxScaler,\n", " LocalRobustScaler as RobustScaler,\n", @@ -56,6 +57,7 @@ " _transform,\n", ")\n", "\n", + "from mlforecast.compat import CORE_INSTALLED, CoreGroupedArray, core_scalers\n", "from mlforecast.grouped_array import GroupedArray, _apply_difference\n", "from mlforecast.utils import _ShortSeriesException" ] @@ -92,13 +94,16 @@ " self.time_col = time_col\n", " self.target_col = target_col\n", "\n", + " def update(self, df: DataFrame) -> DataFrame:\n", + " raise NotImplementedError\n", + "\n", " @abc.abstractmethod\n", " def fit_transform(self, df: DataFrame) -> DataFrame:\n", - " raise NotImplementedError\n", + " ...\n", " \n", " @abc.abstractmethod\n", " def inverse_transform(self, df: DataFrame) -> DataFrame:\n", - " raise NotImplementedError" + " ..." ] }, { @@ -114,12 +119,16 @@ " idxs: Optional[np.ndarray] = None\n", "\n", " @abc.abstractmethod\n", + " def update(self, ga: GroupedArray) -> GroupedArray:\n", + " ...\n", + " \n", + " @abc.abstractmethod\n", " def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n", - " raise NotImplementedError\n", + " ...\n", " \n", " @abc.abstractmethod\n", " def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n", - " raise NotImplementedError\n", + " ...\n", "\n", " def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:\n", " return self.inverse_transform(ga)" @@ -161,6 +170,12 @@ " self.original_values_.append(GroupedArray(new_data, new_indptr))\n", " return ga\n", "\n", + " def update(self, ga: GroupedArray) -> GroupedArray:\n", + " transformed = copy.copy(ga)\n", + " for d, orig_ga in zip(self.differences, self.original_values_):\n", + " orig_ga.update_difference(d, transformed)\n", + " return transformed\n", + "\n", " def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n", " ga = copy.copy(ga)\n", " for d, orig_vals_ga in zip(reversed(self.differences), reversed(self.original_values_)):\n", @@ -217,6 +232,24 @@ "restored_subs = diffs.inverse_transform_fitted(transformed.take_from_groups(slice(8, None)))\n", "np.testing.assert_allclose(ga.data[keep_mask], restored_subs.data)\n", "\n", + "# test transform\n", + "new_ga = GroupedArray(np.random.rand(10), np.arange(11))\n", + "prev_orig = [diffs.original_values_[i].data[::d].copy() for i, d in enumerate(diffs.differences)]\n", + "expected = new_ga.data - np.add.reduce(prev_orig)\n", + "updates = diffs.update(new_ga)\n", + "np.testing.assert_allclose(expected, updates.data)\n", + "np.testing.assert_allclose(diffs.original_values_[0].data, new_ga.data)\n", + "np.testing.assert_allclose(diffs.original_values_[1].data[1::2], new_ga.data - prev_orig[0])\n", + "np.testing.assert_allclose(diffs.original_values_[2].data[4::5], new_ga.data - np.add.reduce(prev_orig[:2]))\n", + "# variable sizes\n", + "diff1 = Differences([1])\n", + "ga = GroupedArray(np.arange(10), np.array([0, 3, 10]))\n", + "diff1.fit_transform(ga)\n", + "new_ga = GroupedArray(np.arange(4), np.array([0, 1, 4]))\n", + "updates = diff1.update(new_ga)\n", + "np.testing.assert_allclose(updates.data, np.array([0 - 2, 1 - 9, 2 - 1, 3 - 2]))\n", + "np.testing.assert_allclose(diff1.original_values_[0].data, np.array([0, 3]))\n", + "\n", "# short series\n", "ga = GroupedArray(np.arange(20), np.array([0, 2, 20]))\n", "test_fail(lambda: diffs.fit_transform(ga), contains=\"[0]\")" @@ -231,12 +264,24 @@ "source": [ "#| exporti\n", "class BaseLocalScaler(BaseGroupedArrayTargetTransform):\n", - " \"\"\"Standardizes each serie by subtracting its mean and dividing by its standard deviation.\"\"\"\n", " scaler_factory: type\n", + "\n", + " def _is_utils_tfm(self):\n", + " return isinstance(self.scaler_, UtilsTargetTransform)\n", + "\n", + " def update(self, ga: GroupedArray) -> GroupedArray:\n", + " if not self._is_utils_tfm():\n", + " ga = CoreGroupedArray(ga.data, ga.indptr)\n", + " return GroupedArray(self.scaler_.transform(ga), ga.indptr)\n", " \n", " def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n", " self.scaler_ = self.scaler_factory()\n", - " transformed = self.scaler_.fit_transform(ga)\n", + " if self._is_utils_tfm():\n", + " transformed = self.scaler_.fit_transform(ga)\n", + " else:\n", + " core_ga = CoreGroupedArray(ga.data, ga.indptr)\n", + " self.scaler_.fit(core_ga)\n", + " transformed = self.scaler_.transform(core_ga)\n", " return GroupedArray(transformed, ga.indptr)\n", "\n", " def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n", @@ -245,7 +290,12 @@ " stats = stats[self.idxs]\n", " if stats.shape[0] != ga.n_groups:\n", " raise ValueError('Found different number of groups in scaler.')\n", - " transformed = _transform(ga.data, ga.indptr, stats, _common_scaler_inverse_transform)\n", + " if self._is_utils_tfm() or self.idxs is not None:\n", + " # core scalers can't transform a subset\n", + " transformed = _transform(ga.data, ga.indptr, stats, _common_scaler_inverse_transform)\n", + " else:\n", + " core_ga = CoreGroupedArray(ga.data, ga.indptr)\n", + " transformed = self.scaler_.inverse_transform(core_ga)\n", " return GroupedArray(transformed, ga.indptr)\n", "\n", " def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:\n", @@ -268,13 +318,8 @@ " sc.inverse_transform(transformed).data,\n", " ga.data,\n", " )\n", - " \n", - " def filter_df(df):\n", - " return (\n", - " df[df['unique_id'].isin(['id_0', 'id_7'])]\n", - " .groupby('unique_id', observed=True)\n", - " .head(10)\n", - " )\n", + " transformed2 = sc.update(ga)\n", + " np.testing.assert_allclose(transformed.data, transformed2.data)\n", " \n", " idxs = [0, 7]\n", " subset = ga.take(idxs)\n", @@ -296,7 +341,7 @@ "#| export\n", "class LocalStandardScaler(BaseLocalScaler):\n", " \"\"\"Standardizes each serie by subtracting its mean and dividing by its standard deviation.\"\"\"\n", - " scaler_factory = StandardScaler" + " scaler_factory = core_scalers.LocalStandardScaler if CORE_INSTALLED else StandardScaler" ] }, { @@ -319,7 +364,7 @@ "#| export\n", "class LocalMinMaxScaler(BaseLocalScaler):\n", " \"\"\"Scales each serie to be in the [0, 1] interval.\"\"\"\n", - " scaler_factory = MinMaxScaler" + " scaler_factory = core_scalers.LocalMinMaxScaler if CORE_INSTALLED else MinMaxScaler" ] }, { @@ -350,7 +395,7 @@ " \"\"\"\n", "\n", " def __init__(self, scale: str):\n", - " self.scaler_factory = lambda: RobustScaler(scale) # type: ignore" + " self.scaler_factory = lambda: core_scalers.LocalRobustScaler(scale) if CORE_INSTALLED else RobustScaler(scale) # type: ignore" ] }, { @@ -384,16 +429,16 @@ "class LocalBoxCox(BaseLocalScaler):\n", " \"\"\"Finds the optimum lambda for each serie and applies the Box-Cox transformation\"\"\"\n", " def __init__(self):\n", - " self.scaler = BoxCox()\n", + " self.scaler_ = BoxCox()\n", " \n", " def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n", - " return GroupedArray(self.scaler.fit_transform(ga), ga.indptr)\n", + " return GroupedArray(self.scaler_.fit_transform(ga), ga.indptr)\n", "\n", " def inverse_transform(self, ga: GroupedArray) -> GroupedArray:\n", " from scipy.special import inv_boxcox1p\n", "\n", " sizes = np.diff(ga.indptr)\n", - " lmbdas = self.scaler.lmbdas_\n", + " lmbdas = self.scaler_.lmbdas_\n", " if self.idxs is not None:\n", " lmbdas = lmbdas[self.idxs]\n", " lmbdas = np.repeat(lmbdas, sizes, axis=0)\n", @@ -423,6 +468,11 @@ " def __init__(self, transformer: TransformerMixin):\n", " self.transformer = transformer\n", "\n", + " def update(self, df: pd.DataFrame) -> pd.DataFrame:\n", + " df = df.copy(deep=False)\n", + " df[self.target_col] = self.transformer_.transform(df[[self.target_col]].values)\n", + " return df\n", + "\n", " def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:\n", " df = df.copy(deep=False)\n", " self.transformer_ = clone(self.transformer)\n", diff --git a/settings.ini b/settings.ini index 54f77f8f..5d80afca 100644 --- a/settings.ini +++ b/settings.ini @@ -8,7 +8,7 @@ author = José Morales author_email = jmoralz92@gmail.com copyright = Nixtla branch = main -version = 0.11.2 +version = 0.11.5 min_python = 3.8 audience = Developers language = English