Skip to content

Commit

Permalink
enh: add step_size to AutoMLForecast (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Oct 10, 2024
1 parent 282d91f commit 95b91c4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 22 deletions.
4 changes: 4 additions & 0 deletions mlforecast/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def fit(
n_windows: int,
h: int,
num_samples: int,
step_size: Optional[int] = None,
refit: Union[bool, int] = False,
loss: Optional[Callable[[DataFrame, DataFrame], float]] = None,
id_col: str = "unique_id",
Expand All @@ -467,6 +468,8 @@ def fit(
Forecast horizon.
num_samples : int
Number of trials to run
step_size : int, optional (default=None)
Step size between each cross validation window. If None it will be equal to `h`.
refit : bool or int (default=False)
Retrain model for each cross validation window.
If False, the models are trained at the beginning and then used to predict each window.
Expand Down Expand Up @@ -541,6 +544,7 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]:
freq=self.freq,
n_windows=n_windows,
h=h,
step_size=step_size,
refit=refit,
id_col=id_col,
time_col=time_col,
Expand Down
82 changes: 60 additions & 22 deletions nbs/auto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoModel\n",
"\n",
Expand All @@ -289,7 +289,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoModel\n",
"\n",
Expand Down Expand Up @@ -522,6 +522,7 @@
" n_windows: int,\n",
" h: int,\n",
" num_samples: int,\n",
" step_size: Optional[int] = None,\n",
" refit: Union[bool, int] = False,\n",
" loss: Optional[Callable[[DataFrame, DataFrame], float]] = None,\n",
" id_col: str = 'unique_id',\n",
Expand All @@ -545,6 +546,8 @@
" Forecast horizon.\n",
" num_samples : int\n",
" Number of trials to run\n",
" step_size : int, optional (default=None)\n",
" Step size between each cross validation window. If None it will be equal to `h`.\n",
" refit : bool or int (default=False)\n",
" Retrain model for each cross validation window.\n",
" If False, the models are trained at the beginning and then used to predict each window.\n",
Expand Down Expand Up @@ -616,6 +619,7 @@
" freq=self.freq,\n",
" n_windows=n_windows,\n",
" h=h,\n",
" step_size=step_size,\n",
" refit=refit,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
Expand Down Expand Up @@ -726,7 +730,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast\n",
"\n",
Expand All @@ -752,7 +756,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast\n",
"\n",
Expand Down Expand Up @@ -796,18 +800,19 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.fit\n",
"\n",
"> AutoMLForecast.fit\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
"> e.frame.DataFrame], n_windows:int, h:int,\n",
"> num_samples:int, refit:Union[bool,int]=False, loss:Op\n",
"> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n",
"> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n",
"> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n",
"> ]]=None, id_col:str='unique_id', time_col:str='ds',\n",
"> num_samples:int, step_size:Optional[int]=None,\n",
"> refit:Union[bool,int]=False, loss:Optional[Callable[[\n",
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
"> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n",
"> rs.dataframe.frame.DataFrame]],float]]=None,\n",
"> id_col:str='unique_id', time_col:str='ds',\n",
"> target_col:str='y',\n",
"> study_kwargs:Optional[Dict[str,Any]]=None,\n",
"> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
Expand All @@ -823,6 +828,7 @@
"| n_windows | int | | Number of windows to evaluate. |\n",
"| h | int | | Forecast horizon. |\n",
"| num_samples | int | | Number of trials to run |\n",
"| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n",
"| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
"| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
Expand All @@ -837,18 +843,19 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.fit\n",
"\n",
"> AutoMLForecast.fit\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
"> e.frame.DataFrame], n_windows:int, h:int,\n",
"> num_samples:int, refit:Union[bool,int]=False, loss:Op\n",
"> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n",
"> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n",
"> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n",
"> ]]=None, id_col:str='unique_id', time_col:str='ds',\n",
"> num_samples:int, step_size:Optional[int]=None,\n",
"> refit:Union[bool,int]=False, loss:Optional[Callable[[\n",
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
"> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n",
"> rs.dataframe.frame.DataFrame]],float]]=None,\n",
"> id_col:str='unique_id', time_col:str='ds',\n",
"> target_col:str='y',\n",
"> study_kwargs:Optional[Dict[str,Any]]=None,\n",
"> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
Expand All @@ -864,6 +871,7 @@
"| n_windows | int | | Number of windows to evaluate. |\n",
"| h | int | | Forecast horizon. |\n",
"| num_samples | int | | Number of trials to run |\n",
"| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n",
"| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
"| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
Expand Down Expand Up @@ -896,7 +904,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.predict\n",
"\n",
Expand All @@ -916,7 +924,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.predict\n",
"\n",
Expand Down Expand Up @@ -954,7 +962,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.save\n",
"\n",
Expand All @@ -970,7 +978,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.save\n",
"\n",
Expand Down Expand Up @@ -1004,7 +1012,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.forecast_fitted_values\n",
"\n",
Expand All @@ -1022,7 +1030,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.forecast_fitted_values\n",
"\n",
Expand Down Expand Up @@ -1702,6 +1710,36 @@
"#| polars\n",
"auto_mlf.forecast_fitted_values(level=[95])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0dfe2c18-d3df-41f2-a3b8-1cf73ef9765c",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| polars\n",
"auto_mlf2 = AutoMLForecast(\n",
" freq=1,\n",
" season_length=season_length,\n",
" models={'ridge': AutoRidge()},\n",
" num_threads=2,\n",
")\n",
"auto_mlf2.fit(\n",
" df=train_pl,\n",
" n_windows=2,\n",
" h=h,\n",
" step_size=1,\n",
" num_samples=2,\n",
" optimize_kwargs={'timeout': 60},\n",
" fitted=True,\n",
" prediction_intervals=PredictionIntervals(n_windows=2, h=h),\n",
")\n",
"metric_step_h = auto_mlf.results_['ridge'].best_trial.value\n",
"metric_step_1 = auto_mlf2.results_['ridge'].best_trial.value\n",
"assert abs(metric_step_h / metric_step_1 - 1) > 0.02"
]
}
],
"metadata": {
Expand Down

0 comments on commit 95b91c4

Please sign in to comment.