Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add branch_id to distinguish between reusable branches and pipeline breakers #883

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import weakref
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, NamedTuple

import dask
import pandas as pd
Expand All @@ -29,6 +29,10 @@
]


class BranchId(NamedTuple):
branch_id: int


def _unpack_collections(o):
if isinstance(o, Expr):
return o
Expand All @@ -44,8 +48,13 @@ class Expr:
_defaults = {}
_instances = weakref.WeakValueDictionary()

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, _branch_id=None, **kwargs):
operands = list(args)
if _branch_id is None and len(operands) and isinstance(operands[-1], BranchId):
_branch_id = operands.pop(-1)
elif _branch_id is None:
_branch_id = BranchId(0)

for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
Expand All @@ -54,6 +63,7 @@ def __new__(cls, *args, **kwargs):
assert not kwargs, kwargs
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
inst._branch_id = _branch_id
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]
Expand Down Expand Up @@ -116,7 +126,10 @@ def _tree_repr_lines(self, indent=0, recursive=True):
elif is_arraylike(op):
op = "<array>"
header = self._tree_repr_argument_construction(i, op, header)

if self._branch_id.branch_id != 0:
header = self._tree_repr_argument_construction(
i + 1, f" branch_id={self._branch_id.branch_id}", header
)
lines = [header] + lines
lines = [" " * indent + line for line in lines]

Expand Down Expand Up @@ -218,7 +231,7 @@ def _layer(self) -> dict:

return {(self._name, i): self._task(i) for i in range(self.npartitions)}

def rewrite(self, kind: str):
def rewrite(self, kind: str, cache):
"""Rewrite an expression

This leverages the ``._{kind}_down`` and ``._{kind}_up``
Expand All @@ -231,6 +244,9 @@ def rewrite(self, kind: str):
changed:
whether or not any change occured
"""
if self._name in cache:
return cache[self._name]

expr = self
down_name = f"_{kind}_down"
up_name = f"_{kind}_up"
Expand Down Expand Up @@ -267,21 +283,46 @@ def rewrite(self, kind: str):
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
new = operand.rewrite(kind=kind)
new = operand.rewrite(kind=kind, cache=cache)
cache[operand._name] = new
if new._name != operand._name:
changed = True
else:
new = operand
new_operands.append(new)

if changed:
expr = type(expr)(*new_operands)
expr = type(expr)(*new_operands, _branch_id=expr._branch_id)
continue
else:
break

return expr

def _reuse_up(self, parent):
return

def _reuse_down(self):
if not self.dependencies():
return
return self._bubble_branch_id_down()

def _bubble_branch_id_down(self):
b_id = self._branch_id
if b_id.branch_id <= 0:
return
if any(b_id.branch_id != d._branch_id.branch_id for d in self.dependencies()):
ops = [
op._substitute_branch_id(b_id) if isinstance(op, Expr) else op
for op in self.operands
]
return type(self)(*ops)

def _substitute_branch_id(self, branch_id):
if self._branch_id.branch_id != 0:
return self
return type(self)(*self.operands, branch_id)

def simplify_once(self, dependents: defaultdict, simplified: dict):
"""Simplify an expression

Expand Down Expand Up @@ -346,7 +387,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict):
new_operands.append(new)

if changed:
expr = type(expr)(*new_operands)
expr = type(expr)(*new_operands, _branch_id=expr._branch_id)

break

Expand Down Expand Up @@ -391,7 +432,7 @@ def lower_once(self):
new_operands.append(new)

if changed:
out = type(out)(*new_operands)
out = type(out)(*new_operands, _branch_id=out._branch_id)

return out

Expand Down Expand Up @@ -427,7 +468,9 @@ def _lower(self):
@functools.cached_property
def _name(self):
return (
funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands)
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(*self.operands, self._branch_id)
)

@property
Expand Down Expand Up @@ -580,7 +623,7 @@ def substitute_parameters(self, substitutions: dict) -> Expr:
else:
new_operands.append(operand)
if changed:
return type(self)(*new_operands)
return type(self)(*new_operands, _branch_id=self._branch_id)
return self

def _node_label_args(self):
Expand Down
39 changes: 30 additions & 9 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from tlz import merge_sorted, partition, unique

from dask_expr import _core as core
from dask_expr._core import BranchId
from dask_expr._util import (
_calc_maybe_new_divisions,
_convert_to_list,
Expand Down Expand Up @@ -502,7 +503,7 @@ def _name(self):
head = funcname(self.operation)
else:
head = funcname(type(self)).lower()
return head + "-" + _tokenize_deterministic(*self.operands)
return head + "-" + _tokenize_deterministic(*self.operands, self._branch_id)

def _blockwise_arg(self, arg, i):
"""Return a Blockwise-task argument"""
Expand Down Expand Up @@ -2728,8 +2729,11 @@ class _DelayedExpr(Expr):
# TODO
_parameters = ["obj"]

def __init__(self, obj):
def __init__(self, obj, _branch_id=None):
self.obj = obj
if _branch_id is None:
_branch_id = BranchId(0)
self._branch_id = _branch_id
self.operands = [obj]

def __str__(self):
Expand Down Expand Up @@ -2758,18 +2762,29 @@ def normalize_expression(expr):
return expr._name


def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr:
def optimize_until(
expr: Expr, stage: core.OptimizerStage, common_subplan_elimination: bool = False
) -> Expr:
result = expr
if stage == "logical":
return result

# Simplify
expr = result.simplify()
while True:
if not common_subplan_elimination:
out = result.rewrite("reuse", cache={})
else:
out = result
out = out.simplify()
if out._name == result._name or common_subplan_elimination:
break
result = out

expr = out
if stage == "simplified-logical":
return expr

# Manipulate Expression to make it more efficient
expr = expr.rewrite(kind="tune")
expr = expr.rewrite(kind="tune", cache={})
if stage == "tuned-logical":
return expr

Expand All @@ -2791,7 +2806,9 @@ def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr:
raise ValueError(f"Stage {stage!r} not supported.")


def optimize(expr: Expr, fuse: bool = True) -> Expr:
def optimize(
expr: Expr, fuse: bool = True, common_subplan_elimination: bool = False
) -> Expr:
"""High level query optimization

