Skip to content

Commit

Permalink
support lag transformations from coreforecast (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 6, 2023
1 parent 0220cbc commit 68c33b1
Show file tree
Hide file tree
Showing 19 changed files with 1,506 additions and 62 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
run-all-tests:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
Expand All @@ -59,10 +60,11 @@ jobs:
run: pip install ./

- name: Run all tests
run: nbdev_test --n_workers 1 --do_print --timing --flags polars
run: nbdev_test --n_workers 0 --do_print --timing --flags 'polars core'

run-local-tests:
runs-on: ${{ matrix.os }}
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
Expand All @@ -83,7 +85,7 @@ jobs:
run: pip install ./

- name: Run local tests
run: nbdev_test --n_workers 1 --do_print --timing --skip_file_glob "*distributed*" --flags polars
run: nbdev_test --n_workers 0 --do_print --timing --skip_file_glob "*distributed*" --flags 'polars core'

check-deps:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: mlforecast
channels:
- conda-forge
dependencies:
- coreforecast>=0.0.2
- dask<2023.1.1
- holidays<0.21
- lightgbm
Expand Down
1 change: 1 addition & 0 deletions local_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: mlforecast
channels:
- conda-forge
dependencies:
- coreforecast>=0.0.2
- holidays<0.21
- lightgbm
- matplotlib
Expand Down
64 changes: 63 additions & 1 deletion mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'mlforecast/callbacks.py'),
'mlforecast.callbacks.SaveFeatures.get_features': ( 'callbacks.html#savefeatures.get_features',
'mlforecast/callbacks.py')},
'mlforecast.compat': {},
'mlforecast.core': { 'mlforecast.core.TimeSeries': ('core.html#timeseries', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries.__init__': ('core.html#timeseries.__init__', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries.__repr__': ('core.html#timeseries.__repr__', 'mlforecast/core.py'),
Expand Down Expand Up @@ -44,11 +45,13 @@
'mlforecast.core.TimeSeries.predict': ('core.html#timeseries.predict', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries.update': ('core.html#timeseries.update', 'mlforecast/core.py'),
'mlforecast.core._as_tuple': ('core.html#_as_tuple', 'mlforecast/core.py'),
'mlforecast.core._build_lag_transform_name': ('core.html#_build_lag_transform_name', 'mlforecast/core.py'),
'mlforecast.core._build_transform_name': ('core.html#_build_transform_name', 'mlforecast/core.py'),
'mlforecast.core._expand_target': ('core.html#_expand_target', 'mlforecast/core.py'),
'mlforecast.core._identity': ('core.html#_identity', 'mlforecast/core.py'),
'mlforecast.core._name_models': ('core.html#_name_models', 'mlforecast/core.py'),
'mlforecast.core._parse_transforms': ('core.html#_parse_transforms', 'mlforecast/core.py')},
'mlforecast.core._parse_transforms': ('core.html#_parse_transforms', 'mlforecast/core.py'),
'mlforecast.core._pascal2camel': ('core.html#_pascal2camel', 'mlforecast/core.py')},
'mlforecast.distributed.forecast': { 'mlforecast.distributed.forecast.DistributedMLForecast': ( 'distributed.forecast.html#distributedmlforecast',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.__init__': ( 'distributed.forecast.html#distributedmlforecast.__init__',
Expand Down Expand Up @@ -189,6 +192,65 @@
'mlforecast/grouped_array.py'),
'mlforecast.grouped_array._transform_series': ( 'grouped_array.html#_transform_series',
'mlforecast/grouped_array.py')},
'mlforecast.lag_transforms': { 'mlforecast.lag_transforms.BaseLagTransform': ( 'lag_transforms.html#baselagtransform',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.BaseLagTransform.transform': ( 'lag_transforms.html#baselagtransform.transform',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.BaseLagTransform.update': ( 'lag_transforms.html#baselagtransform.update',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingBase': ( 'lag_transforms.html#expandingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingBase.__init__': ( 'lag_transforms.html#expandingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingBase._set_core_tfm': ( 'lag_transforms.html#expandingbase._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMax': ( 'lag_transforms.html#expandingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMean': ( 'lag_transforms.html#expandingmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMin': ( 'lag_transforms.html#expandingmin',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingStd': ( 'lag_transforms.html#expandingstd',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean': ( 'lag_transforms.html#exponentiallyweightedmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean.__init__': ( 'lag_transforms.html#exponentiallyweightedmean.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean._set_core_tfm': ( 'lag_transforms.html#exponentiallyweightedmean._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag': ('lag_transforms.html#lag', 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag.__eq__': ( 'lag_transforms.html#lag.__eq__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag.__init__': ( 'lag_transforms.html#lag.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingBase': ( 'lag_transforms.html#rollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingBase.__init__': ( 'lag_transforms.html#rollingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingBase._set_core_tfm': ( 'lag_transforms.html#rollingbase._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMax': ( 'lag_transforms.html#rollingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMean': ( 'lag_transforms.html#rollingmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMin': ( 'lag_transforms.html#rollingmin',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingStd': ( 'lag_transforms.html#rollingstd',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingBase': ( 'lag_transforms.html#seasonalrollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingBase.__init__': ( 'lag_transforms.html#seasonalrollingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingBase._set_core_tfm': ( 'lag_transforms.html#seasonalrollingbase._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingMax': ( 'lag_transforms.html#seasonalrollingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingMean': ( 'lag_transforms.html#seasonalrollingmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingMin': ( 'lag_transforms.html#seasonalrollingmin',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.SeasonalRollingStd': ( 'lag_transforms.html#seasonalrollingstd',
'mlforecast/lag_transforms.py')},
'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'),
'mlforecast.lgb_cv.LightGBMCV.__init__': ('lgb_cv.html#lightgbmcv.__init__', 'mlforecast/lgb_cv.py'),
'mlforecast.lgb_cv.LightGBMCV.__repr__': ('lgb_cv.html#lightgbmcv.__repr__', 'mlforecast/lgb_cv.py'),
Expand Down
23 changes: 23 additions & 0 deletions mlforecast/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/compat.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/compat.ipynb 1
try:
import coreforecast.lag_transforms as core_tfms
from coreforecast.grouped_array import GroupedArray as CoreGroupedArray

