Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNM: Handle pipeline breakers through avoiding reuse #873

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def _tune_down(self):
def _tune_up(self, parent):
return None

def _pipe_down(self):
return None

def _pipe_up(self, parent):
return None

def _cull_down(self):
return None

Expand Down Expand Up @@ -342,6 +348,7 @@ def simplify(self) -> Expr:
while True:
dependents = collect_dependents(expr)
new = expr.simplify_once(dependents=dependents, simplified={})
new = new.rewrite("pipe")
if new._name == expr._name:
break
expr = new
Expand Down
9 changes: 7 additions & 2 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def _meta(self):
args = [
meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args
]
return self.operation(*args, **self._kwargs)
return make_meta(self.operation(*args, **self._kwargs))

@staticmethod
def operation(df, index, sorted_index):
Expand Down Expand Up @@ -2062,6 +2062,9 @@ class ResetIndex(Elemwise):
operation = M.reset_index
_filter_passthrough = True

def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

@functools.cached_property
def _kwargs(self) -> dict:
kwargs = {"drop": self.drop}
Expand Down Expand Up @@ -2099,7 +2102,9 @@ def _simplify_up(self, parent, dependents):
return self._filter_simplification(parent, predicate)

if isinstance(parent, Projection):
if self.frame.ndim == 1 and not self.drop and not isinstance(parent, list):
if self.frame.ndim == 1 and not self.drop:
if isinstance(parent.operand("columns"), list):
return
col = parent.operand("columns")
if col in (self.name, "index"):
return
Expand Down
12 changes: 10 additions & 2 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class GroupByApplyConcatApply(ApplyConcatApply, GroupByBase):
@functools.cached_property
def _meta_chunk(self):
meta = meta_nonempty(self.frame._meta)
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs)
return make_meta(self.chunk(meta, *self._by_meta, **self.chunk_kwargs))

@property
def _chunk_cls_args(self):
Expand Down Expand Up @@ -201,6 +201,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase):
"split_out",
"sort",
"shuffle_method",
"_pipeline_breaker_counter",
]
_defaults = {
"observed": None,
Expand All @@ -212,6 +213,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase):
"split_out": None,
"sort": None,
"shuffle_method": None,
"_pipeline_breaker_counter": None,
}

groupby_chunk = None
Expand Down Expand Up @@ -251,7 +253,11 @@ def aggregate_kwargs(self) -> dict:
}

def _simplify_up(self, parent, dependents):
return groupby_projection(self, parent, dependents)
if isinstance(parent, Projection):
return groupby_projection(self, parent, dependents)

def _pipe_down(self):
return self._adjust_for_pipelinebreaker()


class GroupbyAggregationBase(GroupByApplyConcatApply, GroupByBase):
Expand Down Expand Up @@ -1479,6 +1485,7 @@ def _single_agg(
split_out,
self.sort,
shuffle_method,
None,
*self.by,
)
)
Expand Down Expand Up @@ -2161,6 +2168,7 @@ def nunique(self, split_every=None, split_out=True, shuffle_method=None):
split_out,
self.sort,
shuffle_method,
None,
*self.by,
)
)
Expand Down
81 changes: 76 additions & 5 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,41 @@ def _lower(self):
ignore_index=getattr(self, "ignore_index", True),
)

def _adjust_for_pipelinebreaker(self):
if self._pipeline_breaker_counter is not None:
return
from dask_expr.io.io import IO

seen = set()
stack = self.dependencies()
io_nodes = []
counter = 1

while stack:
node = stack.pop()

if node._name in seen:
continue
seen.add(node._name)

if isinstance(node, IO):
io_nodes.append(node)
continue
elif isinstance(node, ApplyConcatApply):
counter += 1
continue
stack.extend(node.dependencies())
if len(io_nodes) == 0:
return
io_nodes_new = [
io.substitute_parameters({"_pipeline_breaker_counter": counter})
for io in io_nodes
]
expr = self
for io_node_old, io_node_new in zip(io_nodes, io_nodes_new):
expr = expr.substitute(io_node_old, io_node_new)
return expr.substitute_parameters({"_pipeline_breaker_counter": counter})


class Unique(ApplyConcatApply):
_parameters = ["frame", "split_every", "split_out", "shuffle_method"]
Expand Down Expand Up @@ -773,13 +808,23 @@ def _simplify_up(self, parent, dependents):
if isinstance(parent, Projection):
return plain_column_projection(self, parent, dependents)

def _pipe_down(self):
return self._adjust_for_pipelinebreaker()


class Sum(Reduction):
_parameters = ["frame", "skipna", "numeric_only", "split_every"]
_parameters = [
"frame",
"skipna",
"numeric_only",
"split_every",
"_pipeline_breaker_counter",
]
_defaults = {
"split_every": False,
"numeric_only": False,
"skipna": True,
"_pipeline_breaker_counter": None,
}
reduction_chunk = M.sum

Expand Down Expand Up @@ -1090,8 +1135,21 @@ def reduction_aggregate(cls, vals, order):


class Mean(Reduction):
_parameters = ["frame", "skipna", "numeric_only", "split_every", "axis"]
_defaults = {"skipna": True, "numeric_only": False, "split_every": False, "axis": 0}
_parameters = [
"frame",
"skipna",
"numeric_only",
"split_every",
"axis",
"_pipeline_breaker_counter",
]
_defaults = {
"skipna": True,
"numeric_only": False,
"split_every": False,
"axis": 0,
"_pipeline_breaker_counter": None,
}

@functools.cached_property
def _meta(self):
Expand Down Expand Up @@ -1267,8 +1325,21 @@ def _nlast(df, columns, n, ascending):


class NFirst(NLargest):
_parameters = ["frame", "n", "_columns", "ascending", "split_every"]
_defaults = {"n": 5, "_columns": None, "ascending": None, "split_every": None}
_parameters = [
"frame",
"n",
"_columns",
"ascending",
"split_every",
"_pipeline_breaker_counter",
]
_defaults = {
"n": 5,
"_columns": None,
"ascending": None,
"split_every": None,
"_pipeline_breaker_counter": None,
}
reduction_chunk = _nfirst
reduction_aggregate = _nfirst

Expand Down
2 changes: 2 additions & 0 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
"columns",
"_partitions",
"_series",
"_pipeline_breaker_counter",
]
_defaults = {
"npartitions": None,
Expand All @@ -328,6 +329,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"_series": False,
"chunksize": None,
"_pipeline_breaker_counter": None,
}
_pd_length_stats = None
_absorb_projections = True
Expand Down
2 changes: 2 additions & 0 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"_partitions",
"_series",
"_dataset_info_cache",
"_pipeline_breaker_counter",
]
_defaults = {
"columns": None,
Expand All @@ -422,6 +423,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
"_pipeline_breaker_counter": None,
}
_pq_length_stats = None
_absorb_projections = True
Expand Down
Loading