Skip to content

Commit

Permalink
support step_size selection in optimization.mlforecast_objective (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bchaoss authored Oct 2, 2024
1 parent ce2130e commit fa40207
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlforecast/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def mlforecast_objective(
freq: Freq,
n_windows: int,
h: int,
step_size: Optional[int] = None,
refit: Union[bool, int] = False,
id_col: str = "unique_id",
time_col: str = "ds",
Expand All @@ -53,6 +54,8 @@ def mlforecast_objective(
Number of windows to evaluate.
h : int
Forecast horizon.
step_size : int, optional (default=None)
Step size between each cross validation window. If None it will be equal to `h`.
refit : bool or int (default=False)
Retrain model for each cross validation window.
If False, the models are trained at the beginning and then used to predict each window.
Expand Down Expand Up @@ -82,6 +85,7 @@ def objective(trial: optuna.Trial) -> float:
id_col=id_col,
time_col=time_col,
freq=freq,
step_size=step_size,
)
model_copy = clone(model)
model_params = config["model_params"]
Expand Down
4 changes: 4 additions & 0 deletions nbs/optimization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
" freq: Freq,\n",
" n_windows: int,\n",
" h: int,\n",
" step_size: Optional[int] = None,\n",
" refit: Union[bool, int] = False,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
Expand All @@ -104,6 +105,8 @@
" Number of windows to evaluate.\n",
" h : int\n",
" Forecast horizon.\n",
" step_size : int, optional (default=None)\n",
" Step size between each cross validation window. If None it will be equal to `h`.\n",
" refit : bool or int (default=False)\n",
" Retrain model for each cross validation window.\n",
" If False, the models are trained at the beginning and then used to predict each window.\n",
Expand Down Expand Up @@ -132,6 +135,7 @@
" id_col=id_col,\n",
" time_col=time_col,\n",
" freq=freq,\n",
" step_size=step_size,\n",
" )\n",
" model_copy = clone(model)\n",
" model_params = config['model_params']\n",
Expand Down

0 comments on commit fa40207

Please sign in to comment.