Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Conformal Predictions in NeuralForecast #1171

Merged
merged 27 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fe65a1e
Add ConformalIntervals class
JQGoh Sep 24, 2024
df9528f
remove level parameter from conformalInterval class
JQGoh Sep 30, 2024
6a4da12
conformal fit and predict logic stored in utils file
JQGoh Oct 3, 2024
e2b0c4a
Specify losses that do not support conformal prediction
JQGoh Oct 3, 2024
2ea83eb
conformal prediction integrated to NeuralForecast class
JQGoh Oct 3, 2024
6aaa80b
HuberQLoss does not support conformal prediction
JQGoh Oct 3, 2024
20e1672
Remove unncessary illustration
JQGoh Oct 3, 2024
37d4f28
Add test for model saving & loading for conformal predictions
JQGoh Oct 3, 2024
c1140f6
Fix model saving/loading missing conformal_intervals
JQGoh Oct 3, 2024
952c332
Add tutorial on conformal prediction
JQGoh Oct 3, 2024
acd4f9c
Fix attribute error during model saving
JQGoh Oct 3, 2024
0cf2802
Review: clear core.ipynb output
JQGoh Oct 3, 2024
5eec585
Review: Corrections to undesired copy-and-paste notes; revision of
JQGoh Oct 3, 2024
d514058
Improve example
JQGoh Oct 3, 2024
8e4ab58
Review: Use DFType instead
JQGoh Oct 4, 2024
5503b6e
Improve example with better illustration
JQGoh Oct 4, 2024
f497c54
Review: Simply without using UNSUPORTED_LOSSED_CONFORMAL,
JQGoh Oct 7, 2024
d714e51
Revise example with the remark on conformalize_quantiles argument
JQGoh Oct 7, 2024
be38777
clean nbs/core.ipynb
JQGoh Oct 7, 2024
c0c24ec
Missed: Revise type to DFType
JQGoh Oct 7, 2024
079a804
Rename to enable_quantiles; avoid confusing interpretation
JQGoh Oct 7, 2024
4561ef1
fix_issues
elephaint Oct 8, 2024
f23661a
fix_example
elephaint Oct 8, 2024
241d76a
CrossValidation can provide conformal-intervals outputs if refit=True
JQGoh Oct 8, 2024
717916e
Merge branch 'main' into feat/conformal-prediction
elephaint Oct 10, 2024
bf62b3c
add_protections
elephaint Oct 10, 2024
51a3eb2
Merge branch 'main' into feat/conformal-prediction
elephaint Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 265 additions & 7 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
" LocalRobustScaler,\n",
" LocalStandardScaler,\n",
")\n",
"from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series\n",
"from utilsforecast.compat import DataFrame, DFType, Series, pl_DataFrame, pl_Series\n",
"from utilsforecast.validation import validate_freq\n",
"\n",
"from neuralforecast.common._base_model import DistributedConfig\n",
Expand All @@ -95,7 +95,8 @@
" BiTCN, TiDE, DeepNPTS, SOFTS,\n",
" TimeMixer, KAN, RMoK\n",
")\n",
"from neuralforecast.common._base_auto import BaseAuto, MockTrial"
"from neuralforecast.common._base_auto import BaseAuto, MockTrial\n",
"from neuralforecast.utils import ConformalIntervals, get_conformal_method"
]
},
{
Expand Down Expand Up @@ -506,6 +507,7 @@
" time_col: str = 'ds',\n",
" target_col: str = 'y',\n",
" distributed_config: Optional[DistributedConfig] = None,\n",
" conformal_intervals: Optional[ConformalIntervals] = None,\n",
" ) -> None:\n",
" \"\"\"Fit the core.NeuralForecast.\n",
"\n",
Expand Down Expand Up @@ -535,6 +537,8 @@
" Column that contains the target.\n",
" distributed_config : neuralforecast.DistributedConfig\n",
" Configuration to use for DDP training. Currently only spark is supported.\n",
" conformal_intervals : ConformalIntervals, optional (default=None)\n",
" Configuration to calibrate prediction intervals (Conformal Prediction). \n",
"\n",
" Returns\n",
" -------\n",
Expand All @@ -550,6 +554,9 @@
" and val_size == 0\n",
" ):\n",
" raise Exception('Set val_size>0 if early stopping is enabled.')\n",
" \n",
" self._cs_df: Optional[DFType] = None\n",
" self.conformal_intervals: Optional[ConformalIntervals] = None\n",
"\n",
" # Process and save new dataset (in self)\n",
" if isinstance(df, (pd.DataFrame, pl_DataFrame)):\n",
Expand Down Expand Up @@ -603,6 +610,17 @@
" if self.dataset.min_size < val_size:\n",
" warnings.warn('Validation set size is larger than the shorter time-series.')\n",
"\n",
" if conformal_intervals is not None:\n",
" # conformal prediction\n",
" self.conformal_intervals = conformal_intervals\n",
" self._cs_df = self._conformity_scores(\n",
" df=df,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" target_col=target_col,\n",
" static_df=static_df,\n",
" )\n",
"\n",
" # Recover initial model if use_init_models\n",
" if use_init_models:\n",
" self._reset_models()\n",
Expand Down Expand Up @@ -708,10 +726,14 @@
"\n",
" return futr_exog | set(hist_exog)\n",
" \n",
" def _get_model_names(self) -> List[str]:\n",
" def _get_model_names(self, conformal=False, enable_quantiles=False) -> List[str]:\n",
" names: List[str] = []\n",
" count_names = {'model': 0}\n",
" for model in self.models:\n",
" if conformal and not enable_quantiles and model.loss.outputsize_multiplier > 1:\n",
" # skip prediction intervals on quantile outputs\n",
" continue\n",
"\n",
" model_name = repr(model)\n",
" count_names[model_name] = count_names.get(model_name, -1) + 1\n",
" if count_names[model_name] > 0:\n",
Expand Down Expand Up @@ -834,6 +856,7 @@
" sort_df: bool = True,\n",
" verbose: bool = False,\n",
" engine = None,\n",
" conformal_level: Optional[List[Union[int, float]]] = None,\n",
" **data_kwargs\n",
" ):\n",
" \"\"\"Predict with core.NeuralForecast.\n",
Expand All @@ -855,6 +878,8 @@
" Print processing steps.\n",
" engine : spark session\n",
" Distributed engine for inference. Only used if df is a spark dataframe or if fit was called on a spark dataframe.\n",
" conformal_level : list of ints or floats, optional (default=None)\n",
" Confidence levels between 0 and 100 for conformal intervals.\n",
" data_kwargs : kwargs\n",
" Extra arguments to be passed to the dataset within each model.\n",
"\n",
Expand Down Expand Up @@ -989,6 +1014,29 @@
" if isinstance(fcsts_df, pd.DataFrame) and _id_as_idx():\n",
" _warn_id_as_idx()\n",
" fcsts_df = fcsts_df.set_index(self.id_col)\n",
"\n",
" # perform conformal predictions\n",
" if conformal_level is not None:\n",
elephaint marked this conversation as resolved.
Show resolved Hide resolved
" if self._cs_df is None or self.conformal_intervals is None:\n",
" warn_msg = (\n",
" 'Please rerun the `fit` method passing a valid conformal_interval settings to compute conformity scores'\n",
" )\n",
" warnings.warn(warn_msg, UserWarning)\n",
elephaint marked this conversation as resolved.
Show resolved Hide resolved
" else:\n",
" level_ = sorted(conformal_level)\n",
" model_names = self._get_model_names(conformal=True, enable_quantiles=self.conformal_intervals.enable_quantiles)\n",
" conformal_method = get_conformal_method(self.conformal_intervals.method)\n",
"\n",
" fcsts_df = conformal_method(\n",
" fcsts_df,\n",
" self._cs_df,\n",
" model_names=list(model_names),\n",
" level=level_,\n",
" cs_n_windows=self.conformal_intervals.n_windows,\n",
" n_series=len(uids),\n",
" horizon=self.h,\n",
" )\n",
"\n",
" return fcsts_df\n",
"\n",
" def _reset_models(self):\n",
Expand Down Expand Up @@ -1474,6 +1522,9 @@
" \"id_col\": self.id_col,\n",
" \"time_col\": self.time_col,\n",
" \"target_col\": self.target_col,\n",
" # conformal prediction\n",
" \"conformal_intervals\": self.conformal_intervals,\n",
" \"_cs_df\": self._cs_df, # conformity score\n",
" }\n",
" if save_dataset:\n",
" config_dict.update(\n",
Expand Down Expand Up @@ -1561,6 +1612,10 @@
"\n",
" for attr in ['id_col', 'time_col', 'target_col']:\n",
" setattr(neuralforecast, attr, config_dict[attr])\n",
" # only restore attribute if available\n",
" for attr in ['conformal_intervals', '_cs_df']:\n",
" if attr in config_dict.keys():\n",
" setattr(neuralforecast, attr, config_dict[attr])\n",
"\n",
" # Dataset\n",
" if dataset is not None:\n",
Expand All @@ -1579,7 +1634,60 @@
"\n",
" neuralforecast.scalers_ = config_dict['scalers_']\n",
"\n",
" return neuralforecast"
" return neuralforecast\n",
" \n",
" def _conformity_scores(\n",
" self,\n",
" df: DataFrame,\n",
" id_col: str, \n",
" time_col: str,\n",
" target_col: str,\n",
" static_df: Optional[Union[DataFrame, SparkDataFrame]] = None,\n",
" ) -> DataFrame:\n",
" \"\"\"Compute conformity scores.\n",
" \n",
" We need at least two cross validation errors to compute\n",
" quantiles for prediction intervals (`n_windows=2`, specified by self.conformal_intervals).\n",
" \n",
" The exception is raised by the ConformalIntervals data class.\n",
"\n",
" df: Optional[Union[DataFrame, SparkDataFrame, Sequence[str]]] = None,\n",
" id_col: str = 'unique_id',\n",
" time_col: str = 'ds',\n",
" target_col: str = 'y',\n",
" static_df: Optional[Union[DataFrame, SparkDataFrame]] = None,\n",
" \"\"\"\n",
" if self.conformal_intervals is None:\n",
" raise AttributeError('Please rerun the `fit` method passing a valid conformal_interval settings to compute conformity scores')\n",
" \n",
" min_size = ufp.counts_by_id(df, id_col)['counts'].min()\n",
" min_samples = self.h * self.conformal_intervals.n_windows + 1\n",
" if min_size < min_samples:\n",
" raise ValueError(\n",
" \"Minimum required samples in each serie for the prediction intervals \"\n",
" f\"settings are: {min_samples}, shortest serie has: {min_size}. \"\n",
" \"Please reduce the number of windows, horizon or remove those series.\"\n",
" )\n",
" \n",
" cv_results = self.cross_validation(\n",
" df=df,\n",
" static_df=static_df,\n",
" n_windows=self.conformal_intervals.n_windows,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" target_col=target_col,\n",
" )\n",
" \n",
" kept = [time_col, id_col, 'cutoff']\n",
" # conformity score for each model\n",
" for model in self._get_model_names(conformal=True, enable_quantiles=self.conformal_intervals.enable_quantiles):\n",
" kept.append(model)\n",
"\n",
" # compute absolute error for each model\n",
" abs_err = abs(cv_results[model] - cv_results[target_col])\n",
" cv_results = ufp.assign_columns(cv_results, model, abs_err)\n",
" dropped = list(set(cv_results.columns) - set(kept))\n",
" return ufp.drop_columns(cv_results, dropped) "
]
},
{
Expand Down Expand Up @@ -1806,6 +1914,14 @@
"test_eq(init_fcst, after_fcst)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d94486f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -2506,8 +2622,9 @@
" ],\n",
" freq='M'\n",
")\n",
"fcst.fit(AirPassengersPanel_train)\n",
"forecasts1 = fcst.predict(futr_df=AirPassengersPanel_test)\n",
"conformal_intervals = ConformalIntervals()\n",
"fcst.fit(AirPassengersPanel_train, conformal_intervals=conformal_intervals)\n",
"forecasts1 = fcst.predict(futr_df=AirPassengersPanel_test, conformal_level=[50])\n",
"save_paths = ['./examples/debug_run/']\n",
"try:\n",
" s3fs.S3FileSystem().ls('s3://nixtla-tmp') \n",
Expand All @@ -2521,7 +2638,7 @@
"for path in save_paths:\n",
" fcst.save(path=path, model_index=None, overwrite=True, save_dataset=True)\n",
" fcst2 = NeuralForecast.load(path=path)\n",
" forecasts2 = fcst2.predict(futr_df=AirPassengersPanel_test)\n",
" forecasts2 = fcst2.predict(futr_df=AirPassengersPanel_test, conformal_level=[50])\n",
" pd.testing.assert_frame_equal(forecasts1, forecasts2[forecasts1.columns])"
]
},
Expand Down Expand Up @@ -3204,6 +3321,147 @@
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\" in str(w.message) for w in issued_warnings)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0441e9d5",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test conformal prediction, method=conformal_distribution\n",
elephaint marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"conformal_intervals = ConformalIntervals()\n",
"\n",
"models = []\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
elephaint marked this conversation as resolved.
Show resolved Hide resolved
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models.append(nf_model(**params))\n",
"\n",
"\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(AirPassengersPanel_train, conformal_intervals=conformal_intervals)\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, conformal_level=[10, 50, 90])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9eb89f3b",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| polars\n",
"# test conformal prediction works for polar dataframe\n",
"\n",
"conformal_intervals = ConformalIntervals()\n",
"\n",
"models = []\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models.append(nf_model(**params))\n",
"\n",
"\n",
"nf = NeuralForecast(models=models, freq='1mo')\n",
"nf.fit(AirPassengers_pl, conformal_intervals=conformal_intervals, time_col='time', id_col='uid', target_col='target')\n",
"preds = nf.predict(conformal_level=[10, 50, 90])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0db88cac",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test conformal prediction, method=conformal_error\n",
"\n",
"conformal_intervals = ConformalIntervals(method=\"conformal_error\")\n",
"\n",
"models = []\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models.append(nf_model(**params))\n",
"\n",
"\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(AirPassengersPanel_train, conformal_intervals=conformal_intervals)\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, conformal_level=[10, 50, 90])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d25b2cd2",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test conformal prediction are not applied for models with quantiled-related loss\n",
elephaint marked this conversation as resolved.
Show resolved Hide resolved
"# by default (ConformalIntervals.enable_quantiles=False)\n",
"\n",
"conformal_intervals = ConformalIntervals()\n",
"\n",
"models = []\n",
"for nf_model in [NHITS, RNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"NHITS\":\n",
" params.update({\"loss\": MQLoss(level=[80])})\n",
" models.append(nf_model(**params))\n",
"\n",
"\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(AirPassengersPanel_train, conformal_intervals=conformal_intervals)\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, conformal_level=[10, 50, 90])\n",
"\n",
"pred_cols = [\n",
" 'NHITS-median', 'NHITS-lo-80', 'NHITS-hi-80', 'RNN',\n",
" 'RNN-conformal-lo-90', 'RNN-conformal-lo-50', 'RNN-conformal-lo-10',\n",
" 'RNN-conformal-hi-10', 'RNN-conformal-hi-50', 'RNN-conformal-hi-90'\n",
"]\n",
"assert all([col in preds.columns for col in pred_cols])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b980087",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test conformal predictions applied to quantiles if ConformalIntervals.enable_quantiles=True\n",
"\n",
"conformal_intervals = ConformalIntervals(enable_quantiles=True)\n",
"\n",
"nf = NeuralForecast(models=[NHITS(h=12, input_size=24, max_steps=1, loss=MQLoss(level=[80]))], freq='M')\n",
"nf.fit(AirPassengersPanel_train, conformal_intervals=conformal_intervals)\n",
"preds = nf.predict(futr_df=AirPassengersPanel_test, conformal_level=[10, 50, 90])\n",
"\n",
"pred_cols = [\n",
" 'NHITS-median', 'NHITS-lo-80', 'NHITS-hi-80',\n",
" 'NHITS-median-conformal-lo-90', 'NHITS-median-conformal-lo-50',\n",
" 'NHITS-median-conformal-lo-10', 'NHITS-median-conformal-hi-10',\n",
" 'NHITS-median-conformal-hi-50', 'NHITS-median-conformal-hi-90',\n",
" 'NHITS-lo-80-conformal-lo-90', 'NHITS-lo-80-conformal-lo-50',\n",
" 'NHITS-lo-80-conformal-lo-10', 'NHITS-lo-80-conformal-hi-10',\n",
" 'NHITS-lo-80-conformal-hi-50', 'NHITS-lo-80-conformal-hi-90',\n",
" 'NHITS-hi-80-conformal-lo-90', 'NHITS-hi-80-conformal-lo-50',\n",
" 'NHITS-hi-80-conformal-lo-10', 'NHITS-hi-80-conformal-hi-10',\n",
" 'NHITS-hi-80-conformal-hi-50', 'NHITS-hi-80-conformal-hi-90'\n",
"]\n",
"\n",
"assert all([col in preds.columns for col in pred_cols])\n"
]
elephaint marked this conversation as resolved.
Show resolved Hide resolved
}
],
"metadata": {
Expand Down
Loading