From cb4952c405e70661fbee2a406a950dd67f85fc12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 18 Jan 2024 19:33:02 -0600 Subject: [PATCH] add to_local method to distributed forecast --- mlforecast/_modidx.py | 4 +- mlforecast/distributed/forecast.py | 82 +++++- nbs/distributed.forecast.ipynb | 248 +++++++++++++++++- .../quick_start_distributed.ipynb | 153 ++++++++--- 4 files changed, 428 insertions(+), 59 deletions(-) diff --git a/mlforecast/_modidx.py b/mlforecast/_modidx.py index b596715b..194e2ed0 100644 --- a/mlforecast/_modidx.py +++ b/mlforecast/_modidx.py @@ -92,7 +92,9 @@ 'mlforecast.distributed.forecast.DistributedMLForecast.preprocess': ( 'distributed.forecast.html#distributedmlforecast.preprocess', 'mlforecast/distributed/forecast.py'), 'mlforecast.distributed.forecast.DistributedMLForecast.save': ( 'distributed.forecast.html#distributedmlforecast.save', - 'mlforecast/distributed/forecast.py')}, + 'mlforecast/distributed/forecast.py'), + 'mlforecast.distributed.forecast.DistributedMLForecast.to_local': ( 'distributed.forecast.html#distributedmlforecast.to_local', + 'mlforecast/distributed/forecast.py')}, 'mlforecast.distributed.models.dask.lgb': { 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast': ( 'distributed.models.dask.lgb.html#dasklgbmforecast', 'mlforecast/distributed/models/dask/lgb.py'), 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast.model_': ( 'distributed.models.dask.lgb.html#dasklgbmforecast.model_', diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index 28a55a3c..f700df74 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -19,7 +19,9 @@ DASK_INSTALLED = False import fugue import fugue.api as fa +import numpy as np import pandas as pd +import utilsforecast.processing as ufp try: from pyspark.ml.feature import VectorAssembler @@ -36,7 +38,6 @@ except ModuleNotFoundError: RAY_INSTALLED = False from sklearn.base import clone -from utilsforecast.processing import _single_split from mlforecast.core import ( DateFeature, @@ -47,6 +48,8 @@ TimeSeries, _name_models, ) +from ..forecast import MLForecast +from ..grouped_array import GroupedArray # %% ../../nbs/distributed.forecast.ipynb 6 WindowInfo = namedtuple( @@ -161,7 +164,7 @@ def _preprocess_partition( valid = None else: max_dates = part.groupby(id_col, observed=True)[time_col].transform("max") - cutoffs, train_mask, valid_mask = _single_split( + cutoffs, train_mask, valid_mask = ufp._single_split( part, i_window=window_info.i_window, n_windows=window_info.n_windows, @@ -708,3 +711,78 @@ def load(path: str, engine) -> "DistributedMLForecast": fcst.engine = engine fcst.num_partitions = len(paths) return fcst + + def to_local(self) -> MLForecast: + """Convert this distributed forecast object into a local one + + This pulls all the data from the remote machines, so you have to be sure that + it fits in the scheduler/driver. If you're not sure use the save method instead. + + Returns + ------- + MLForecast + Local forecast object.""" + serialized_ts = ( + fa.select_columns( + self._partition_results, + columns=["ts"], + as_fugue=True, + ) + .as_pandas()["ts"] + .tolist() + ) + all_ts = [cloudpickle.loads(ts) for ts in serialized_ts] + # sort by ids (these should already be sorted within each partition) + all_ts = sorted(all_ts, key=lambda ts: ts.uids[0]) + + # combine attributes. since fugue works on pandas these are all pandas. + # we're using utilsforecast here in case we add support for polars + def possibly_concat_indices(collection): + items_are_indices = isinstance(collection[0], pd.Index) + if items_are_indices: + collection = [pd.Series(item) for item in collection] + combined = ufp.vertical_concat(collection) + if items_are_indices: + combined = pd.Index(combined) + return combined + + uids = possibly_concat_indices([ts.uids for ts in all_ts]) + last_dates = possibly_concat_indices([ts.last_dates for ts in all_ts]) + statics = ufp.vertical_concat([ts.static_features_ for ts in all_ts]) + sizes = np.hstack([np.diff(ts.ga.indptr) for ts in all_ts]) + data = np.hstack([ts.ga.data for ts in all_ts]) + indptr = np.append(0, sizes).cumsum() + if isinstance(uids, pd.Index): + uids_idx = uids + else: + # uids is polars series + uids_idx = pd.Index(uids) + if not uids_idx.is_monotonic_increasing: + # this seems to happen only with ray + # we have to sort all data related to the series + sort_idxs = uids_idx.argsort() + uids = uids[sort_idxs] + last_dates = last_dates[sort_idxs] + statics = ufp.take_rows(statics, sort_idxs) + statics = ufp.drop_index_if_pandas(statics) + old_data = data.copy() + old_indptr = indptr.copy() + indptr = np.append(0, sizes[sort_idxs]).cumsum() + # this loop takes 500ms for 100,000 series of sizes between 500 and 2,000 + # so it may not be that much of a bottleneck, but try to implement in core + for i, sort_idx in enumerate(sort_idxs): + old_slice = slice(old_indptr[sort_idx], old_indptr[sort_idx + 1]) + new_slice = slice(indptr[i], indptr[i + 1]) + data[new_slice] = old_data[old_slice] + ga = GroupedArray(data, indptr) + + # all other attributes should be the same, so we just override the first serie + ts = all_ts[0] + ts.uids = uids + ts.last_dates = last_dates + ts.ga = ga + ts.static_features_ = statics + fcst = MLForecast(models=self.models_, freq=ts.freq) + fcst.ts = ts + fcst.models_ = self.models_ + return fcst diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index 2aea1ed5..f7e80d3a 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -82,7 +82,9 @@ " DASK_INSTALLED = False\n", "import fugue\n", "import fugue.api as fa\n", + "import numpy as np\n", "import pandas as pd\n", + "import utilsforecast.processing as ufp\n", "try:\n", " from pyspark.ml.feature import VectorAssembler\n", " from pyspark.sql import DataFrame as SparkDataFrame\n", @@ -96,7 +98,6 @@ "except ModuleNotFoundError:\n", " RAY_INSTALLED = False\n", "from sklearn.base import clone\n", - "from utilsforecast.processing import _single_split\n", "\n", "from mlforecast.core import (\n", " DateFeature,\n", @@ -106,7 +107,9 @@ " TargetTransform,\n", " TimeSeries,\n", " _name_models,\n", - ")" + ")\n", + "from mlforecast.forecast import MLForecast\n", + "from mlforecast.grouped_array import GroupedArray" ] }, { @@ -229,7 +232,7 @@ " valid = None\n", " else:\n", " max_dates = part.groupby(id_col, observed=True)[time_col].transform('max')\n", - " cutoffs, train_mask, valid_mask = _single_split(\n", + " cutoffs, train_mask, valid_mask = ufp._single_split(\n", " part,\n", " i_window=window_info.i_window,\n", " n_windows=window_info.n_windows,\n", @@ -755,6 +758,77 @@ " fcst.models_ = models \n", " fcst.engine = engine\n", " fcst.num_partitions = len(paths)\n", + " return fcst\n", + "\n", + " def to_local(self) -> MLForecast:\n", + " \"\"\"Convert this distributed forecast object into a local one\n", + " \n", + " This pulls all the data from the remote machines, so you have to be sure that \n", + " it fits in the scheduler/driver. If you're not sure use the save method instead.\n", + " \n", + " Returns\n", + " -------\n", + " MLForecast\n", + " Local forecast object.\"\"\"\n", + " serialized_ts = fa.select_columns(\n", + " self._partition_results,\n", + " columns=['ts'],\n", + " as_fugue=True,\n", + " ).as_pandas()['ts'].tolist()\n", + " all_ts = [cloudpickle.loads(ts) for ts in serialized_ts]\n", + " # sort by ids (these should already be sorted within each partition)\n", + " all_ts = sorted(all_ts, key=lambda ts: ts.uids[0])\n", + " \n", + " # combine attributes. since fugue works on pandas these are all pandas.\n", + " # we're using utilsforecast here in case we add support for polars\n", + " def possibly_concat_indices(collection):\n", + " items_are_indices = isinstance(collection[0], pd.Index)\n", + " if items_are_indices:\n", + " collection = [pd.Series(item) for item in collection]\n", + " combined = ufp.vertical_concat(collection)\n", + " if items_are_indices:\n", + " combined = pd.Index(combined)\n", + " return combined\n", + "\n", + " uids = possibly_concat_indices([ts.uids for ts in all_ts])\n", + " last_dates = possibly_concat_indices([ts.last_dates for ts in all_ts])\n", + " statics = ufp.vertical_concat([ts.static_features_ for ts in all_ts])\n", + " sizes = np.hstack([np.diff(ts.ga.indptr) for ts in all_ts]) \n", + " data = np.hstack([ts.ga.data for ts in all_ts])\n", + " indptr = np.append(0, sizes).cumsum()\n", + " if isinstance(uids, pd.Index):\n", + " uids_idx = uids\n", + " else:\n", + " # uids is polars series\n", + " uids_idx = pd.Index(uids)\n", + " if not uids_idx.is_monotonic_increasing:\n", + " # this seems to happen only with ray\n", + " # we have to sort all data related to the series\n", + " sort_idxs = uids_idx.argsort()\n", + " uids = uids[sort_idxs]\n", + " last_dates = last_dates[sort_idxs]\n", + " statics = ufp.take_rows(statics, sort_idxs)\n", + " statics = ufp.drop_index_if_pandas(statics)\n", + " old_data = data.copy()\n", + " old_indptr = indptr.copy()\n", + " indptr = np.append(0, sizes[sort_idxs]).cumsum()\n", + " # this loop takes 500ms for 100,000 series of sizes between 500 and 2,000\n", + " # so it may not be that much of a bottleneck, but try to implement in core\n", + " for i, sort_idx in enumerate(sort_idxs):\n", + " old_slice = slice(old_indptr[sort_idx], old_indptr[sort_idx + 1])\n", + " new_slice = slice(indptr[i], indptr[i + 1])\n", + " data[new_slice] = old_data[old_slice]\n", + " ga = GroupedArray(data, indptr)\n", + "\n", + " # all other attributes should be the same, so we just override the first serie\n", + " ts = all_ts[0]\n", + " ts.uids = uids\n", + " ts.last_dates = last_dates\n", + " ts.ga = ga\n", + " ts.static_features_ = statics\n", + " fcst = MLForecast(models=self.models_, freq=ts.freq)\n", + " fcst.ts = ts\n", + " fcst.models_ = self.models_\n", " return fcst" ] }, @@ -769,7 +843,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L60){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast\n", "\n", @@ -790,7 +864,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L60){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast\n", "\n", @@ -829,7 +903,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L386){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.fit\n", "\n", @@ -855,7 +929,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L386){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.fit\n", "\n", @@ -899,7 +973,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L462){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.predict\n", "\n", @@ -924,7 +998,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L451){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L462){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.predict\n", "\n", @@ -956,6 +1030,154 @@ "show_doc(DistributedMLForecast.predict)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a3326fd-638d-43ba-93d7-3b3d289c68ba", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L645){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.save\n", + "\n", + "> DistributedMLForecast.save (path:str)\n", + "\n", + "Save forecast object\n", + "\n", + "| | **Type** | **Details** |\n", + "| -- | -------- | ----------- |\n", + "| path | str | Directory where artifacts will be stored. |\n", + "| **Returns** | **None** | |" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L645){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.save\n", + "\n", + "> DistributedMLForecast.save (path:str)\n", + "\n", + "Save forecast object\n", + "\n", + "| | **Type** | **Details** |\n", + "| -- | -------- | ----------- |\n", + "| path | str | Directory where artifacts will be stored. |\n", + "| **Returns** | **None** | |" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(DistributedMLForecast.save)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "309d0b0f-f2ce-4afa-b192-f7bcd6f5044f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.load\n", + "\n", + "> DistributedMLForecast.load (path:str, engine)\n", + "\n", + "Load forecast object\n", + "\n", + "| | **Type** | **Details** |\n", + "| -- | -------- | ----------- |\n", + "| path | str | Directory with saved artifacts. |\n", + "| engine | fugue execution engine | Dask Client, Spark Session, etc to use for the distributed computation. |\n", + "| **Returns** | **DistributedMLForecast** | |" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.load\n", + "\n", + "> DistributedMLForecast.load (path:str, engine)\n", + "\n", + "Load forecast object\n", + "\n", + "| | **Type** | **Details** |\n", + "| -- | -------- | ----------- |\n", + "| path | str | Directory with saved artifacts. |\n", + "| engine | fugue execution engine | Dask Client, Spark Session, etc to use for the distributed computation. |\n", + "| **Returns** | **DistributedMLForecast** | |" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(DistributedMLForecast.load)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad4916ff-ad5c-4fb5-a8c4-e7befb52fa67", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L715){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.to_local\n", + "\n", + "> DistributedMLForecast.to_local ()\n", + "\n", + "Convert this distributed forecast object into a local one\n", + "\n", + "This pulls all the data from the remote machines, so you have to be sure that \n", + "it fits in the scheduler/driver. If you're not sure use the save method instead." + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L715){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### DistributedMLForecast.to_local\n", + "\n", + "> DistributedMLForecast.to_local ()\n", + "\n", + "Convert this distributed forecast object into a local one\n", + "\n", + "This pulls all the data from the remote machines, so you have to be sure that \n", + "it fits in the scheduler/driver. If you're not sure use the save method instead." + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(DistributedMLForecast.to_local)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -967,7 +1189,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L290){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.preprocess\n", "\n", @@ -994,7 +1216,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L290){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.preprocess\n", "\n", @@ -1039,7 +1261,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L527){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.cross_validation\n", "\n", @@ -1082,7 +1304,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L510){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L527){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DistributedMLForecast.cross_validation\n", "\n", diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 1a8be2ea..632968ff 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -424,36 +424,36 @@ " 0\n", " id_00\n", " 2002-09-27\n", - " 18.819103\n", - " 17.900281\n", + " 19.435570\n", + " 17.420052\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28\n", - " 89.682961\n", - " 91.353827\n", + " 90.069359\n", + " 89.933465\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29\n", - " 167.320984\n", - " 167.335792\n", + " 166.154678\n", + " 165.127939\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30\n", - " 245.242462\n", - " 243.613032\n", + " 244.439392\n", + " 246.553488\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01\n", - " 315.341370\n", - " 313.804709\n", + " 317.734375\n", + " 319.077260\n", " \n", " \n", "\n", @@ -461,11 +461,11 @@ ], "text/plain": [ " unique_id ds DaskXGBForecast DaskLGBMForecast\n", - "0 id_00 2002-09-27 18.819103 17.900281\n", - "1 id_00 2002-09-28 89.682961 91.353827\n", - "2 id_00 2002-09-29 167.320984 167.335792\n", - "3 id_00 2002-09-30 245.242462 243.613032\n", - "4 id_00 2002-10-01 315.341370 313.804709" + "0 id_00 2002-09-27 19.435570 17.420052\n", + "1 id_00 2002-09-28 90.069359 89.933465\n", + "2 id_00 2002-09-29 166.154678 165.127939\n", + "3 id_00 2002-09-30 244.439392 246.553488\n", + "4 id_00 2002-10-01 317.734375 319.077260" ] }, "execution_count": null, @@ -610,6 +610,30 @@ "pd.testing.assert_frame_equal(preds, preds2)" ] }, + { + "cell_type": "markdown", + "id": "4190830c-b2e5-4343-97da-023e9f532ef6", + "metadata": {}, + "source": [ + "### Converting to local\n", + "\n", + "Another option to store your distributed forecast object is to first turn it into a local one and then save it. Keep in mind that in order to do that all the remote data that is stored from the series will have to be pulled into a single machine (the scheduler in dask, driver in spark, etc.), so you have to be sure that it'll fit in memory, it should consume about 2x the size of your target column (you can reduce this further by using the `keep_last_n` argument in the `fit` method)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05f2cefe-d032-413e-89d3-5d01271c5d39", + "metadata": {}, + "outputs": [], + "source": [ + "local_fcst = fcst.to_local()\n", + "local_preds = local_fcst.predict(10)\n", + "# we don't check the dtype because sometimes these are arrow dtypes\n", + "# or different precisions of float\n", + "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" + ] + }, { "cell_type": "markdown", "id": "29841c02-b0bc-44cc-a8f3-da31b442584b", @@ -925,19 +949,22 @@ { "cell_type": "code", "execution_count": null, - "id": "9609dee3-f959-4c28-b913-0871f94e2863", + "id": "32cea90e-ea7b-4244-b26f-a16d3e85fe93", "metadata": {}, "outputs": [], "source": [ "from mlforecast.distributed.models.spark.lgb import SparkLGBMForecast\n", - "\n", - "models = [SparkLGBMForecast()]\n", - "try:\n", - " from xgboost.spark import SparkXGBRegressor\n", - " from mlforecast.distributed.models.spark.xgb import SparkXGBForecast\n", - " models.append(SparkXGBForecast())\n", - "except ModuleNotFoundError: # py < 38\n", - " pass" + "from mlforecast.distributed.models.spark.xgb import SparkXGBForecast" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03231a34-30a0-44a3-805a-c91ce6eba393", + "metadata": {}, + "outputs": [], + "source": [ + "models = [SparkLGBMForecast(), SparkXGBForecast()]" ] }, { @@ -1029,7 +1056,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = fcst.predict(14)" + "preds = fcst.predict(14).toPandas()" ] }, { @@ -1071,14 +1098,14 @@ " id_00\n", " 2001-05-15\n", " 422.139843\n", - " 417.848083\n", + " 424.463562\n", " \n", " \n", " 1\n", " id_00\n", " 2001-05-16\n", " 497.180212\n", - " 503.371185\n", + " 505.564667\n", " \n", " \n", " 2\n", @@ -1107,8 +1134,8 @@ ], "text/plain": [ " unique_id ds SparkLGBMForecast SparkXGBForecast\n", - "0 id_00 2001-05-15 422.139843 417.848083\n", - "1 id_00 2001-05-16 497.180212 503.371185\n", + "0 id_00 2001-05-15 422.139843 424.463562\n", + "1 id_00 2001-05-16 497.180212 505.564667\n", "2 id_00 2001-05-17 13.062478 18.514997\n", "3 id_00 2001-05-18 100.601041 109.317825\n", "4 id_00 2001-05-19 180.707848 181.431747" @@ -1120,7 +1147,7 @@ } ], "source": [ - "preds.toPandas().head()" + "preds.head()" ] }, { @@ -1200,21 +1227,37 @@ "execution_count": null, "id": "04005646-d244-48c8-b11c-124a41d9740c", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - } - ], + "outputs": [], "source": [ "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(preds, preds2)" ] }, + { + "cell_type": "markdown", + "id": "21dd6f64-834f-48af-8cfa-4da2d7adfffa", + "metadata": {}, + "source": [ + "### Converting to local\n", + "\n", + "Another option to store your distributed forecast object is to first turn it into a local one and then save it. Keep in mind that in order to do that all the remote data that is stored from the series will have to be pulled into a single machine (the scheduler in dask, driver in spark, etc.), so you have to be sure that it'll fit in memory, it should consume about 2x the size of your target column (you can reduce this further by using the `keep_last_n` argument in the `fit` method)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7af4994-44dd-48e6-b48d-e8faa1a2a846", + "metadata": {}, + "outputs": [], + "source": [ + "local_fcst = fcst.to_local()\n", + "local_preds = local_fcst.predict(10)\n", + "# we don't check the dtype because sometimes these are arrow dtypes\n", + "# or different precisions of float\n", + "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" + ] + }, { "cell_type": "markdown", "id": "5ad24338-2634-4b27-8f73-dc162338dd38", @@ -1609,14 +1652,14 @@ " id_00\n", " 2001-05-15\n", " 422.139843\n", - " 418.110107\n", + " 424.179993\n", " \n", " \n", " 1\n", " id_00\n", " 2001-05-16\n", " 497.180212\n", - " 502.229492\n", + " 501.030060\n", " \n", " \n", " 2\n", @@ -1645,8 +1688,8 @@ ], "text/plain": [ " unique_id ds RayLGBMForecast RayXGBForecast\n", - "0 id_00 2001-05-15 422.139843 418.110107\n", - "1 id_00 2001-05-16 497.180212 502.229492\n", + "0 id_00 2001-05-15 422.139843 424.179993\n", + "1 id_00 2001-05-16 497.180212 501.030060\n", "2 id_00 2001-05-17 13.062478 18.364956\n", "3 id_00 2001-05-18 100.601041 102.921730\n", "4 id_00 2001-05-19 180.707848 183.109436" @@ -1729,6 +1772,30 @@ "pd.testing.assert_frame_equal(preds, preds2)" ] }, + { + "cell_type": "markdown", + "id": "01ea42c6-8a0a-4a6d-8416-e78e0f7eccd8", + "metadata": {}, + "source": [ + "### Converting to local\n", + "\n", + "Another option to store your distributed forecast object is to first turn it into a local one and then save it. Keep in mind that in order to do that all the remote data that is stored from the series will have to be pulled into a single machine (the scheduler in dask, driver in spark, etc.), so you have to be sure that it'll fit in memory, it should consume about 2x the size of your target column (you can reduce this further by using the `keep_last_n` argument in the `fit` method)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f844ebe4-423f-4790-b5b1-d9550e4f1835", + "metadata": {}, + "outputs": [], + "source": [ + "local_fcst = fcst.to_local()\n", + "local_preds = local_fcst.predict(10)\n", + "# we don't check the dtype because sometimes these are arrow dtypes\n", + "# or different precisions of float\n", + "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" + ] + }, { "cell_type": "markdown", "id": "cb901de2-c98e-47fb-81e2-714010114a91",