Skip to content

Commit

Permalink
enh(distributed): propagate null features in spark (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 12, 2024
1 parent ca67b98 commit e869465
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ def _fit(
]
self.models_ = {}
if SPARK_INSTALLED and isinstance(data, SparkDataFrame):
featurizer = VectorAssembler(inputCols=features, outputCol="features")
featurizer = VectorAssembler(
inputCols=features, outputCol="features", handleInvalid="keep"
)
train_data = featurizer.transform(prep)[target_col, "features"]
for name, model in self.models.items():
trained_model = model._pre_fit(target_col).fit(train_data)
Expand Down
4 changes: 3 additions & 1 deletion nbs/distributed.forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@
" features = [x for x in fa.get_column_names(prep) if x not in {id_col, time_col, target_col}]\n",
" self.models_ = {}\n",
" if SPARK_INSTALLED and isinstance(data, SparkDataFrame):\n",
" featurizer = VectorAssembler(inputCols=features, outputCol=\"features\")\n",
" featurizer = VectorAssembler(\n",
" inputCols=features, outputCol=\"features\", handleInvalid=\"keep\"\n",
" )\n",
" train_data = featurizer.transform(prep)[target_col, \"features\"]\n",
" for name, model in self.models.items():\n",
" trained_model = model._pre_fit(target_col).fit(train_data)\n",
Expand Down

0 comments on commit e869465

Please sign in to comment.