Skip to content

Commit

Permalink
Merge branch 'main' into fix-predict
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jan 2, 2024
2 parents 8561672 + 6d4c59d commit 3a3c07d
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 21 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ dependencies:
- polars
- ray<2.8
- triad==0.9.1
- utilsforecast>=0.0.22
- utilsforecast>=0.0.24
- xgboost_ray
2 changes: 1 addition & 1 deletion local_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ dependencies:
- datasetsforecast
- nbdev
- polars
- utilsforecast>=0.0.22
- utilsforecast>=0.0.24
6 changes: 5 additions & 1 deletion mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,11 @@ def update(self, df: DataFrame) -> None:
df = ufp.sort(df, by=[self.id_col, self.time_col])
values = df[self.target_col].to_numpy()
id_counts = ufp.counts_by_id(df, self.id_col)
sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer")
try:
sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer_coalesce")
except (KeyError, ValueError):
# pandas raises key error, polars before coalesce raises value error
sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer")
sizes = ufp.fill_null(sizes, {"counts": 0})
sizes = ufp.sort(sizes, by=self.id_col)
new_groups = ~ufp.is_in(sizes[self.id_col], uids)
Expand Down
13 changes: 10 additions & 3 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,11 @@
" df = ufp.sort(df, by=[self.id_col, self.time_col])\n",
" values = df[self.target_col].to_numpy() \n",
" id_counts = ufp.counts_by_id(df, self.id_col)\n",
" sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n",
" try:\n",
" sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer_coalesce')\n",
" except (KeyError, ValueError):\n",
" # pandas raises key error, polars before coalesce raises value error\n",
" sizes = ufp.join(uids, id_counts, on=self.id_col, how='outer')\n",
" sizes = ufp.fill_null(sizes, {'counts': 0})\n",
" sizes = ufp.sort(sizes, by=self.id_col)\n",
" new_groups = ~ufp.is_in(sizes[self.id_col], uids)\n",
Expand Down Expand Up @@ -2181,7 +2185,7 @@
")\n",
"last_val_id0 = last_vals_two_series.filter(pl.col('unique_id') == 'id_0')\n",
"new_values = last_val_id0.with_columns(\n",
" pl.col('unique_id').cast(pl.Utf8),\n",
" pl.col('unique_id').cast(pl.Categorical),\n",
" pl.col('ds').dt.offset_by('1d'),\n",
" pl.col('static_0').cast(pl.Int64),\n",
" pl.col('static_1').cast(pl.Int64),\n",
Expand All @@ -2192,7 +2196,10 @@
" 'y': [5.0, 6.0],\n",
" 'static_0': [0, 0],\n",
" 'static_1': [1, 1],\n",
"}).with_columns(pl.col('ds').dt.cast_time_unit('ns'))\n",
"}).with_columns(\n",
" pl.col('ds').dt.cast_time_unit('ns'),\n",
" pl.col('unique_id').cast(pl.Categorical),\n",
")\n",
"new_values = pl.concat([new_values, new_serie])\n",
"ts.update(new_values)\n",
"preds = ts.predict({'Naive': NaiveModel()}, 1)\n",
Expand Down
28 changes: 14 additions & 14 deletions nbs/forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L542){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L564){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.make_future_dataframe\n",
"\n",
Expand All @@ -1492,7 +1492,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L542){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L564){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.make_future_dataframe\n",
"\n",
Expand Down Expand Up @@ -1600,7 +1600,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L566){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L588){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.get_missing_future\n",
"\n",
Expand All @@ -1619,7 +1619,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L566){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L588){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.get_missing_future\n",
"\n",
Expand Down Expand Up @@ -2111,7 +2111,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
Expand Down Expand Up @@ -2141,7 +2141,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L554){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L607){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
Expand Down Expand Up @@ -2700,7 +2700,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
Expand Down Expand Up @@ -2732,7 +2732,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L216){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
Expand Down Expand Up @@ -3052,7 +3052,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
Expand All @@ -3071,7 +3071,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L272){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L258){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
Expand Down Expand Up @@ -3196,7 +3196,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
Expand Down Expand Up @@ -3249,7 +3249,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L693){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L746){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
Expand Down Expand Up @@ -4190,7 +4190,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
Expand All @@ -4199,7 +4199,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ language = English
custom_sidebar = True
license = apache2
status = 3
requirements = numba packaging pandas scikit-learn utilsforecast>=0.0.22 window-ops
requirements = numba packaging pandas scikit-learn utilsforecast>=0.0.24 window-ops
dask_requirements = fugue dask[complete] lightgbm xgboost
ray_requirements = fugue[ray] lightgbm_ray xgboost_ray
spark_requirements = fugue pyspark lightgbm xgboost
Expand Down

0 comments on commit 3a3c07d

Please sign in to comment.