Skip to content

Commit

Permalink
add quantile forests
Browse files Browse the repository at this point in the history
  • Loading branch information
MoritzM00 committed Jan 20, 2025
1 parent 3dc7615 commit 8f0fa18
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 30 deletions.
96 changes: 66 additions & 30 deletions dvc.lock
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ stages:
size: 1280486
- path: src/probafcst//models/
hash: md5
md5: 9e4fc803a61ab97704a486aef34eee91.dir
size: 60048
nfiles: 17
md5: 74c7ac520b842f6244b4bc9012232a8a.dir
size: 62864
nfiles: 19
- path: src/probafcst//utils/tabularization.py
hash: md5
md5: 20b1a00caed3c407c77056b04bd5fa4a
Expand Down Expand Up @@ -170,11 +170,31 @@ stages:
n_estimators: 200
random_state: 0
verbose: 0
qrf:
lags:
- 24
- 48
- 72
- 96
- 120
- 144
- 168
- 336
- 504
- 672
include_seasonal_dummies: true
cyclical_encodings: true
X_lag_cols: []
include_rolling_stats: false
kwargs:
n_estimators: 100
random_state: 0
n_jobs: -1
outs:
- path: models/energy_model.pkl
hash: md5
md5: 818c68a30e12899f3f209a88aa8e96df
size: 2273354
md5: 7536442fe6561688c046c8a3bb75efba
size: 2276094
train@bikes:
cmd: python src/probafcst/pipeline/train.py --target bikes
deps:
Expand All @@ -184,9 +204,9 @@ stages:
size: 148879
- path: src/probafcst//models/
hash: md5
md5: 9e4fc803a61ab97704a486aef34eee91.dir
size: 60048
nfiles: 17
md5: 74c7ac520b842f6244b4bc9012232a8a.dir
size: 62864
nfiles: 19
- path: src/probafcst//utils/tabularization.py
hash: md5
md5: 20b1a00caed3c407c77056b04bd5fa4a
Expand Down Expand Up @@ -283,11 +303,27 @@ stages:
n_estimators: 100
verbose: 0
random_state: 0
qrf:
lags:
- 1
- 2
- 3
- 7
- 14
- 21
include_seasonal_dummies: true
cyclical_encodings: true
include_rolling_stats: false
X_lag_cols: []
kwargs:
n_estimators: 100
random_state: 0
n_jobs: -1
outs:
- path: models/bikes_model.pkl
hash: md5
md5: 4537d54e03e658bc96b62baf3a52db78
size: 1219119
md5: a7529edfdb07b51c22a27e459fcdaa09
size: 1221859
eval@energy:
cmd: python src/probafcst/pipeline/evaluate.py --target energy
deps:
Expand All @@ -297,8 +333,8 @@ stages:
size: 1280486
- path: models/energy_model.pkl
hash: md5
md5: 818c68a30e12899f3f209a88aa8e96df
size: 2273354
md5: 7536442fe6561688c046c8a3bb75efba
size: 2276094
- path: src/probafcst//backtest.py
hash: md5
md5: c352b34254b7299b5ec810bac65b0949
Expand Down Expand Up @@ -336,19 +372,19 @@ stages:
outs:
- path: output/energy_eval_results.csv
hash: md5
md5: 918e9b4b808ef3a8411a160eb1708091
md5: ef24a00ba647a7a4d9149f3d012fee10
size: 5279
- path: output/energy_metrics.json
hash: md5
md5: b83a61cf6499da1efbfef176b912dd59
md5: fde281598f36df5135c5dc745cb83f5c
size: 583
- path: output/energy_pinball_losses.svg
hash: md5
md5: 8ba26fca45d6eaa56bfb689a0738bb9f
md5: 3293bd89ff7a37ae2756c6c331b1b4ce
size: 26352
- path: output/eval_plots/energy/
hash: md5
md5: 3c8779e7c44205be26773e9a68dd75c1.dir
md5: 019ee565f374587f408ad9c4fa0ee875.dir
size: 254098
nfiles: 4
eval@bikes:
Expand All @@ -360,8 +396,8 @@ stages:
size: 148879
- path: models/bikes_model.pkl
hash: md5
md5: 4537d54e03e658bc96b62baf3a52db78
size: 1219119
md5: a7529edfdb07b51c22a27e459fcdaa09
size: 1221859
- path: src/probafcst//backtest.py
hash: md5
md5: c352b34254b7299b5ec810bac65b0949
Expand Down Expand Up @@ -399,32 +435,32 @@ stages:
outs:
- path: output/bikes_eval_results.csv
hash: md5
md5: 7c624729cc89dd89c2f0c59d15ac6bd1
size: 19770
md5: 0bbace17ec7598b9fd6484280e1a6741
size: 19755
- path: output/bikes_metrics.json
hash: md5
md5: 9887af223700c37da1aa8dda455f58d8
md5: bfa27abac68c38c55bb5db43a8e299b0
size: 587
- path: output/bikes_pinball_losses.svg
hash: md5
md5: 9b9b6befc4fb9b62c395394a913b6f2d
md5: 64c1b6849335c08205809f9b1051df09
size: 30193
- path: output/eval_plots/bikes/
hash: md5
md5: 8acfbacb2865c4b620ffbc22a4f0d299.dir
md5: 97409185f3a29ecf2bd19f39bb551c89.dir
size: 136584
nfiles: 4
submit:
cmd: python src/probafcst/pipeline/submit.py
deps:
- path: models/bikes_model.pkl
hash: md5
md5: 4537d54e03e658bc96b62baf3a52db78
size: 1219119
md5: a7529edfdb07b51c22a27e459fcdaa09
size: 1221859
- path: models/energy_model.pkl
hash: md5
md5: 818c68a30e12899f3f209a88aa8e96df
size: 2273354
md5: 7536442fe6561688c046c8a3bb75efba
size: 2276094
- path: src/probafcst//plotting.py
hash: md5
md5: 482a42cf8b0b9196d98b0d8e772d83d2
Expand Down Expand Up @@ -452,13 +488,13 @@ stages:
outs:
- path: output/bikes_forecast.svg
hash: md5
md5: 054aa7d71bf8ddfac1a4fbe0dcb11dda
md5: 3ac12752d873a6a26d1f9e057cb1389c
size: 31828
- path: output/energy_forecast.svg
hash: md5
md5: 204c87458b6bfb51a55d1964d0dfa724
md5: 3d27f873b21994ca1a6c636eb9fa27cb
size: 68043
- path: output/submission.csv
hash: md5
md5: c2881c6556d14c23d5bf848ff6808ff2
md5: 5ae5f5042c96a647279d3234588cf15f
size: 1605
9 changes: 9 additions & 0 deletions notebooks/dvc.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ruff: noqa"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
37 changes: 37 additions & 0 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ train:
n_estimators: 100
verbose: 0
random_state: 0
qrf:
lags:
- 1
- 2
- 3
- 7
- 14
- 21
include_seasonal_dummies: true
cyclical_encodings: true
include_rolling_stats: false
X_lag_cols: []
kwargs:
n_estimators: 100
random_state: 0
n_jobs: -1