from mlforecast.lag_transforms import BaseLagTransform, Lag

CORE_INSTALLED = True
except ImportError:
core_tfms = None
CoreGroupedArray = None

class BaseLagTransform:
...

Lag = None

CORE_INSTALLED = False
63 changes: 42 additions & 21 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# %% ../nbs/core.ipynb 3
import copy
import inspect
import re
import reprlib
import warnings
from collections import Counter, OrderedDict
Expand All @@ -14,7 +15,7 @@
import numpy as np
import pandas as pd
from numba import njit
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, clone
from utilsforecast.compat import (
DataFrame,
pl,
Expand Down Expand Up @@ -46,6 +47,7 @@
)
from utilsforecast.validation import validate_format

from .compat import CORE_INSTALLED, BaseLagTransform, Lag
from .grouped_array import GroupedArray
from mlforecast.target_transforms import (
BaseGroupedArrayTargetTransform,
Expand Down Expand Up @@ -94,6 +96,23 @@ def _build_transform_name(lag, tfm, *args) -> str:
return tfm_name

# %% ../nbs/core.ipynb 13
def _pascal2camel(pascal_str: str) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "_", pascal_str).lower()

# %% ../nbs/core.ipynb 14
def _build_lag_transform_name(tfm: BaseLagTransform, lag: int) -> str:
tfm_params = list(inspect.signature(tfm.__init__).parameters.items()) # type: ignore
tfm_name = f"{_pascal2camel(tfm.__class__.__name__)}_lag{lag}"
changed_params = [
f"{name}{getattr(tfm, name)}"
for name, arg in tfm_params
if arg.default != getattr(tfm, name)
]
if changed_params:
tfm_name += "_" + "_".join(changed_params)
return tfm_name

# %% ../nbs/core.ipynb 16
def _name_models(current_names):
ctr = Counter(current_names)
if not ctr:
Expand All @@ -111,7 +130,7 @@ def _name_models(current_names):
names[-i] = name
return names

