Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract Python and Dask Executor classes from Workflow #1609

Merged
merged 16 commits into from
Aug 15, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move ensure_optimize_dataframe_graph into MerlinDaskExecutor
karlhigley committed Jul 12, 2022
commit 32512d8642f4658445e931ac217fee3f6fdec35d
18 changes: 10 additions & 8 deletions nvtabular/workflow/executor.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
from dask.core import flatten

from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype
from merlin.core.utils import global_dask_client
from merlin.core.utils import ensure_optimize_dataframe_graph, global_dask_client
from merlin.dag import Node
from merlin.io.worker import clean_worker_cache

@@ -165,13 +165,15 @@ def apply(self, ddf, nodes, output_dtypes=None, additional_columns=None, capture
# don't require dtype information on the DDF this doesn't matter all that much
output_dtypes = type(ddf._meta)({k: [] for k in columns})

return ddf.map_partitions(
self._executor.apply,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
meta=output_dtypes,
enforce_metadata=False,
return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
self._executor.apply,
nodes,
additional_columns=additional_columns,
capture_dtypes=capture_dtypes,
meta=output_dtypes,
enforce_metadata=False,
)
)

def _clear_worker_cache(self):
18 changes: 6 additions & 12 deletions nvtabular/workflow/workflow.py
Original file line number Diff line number Diff line change
@@ -31,11 +31,7 @@
import pandas as pd

import nvtabular
from merlin.core.utils import (
ensure_optimize_dataframe_graph,
global_dask_client,
set_client_deprecated,
)
from merlin.core.utils import global_dask_client, set_client_deprecated
from merlin.dag import Graph
from merlin.io import Dataset
from merlin.schema import Schema
@@ -221,13 +217,11 @@ def fit(self, dataset: Dataset) -> "Workflow":

# apply transforms necessary for the inputs to the current column group, ignoring
# the transforms from the statop itself
transformed_ddf = ensure_optimize_dataframe_graph(
ddf=self.executor.apply(
ddf,
workflow_node.parents_with_dependencies,
additional_columns=addl_input_cols,
capture_dtypes=True,
)
transformed_ddf = self.executor.apply(
ddf,
workflow_node.parents_with_dependencies,
additional_columns=addl_input_cols,
capture_dtypes=True,
)

op = workflow_node.op