energy:
Expand Down Expand Up @@ -212,6 +229,26 @@ train:
n_estimators: 200
random_state: 0
verbose: 0
qrf:
lags:
- 24
- 48
- 72
- 96
- 120
- 144
- 168
- 336 # 2 weeks
- 504 # 3 weeks
- 672
include_seasonal_dummies: true
cyclical_encodings: true
X_lag_cols: []
include_rolling_stats: false
kwargs:
n_estimators: 100
random_state: 0
n_jobs: -1

eval:
backend: loky # null or 'loky' for parallel evaluate
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"lightgbm>=4.5.0",
"scikit-learn>=1.5.2",
"catboost>=1.2.7",
"quantile-forest>=1.3.11",
]


Expand Down
7 changes: 7 additions & 0 deletions src/probafcst/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from probafcst.models.darts import get_xgboost_model
from probafcst.models.lgbm import LGBMQuantileForecaster
from probafcst.models.linear_qr import LinearQuantileForecaster
from probafcst.models.qrf import RandomForestQuantileForecaster
from probafcst.models.xgboost import XGBQuantileForecaster


Expand Down Expand Up @@ -51,6 +52,12 @@ def get_model(
model = LGBMQuantileForecaster(**model_params, quantiles=quantiles)
case "catboost":
model = CatBoostQuantileForecaster(**model_params, quantiles=quantiles)
case "qrf":
if n_jobs is not None:
model_params["kwargs"]["n_jobs"] = n_jobs
model = RandomForestQuantileForecaster(**model_params, quantiles=quantiles)
case _:
raise ValueError(f"Unsupported model: {params.selected}")

return model

Expand Down
34 changes: 34 additions & 0 deletions src/probafcst/models/qrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Quantile regression forecaster using a random forest model."""

from quantile_forest import RandomForestQuantileRegressor

from probafcst.models.regression import QuantileRegressionForecaster


class RandomForestQuantileForecaster(QuantileRegressionForecaster):
"""Quantile regression forecaster using a random forest model."""

def __init__(
self,
lags: list[int],
quantiles: list[float],
include_seasonal_dummies=True,
cyclical_encodings=True,
include_rolling_stats=False,
X_lag_cols: list[str] | None = None,
kwargs: dict | None = None,
):
self.kwargs = kwargs or {}
model = RandomForestQuantileRegressor(
default_quantiles=quantiles,
**self.kwargs,
)
super().__init__(
model=model,
lags=lags,
quantiles=quantiles,
include_seasonal_dummies=include_seasonal_dummies,
include_rolling_stats=include_rolling_stats,
cyclical_encodings=cyclical_encodings,
X_lag_cols=X_lag_cols,
)
35 changes: 35 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8f0fa18

Please sign in to comment.