Skip to content

Commit

Permalink
add to_local method to distributed forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jan 19, 2024
1 parent 330d488 commit cb4952c
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 59 deletions.
4 changes: 3 additions & 1 deletion mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@
'mlforecast.distributed.forecast.DistributedMLForecast.preprocess': ( 'distributed.forecast.html#distributedmlforecast.preprocess',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.save': ( 'distributed.forecast.html#distributedmlforecast.save',
'mlforecast/distributed/forecast.py')},
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.to_local': ( 'distributed.forecast.html#distributedmlforecast.to_local',
'mlforecast/distributed/forecast.py')},
'mlforecast.distributed.models.dask.lgb': { 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast': ( 'distributed.models.dask.lgb.html#dasklgbmforecast',
'mlforecast/distributed/models/dask/lgb.py'),
'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast.model_': ( 'distributed.models.dask.lgb.html#dasklgbmforecast.model_',
Expand Down
82 changes: 80 additions & 2 deletions mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
DASK_INSTALLED = False
import fugue
import fugue.api as fa
import numpy as np
import pandas as pd
import utilsforecast.processing as ufp

try:
from pyspark.ml.feature import VectorAssembler
Expand All @@ -36,7 +38,6 @@
except ModuleNotFoundError:
RAY_INSTALLED = False
from sklearn.base import clone
from utilsforecast.processing import _single_split

from mlforecast.core import (
DateFeature,
Expand All @@ -47,6 +48,8 @@
TimeSeries,
_name_models,
)
from ..forecast import MLForecast
from ..grouped_array import GroupedArray

# %% ../../nbs/distributed.forecast.ipynb 6
WindowInfo = namedtuple(
Expand Down Expand Up @@ -161,7 +164,7 @@ def _preprocess_partition(
valid = None
else:
max_dates = part.groupby(id_col, observed=True)[time_col].transform("max")
cutoffs, train_mask, valid_mask = _single_split(
cutoffs, train_mask, valid_mask = ufp._single_split(
part,
i_window=window_info.i_window,
n_windows=window_info.n_windows,
Expand Down Expand Up @@ -708,3 +711,78 @@ def load(path: str, engine) -> "DistributedMLForecast":
fcst.engine = engine
fcst.num_partitions = len(paths)
return fcst

def to_local(self) -> MLForecast:
"""Convert this distributed forecast object into a local one
This pulls all the data from the remote machines, so you have to be sure that
it fits in the scheduler/driver. If you're not sure use the save method instead.
Returns
-------
MLForecast
Local forecast object."""
serialized_ts = (
fa.select_columns(
self._partition_results,
columns=["ts"],
as_fugue=True,
)
.as_pandas()["ts"]
.tolist()
)
all_ts = [cloudpickle.loads(ts) for ts in serialized_ts]
# sort by ids (these should already be sorted within each partition)
all_ts = sorted(all_ts, key=lambda ts: ts.uids[0])

# combine attributes. since fugue works on pandas these are all pandas.
# we're using utilsforecast here in case we add support for polars
def possibly_concat_indices(collection):
items_are_indices = isinstance(collection[0], pd.Index)
if items_are_indices:
collection = [pd.Series(item) for item in collection]
combined = ufp.vertical_concat(collection)
if items_are_indices:
combined = pd.Index(combined)
return combined

uids = possibly_concat_indices([ts.uids for ts in all_ts])
last_dates = possibly_concat_indices([ts.last_dates for ts in all_ts])
statics = ufp.vertical_concat([ts.static_features_ for ts in all_ts])
sizes = np.hstack([np.diff(ts.ga.indptr) for ts in all_ts])
data = np.hstack([ts.ga.data for ts in all_ts])
indptr = np.append(0, sizes).cumsum()
if isinstance(uids, pd.Index):
uids_idx = uids
else:
# uids is polars series
uids_idx = pd.Index(uids)
if not uids_idx.is_monotonic_increasing:
# this seems to happen only with ray
# we have to sort all data related to the series
sort_idxs = uids_idx.argsort()
uids = uids[sort_idxs]
last_dates = last_dates[sort_idxs]
statics = ufp.take_rows(statics, sort_idxs)
statics = ufp.drop_index_if_pandas(statics)
old_data = data.copy()
old_indptr = indptr.copy()
indptr = np.append(0, sizes[sort_idxs]).cumsum()
# this loop takes 500ms for 100,000 series of sizes between 500 and 2,000
# so it may not be that much of a bottleneck, but try to implement in core
for i, sort_idx in enumerate(sort_idxs):
old_slice = slice(old_indptr[sort_idx], old_indptr[sort_idx + 1])
new_slice = slice(indptr[i], indptr[i + 1])
data[new_slice] = old_data[old_slice]
ga = GroupedArray(data, indptr)

# all other attributes should be the same, so we just override the first serie
ts = all_ts[0]
ts.uids = uids
ts.last_dates = last_dates
ts.ga = ga
ts.static_features_ = statics
fcst = MLForecast(models=self.models_, freq=ts.freq)
fcst.ts = ts
fcst.models_ = self.models_
return fcst
Loading

0 comments on commit cb4952c

Please sign in to comment.