# %% ../nbs/core.ipynb 15
# %% ../nbs/core.ipynb 18
@njit
def _identity(x: np.ndarray) -> np.ndarray:
"""Do nothing to the input."""
Expand Down Expand Up @@ -141,7 +160,7 @@ def _expand_target(data, indptr, max_horizon):
n += 1
return out

# %% ../nbs/core.ipynb 16
# %% ../nbs/core.ipynb 19
Freq = Union[int, str, pd.offsets.BaseOffset]
Lags = Iterable[int]
LagTransform = Union[Callable, Tuple[Callable, Any]]
Expand All @@ -150,22 +169,29 @@ def _expand_target(data, indptr, max_horizon):
Models = Union[BaseEstimator, List[BaseEstimator], Dict[str, BaseEstimator]]
TargetTransform = Union[BaseTargetTransform, BaseGroupedArrayTargetTransform]

# %% ../nbs/core.ipynb 17
# %% ../nbs/core.ipynb 20
def _parse_transforms(
lags: Lags,
lag_transforms: LagTransforms,
) -> Dict[str, Tuple[Any, ...]]:
transforms: Dict[str, Tuple[Any, ...]] = OrderedDict()
transforms: Dict[str, Union[Tuple[Any, ...], BaseLagTransform]] = OrderedDict()
for lag in lags:
transforms[f"lag{lag}"] = (lag, _identity)
if CORE_INSTALLED:
transforms[f"lag{lag}"] = Lag(lag)
else:
transforms[f"lag{lag}"] = (lag, _identity)
for lag in lag_transforms.keys():
for tfm_args in lag_transforms[lag]:
tfm, *args = _as_tuple(tfm_args)
tfm_name = _build_transform_name(lag, tfm, *args)
transforms[tfm_name] = (lag, tfm, *args)
for tfm in lag_transforms[lag]:
if isinstance(tfm, BaseLagTransform):
tfm_name = _build_lag_transform_name(tfm, lag)
transforms[tfm_name] = clone(tfm)._set_core_tfm(lag)
else:
tfm, *args = _as_tuple(tfm)
tfm_name = _build_transform_name(lag, tfm, *args)
transforms[tfm_name] = (lag, tfm, *args)
return transforms

# %% ../nbs/core.ipynb 18
# %% ../nbs/core.ipynb 21
class TimeSeries:
"""Utility class for storing and transforming time series data."""

Expand Down Expand Up @@ -331,15 +357,15 @@ def _fit(
] + self.features
return self

def _compute_transforms(self) -> Dict[str, np.ndarray]:
def _compute_transforms(self, updates_only: bool) -> Dict[str, np.ndarray]:
"""Compute the transformations defined in the constructor.
If `self.num_threads > 1` these are computed using multithreading."""
if self.num_threads == 1 or len(self.transforms) == 1:
out = self.ga.apply_transforms(self.transforms, updates_only=False)
out = self.ga.apply_transforms(self.transforms, updates_only=updates_only)
else:
out = self.ga.apply_multithreaded_transforms(
self.transforms, num_threads=self.num_threads, updates_only=False
self.transforms, num_threads=self.num_threads, updates_only=updates_only
)
return out

Expand Down Expand Up @@ -373,7 +399,7 @@ def _transform(
"""Add the features to `df`.
if `dropna=True` then all the null rows are dropped."""
features = self._compute_transforms()
features = self._compute_transforms(updates_only=False)
if self._restore_idxs is not None:
for k, v in features.items():
features[k] = v[self._restore_idxs]
Expand Down Expand Up @@ -515,12 +541,7 @@ def _update_features(self) -> DataFrame:
)
self.test_dates.append(self.curr_dates)

if self.num_threads == 1 or len(self.transforms) == 1:
features = self.ga.apply_transforms(self.transforms, updates_only=True)
else:
features = self.ga.apply_multithreaded_transforms(
self.transforms, num_threads=self.num_threads, updates_only=True
)
features = self._compute_transforms(updates_only=True)

for feature in self.date_features:
feat_name, feat_vals = self._compute_date_feature(self.curr_dates, feature)
Expand Down
Loading

0 comments on commit 68c33b1

Please sign in to comment.