This leverages three optimization passes:
Expand All @@ -2805,6 +2822,10 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr:
Input expression to optimize
fuse:
whether or not to turn on blockwise fusion
common_subplan_elimination : bool, default False
whether we want to reuse common subplans that are found in the graph and
are used in self-joins or similar which require all data be held in memory
at some point. Only set this to true if your dataset fits into memory.

See Also
--------
Expand All @@ -2813,7 +2834,7 @@ def optimize(expr: Expr, fuse: bool = True) -> Expr:
"""
stage: core.OptimizerStage = "fused" if fuse else "simplified-physical"

return optimize_until(expr, stage)
return optimize_until(expr, stage, common_subplan_elimination)


def is_broadcastable(dfs, s):
Expand Down Expand Up @@ -3462,7 +3483,7 @@ def __str__(self):

@functools.cached_property
def _name(self):
return f"{str(self)}-{_tokenize_deterministic(self.exprs)}"
return f"{str(self)}-{_tokenize_deterministic(self.exprs, self._branch_id)}"

def _divisions(self):
return self.exprs[0]._divisions()
Expand Down
45 changes: 44 additions & 1 deletion dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dask.utils import M, apply, funcname

from dask_expr._concat import Concat
from dask_expr._core import BranchId
from dask_expr._expr import (
Blockwise,
Expr,
Expand Down Expand Up @@ -300,7 +301,7 @@ def _name(self):
name = funcname(self.combine.__self__).lower() + "-tree"
else:
name = funcname(self.combine)
return name + "-" + _tokenize_deterministic(*self.operands)
return name + "-" + _tokenize_deterministic(*self.operands, self._branch_id)

def __dask_postcompute__(self):
return toolz.first, ()
Expand Down Expand Up @@ -507,6 +508,48 @@ def _lower(self):
ignore_index=getattr(self, "ignore_index", True),
)

def _reuse_up(self, parent):
return

def _substitute_branch_id(self, branch_id):
return self

def _reuse_down(self):
if self._branch_id.branch_id != 0:
return

from dask_expr.io import IO

seen = set()
stack = self.dependencies()
counter, found_consumer = 1, False

while stack:
node = stack.pop()

if node._name in seen:
continue
seen.add(node._name)

if isinstance(node, IO):
found_consumer = True
continue

if isinstance(node, ApplyConcatApply):
counter += 1
continue

stack.extend(node.dependencies())

if not found_consumer:
return
b_id = BranchId(counter)
result = type(self)(*self.operands, b_id)
out = result._bubble_branch_id_down()
if out is None:
return result
return type(out)(*out.operands, _branch_id=b_id)


class Unique(ApplyConcatApply):
_parameters = ["frame", "split_every", "split_out", "shuffle_method"]
Expand Down
20 changes: 15 additions & 5 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def _divisions(self):
@functools.cached_property
def _name(self):
return (
self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands)
self.operand("name_prefix")
+ "-"
+ _tokenize_deterministic(*self.operands, self._branch_id)
)

def _layer(self):
Expand Down Expand Up @@ -103,7 +105,7 @@ def _name(self):
return (
funcname(type(self.operand("_expr"))).lower()
+ "-fused-"
+ _tokenize_deterministic(*self.operands)
+ _tokenize_deterministic(*self.operands, self._expr._branch_id)
)

@functools.cached_property
Expand Down Expand Up @@ -173,10 +175,14 @@ def _name(self):
return (
funcname(self.func).lower()
+ "-"
+ _tokenize_deterministic(*self.operands)
+ _tokenize_deterministic(*self.operands, self._branch_id)
)
else:
return self.label + "-" + _tokenize_deterministic(*self.operands)
return (
self.label
+ "-"
+ _tokenize_deterministic(*self.operands, self._branch_id)
)

@functools.cached_property
def _meta(self):
Expand Down Expand Up @@ -448,7 +454,11 @@ class FromPandasDivisions(FromPandas):

@functools.cached_property
def _name(self):
return "from_pd_divs" + "-" + _tokenize_deterministic(*self.operands)
return (
"from_pd_divs"
+ "-"
+ _tokenize_deterministic(*self.operands, self._branch_id)
)

@property
def _divisions_and_locations(self):
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _name(self):
return (
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(self.checksum, *self.operands)
+ _tokenize_deterministic(self.checksum, *self.operands, self._branch_id)
)

@property
Expand Down
Loading
Loading