Skip to content

Commit

Permalink
feat: add weight_col to MLForecast.fit and `MLForecast.cross_vali…
Browse files Browse the repository at this point in the history
…dation` (#444)
  • Loading branch information
jmoralez authored Nov 11, 2024
1 parent 3edd9df commit 2164ce4
Show file tree
Hide file tree
Showing 7 changed files with 1,180 additions and 118 deletions.
15 changes: 13 additions & 2 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _fit(
target_col: str,
static_features: Optional[List[str]] = None,
keep_last_n: Optional[int] = None,
weight_col: Optional[str] = None,
) -> "TimeSeries":
"""Save the series values, ids and last dates."""
validate_format(df, id_col, time_col, target_col)
Expand All @@ -251,6 +252,7 @@ def _fit(
self.id_col = id_col
self.target_col = target_col
self.time_col = time_col
self.weight_col = weight_col
self.keep_last_n = keep_last_n
self.static_features = static_features
sorted_df = df[[id_col, time_col, target_col]]
Expand Down Expand Up @@ -298,9 +300,12 @@ def _fit(
if static_features is None:
static_features = [c for c in df.columns if c not in [time_col, target_col]]
elif id_col not in static_features:
static_features = [id_col] + static_features
static_features = [id_col, *static_features]
else: # static_features defined and contain id_col
to_drop = [time_col, target_col]
if weight_col is not None:
to_drop.append(weight_col)
static_features = [f for f in static_features if f != weight_col]
self.ga = ga
series_starts = ga.indptr[:-1]
series_ends = ga.indptr[1:] - 1
Expand Down Expand Up @@ -478,7 +483,11 @@ def _transform(

# assemble return
if return_X_y:
X = df[self.features_order_]
if self.weight_col is not None:
x_cols = [self.weight_col, *self.features_order_]
else:
x_cols = self.features_order_
X = df[x_cols]
if as_numpy:
X = ufp.to_numpy(X)
return X, target
Expand Down Expand Up @@ -506,6 +515,7 @@ def fit_transform(
max_horizon: Optional[int] = None,
return_X_y: bool = False,
as_numpy: bool = False,
weight_col: Optional[str] = None,
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
"""Add the features to `data` and save the required information for the predictions step.
Expand All @@ -522,6 +532,7 @@ def fit_transform(
target_col=target_col,
static_features=static_features,
keep_last_n=keep_last_n,
weight_col=weight_col,
)
return self._transform(
df=data,
Expand Down
55 changes: 45 additions & 10 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def preprocess(
max_horizon: Optional[int] = None,
return_X_y: bool = False,
as_numpy: bool = False,
weight_col: Optional[str] = None,
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
"""Add the features to `data`.
Expand All @@ -239,6 +240,8 @@ def preprocess(
Return a tuple with the features and the target. If False will return a single dataframe.
as_numpy : bool (default = False)
Cast features to numpy array. Only works for `return_X_y=True`.
weight_col : str, optional (default=None)
Column that contains the sample weights.
Returns
-------
Expand All @@ -256,6 +259,7 @@ def preprocess(
max_horizon=max_horizon,
return_X_y=return_X_y,
as_numpy=as_numpy,
weight_col=weight_col,
)

def fit_models(
Expand All @@ -277,21 +281,31 @@ def fit_models(
self : MLForecast
Forecast object with trained models.
"""

def fit_model(model, X, y, weight_col):
fit_kwargs = {}
if weight_col is not None:
if isinstance(X, np.ndarray):
fit_kwargs["sample_weight"] = X[:, 0]
X = X[:, 1:]
else:
fit_kwargs["sample_weight"] = X[weight_col]
X = ufp.drop_columns(X, weight_col)
return clone(model).fit(X, y, **fit_kwargs)

self.models_: Dict[str, Union[BaseEstimator, List[BaseEstimator]]] = {}
for name, model in self.models.items():
if y.ndim == 2 and y.shape[1] > 1:
self.models_[name] = []
for col in range(y.shape[1]):
keep = ~np.isnan(y[:, col])
if isinstance(X, np.ndarray):
# TODO: migrate to utils
Xh = X[keep]
else:
Xh = ufp.filter_with_mask(X, keep)
Xh = ufp.filter_with_mask(X, keep)
yh = y[keep, col]
self.models_[name].append(clone(model).fit(Xh, yh))
self.models_[name].append(
fit_model(model, Xh, yh, self.ts.weight_col)
)
else:
self.models_[name] = clone(model).fit(X, y)
self.models_[name] = fit_model(model, X, y, self.ts.weight_col)
return self

def _conformity_scores(
Expand Down Expand Up @@ -380,8 +394,12 @@ def _extract_X_y(
self,
prep: DFType,
target_col: str,
weight_col: Optional[str],
) -> Tuple[Union[DFType, np.ndarray], np.ndarray]:
X = prep[self.ts.features_order_]
x_cols = self.ts.features_order_
if weight_col is not None:
x_cols = [weight_col, *x_cols]
X = prep[x_cols]
targets = [c for c in prep.columns if re.match(rf"^{target_col}\d*$", c)]
if len(targets) == 1:
targets = targets[0]
Expand All @@ -397,7 +415,13 @@ def _compute_fitted_values(
time_col: str,
target_col: str,
max_horizon: Optional[int],
weight_col: Optional[str],
) -> DFType:
if weight_col is not None:
if isinstance(X, np.ndarray):
X = X[:, 1:]
else:
X = ufp.drop_columns(X, weight_col)
base = ufp.copy_if_pandas(base, deep=False)
sort_idxs = ufp.maybe_compute_sort_indices(base, id_col, time_col)
if sort_idxs is not None:
Expand Down Expand Up @@ -456,6 +480,7 @@ def fit(
prediction_intervals: Optional[PredictionIntervals] = None,
fitted: bool = False,
as_numpy: bool = False,
weight_col: Optional[str] = None,
) -> "MLForecast":
"""Apply the feature engineering and train the models.
Expand Down Expand Up @@ -484,6 +509,8 @@ def fit(
Save in-sample predictions.
as_numpy : bool (default = False)
Cast features to numpy array.
weight_col : str, optional (default=None)
Column that contains the sample weights.
Returns
-------
Expand Down Expand Up @@ -520,12 +547,13 @@ def fit(
max_horizon=max_horizon,
return_X_y=not fitted,
as_numpy=as_numpy,
weight_col=weight_col,
)
if isinstance(prep, tuple):
X, y = prep
else:
base = prep[[id_col, time_col]]
X, y = self._extract_X_y(prep, target_col)
X, y = self._extract_X_y(prep, target_col, weight_col)
if as_numpy:
X = ufp.to_numpy(X)
del prep
Expand All @@ -539,6 +567,7 @@ def fit(
time_col=time_col,
target_col=target_col,
max_horizon=max_horizon,
weight_col=self.ts.weight_col,
)
fitted_values = ufp.drop_index_if_pandas(fitted_values)
self.fcst_fitted_values_ = fitted_values
Expand Down Expand Up @@ -784,6 +813,7 @@ def cross_validation(
input_size: Optional[int] = None,
fitted: bool = False,
as_numpy: bool = False,
weight_col: Optional[str] = None,
) -> DFType:
"""Perform time series cross validation.
Creates `n_windows` splits where each window has `h` test periods,
Expand Down Expand Up @@ -835,6 +865,8 @@ def cross_validation(
Store the in-sample predictions.
as_numpy : bool (default = False)
Cast features to numpy array.
weight_col : str, optional (default=None)
Column that contains the sample weights.
Returns
-------
Expand Down Expand Up @@ -869,6 +901,7 @@ def cross_validation(
prediction_intervals=prediction_intervals,
fitted=fitted,
as_numpy=as_numpy,
weight_col=weight_col,
)
cv_models.append(self.models_)
if fitted:
Expand All @@ -890,10 +923,11 @@ def cross_validation(
keep_last_n=keep_last_n,
max_horizon=max_horizon,
return_X_y=False,
weight_col=weight_col,
)
assert not isinstance(prep, tuple)
base = prep[[id_col, time_col]]
train_X, train_y = self._extract_X_y(prep, target_col)
train_X, train_y = self._extract_X_y(prep, target_col, weight_col)
if as_numpy:
train_X = ufp.to_numpy(train_X)
del prep
Expand All @@ -905,6 +939,7 @@ def cross_validation(
time_col=time_col,
target_col=target_col,
max_horizon=max_horizon,
weight_col=weight_col,
)
fitted_values = ufp.assign_columns(fitted_values, "fold", i_window)
cv_fitted_values.append(fitted_values)
Expand Down
33 changes: 23 additions & 10 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@
" target_col: str,\n",
" static_features: Optional[List[str]] = None,\n",
" keep_last_n: Optional[int] = None,\n",
" weight_col: Optional[str] = None,\n",
" ) -> 'TimeSeries':\n",
" \"\"\"Save the series values, ids and last dates.\"\"\"\n",
" validate_format(df, id_col, time_col, target_col)\n",
Expand All @@ -745,6 +746,7 @@
" self.id_col = id_col\n",
" self.target_col = target_col\n",
" self.time_col = time_col\n",
" self.weight_col = weight_col\n",
" self.keep_last_n = keep_last_n\n",
" self.static_features = static_features\n",
" sorted_df = df[[id_col, time_col, target_col]]\n",
Expand Down Expand Up @@ -790,9 +792,12 @@
" if static_features is None:\n",
" static_features = [c for c in df.columns if c not in [time_col, target_col]]\n",
" elif id_col not in static_features:\n",
" static_features = [id_col] + static_features\n",
" static_features = [id_col, *static_features]\n",
" else: # static_features defined and contain id_col\n",
" to_drop = [time_col, target_col]\n",
" if weight_col is not None:\n",
" to_drop.append(weight_col)\n",
" static_features = [f for f in static_features if f != weight_col]\n",
" self.ga = ga\n",
" series_starts = ga.indptr[:-1]\n",
" series_ends = ga.indptr[1:] - 1\n",
Expand Down Expand Up @@ -967,7 +972,11 @@
"\n",
" # assemble return\n",
" if return_X_y:\n",
" X = df[self.features_order_]\n",
" if self.weight_col is not None:\n",
" x_cols = [self.weight_col, *self.features_order_]\n",
" else:\n",
" x_cols = self.features_order_\n",
" X = df[x_cols]\n",
" if as_numpy:\n",
" X = ufp.to_numpy(X)\n",
" return X, target\n",
Expand Down Expand Up @@ -996,6 +1005,7 @@
" max_horizon: Optional[int] = None,\n",
" return_X_y: bool = False,\n",
" as_numpy: bool = False,\n",
" weight_col: Optional[str] = None,\n",
" ) -> Union[DFType, Tuple[DFType, np.ndarray]]:\n",
" \"\"\"Add the features to `data` and save the required information for the predictions step.\n",
" \n",
Expand All @@ -1012,6 +1022,7 @@
" target_col=target_col,\n",
" static_features=static_features,\n",
" keep_last_n=keep_last_n,\n",
" weight_col=weight_col,\n",
" )\n",
" return self._transform(\n",
" df=data,\n",
Expand Down Expand Up @@ -1690,7 +1701,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
Expand All @@ -1700,7 +1711,8 @@
"> dropna:bool=True,\n",
"> keep_last_n:Optional[int]=None,\n",
"> max_horizon:Optional[int]=None,\n",
"> return_X_y:bool=False, as_numpy:bool=False)\n",
"> return_X_y:bool=False, as_numpy:bool=False,\n",
"> weight_col:Optional[str]=None)\n",
"\n",
"*Add the features to `data` and save the required information for the predictions step.\n",
"\n",
Expand All @@ -1711,7 +1723,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
Expand All @@ -1721,7 +1733,8 @@
"> dropna:bool=True,\n",
"> keep_last_n:Optional[int]=None,\n",
"> max_horizon:Optional[int]=None,\n",
"> return_X_y:bool=False, as_numpy:bool=False)\n",
"> return_X_y:bool=False, as_numpy:bool=False,\n",
"> weight_col:Optional[str]=None)\n",
"\n",
"*Add the features to `data` and save the required information for the predictions step.\n",
"\n",
Expand Down Expand Up @@ -2003,7 +2016,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand All @@ -2017,7 +2030,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand Down Expand Up @@ -2155,7 +2168,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
Expand All @@ -2168,7 +2181,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
Expand Down
Loading

0 comments on commit 2164ce4

Please sign in to comment.