Skip to content

Commit

Permalink
Implement picklable BinOp type
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Nov 26, 2024
1 parent 99a5d12 commit 3e14ec9
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 104 deletions.
127 changes: 94 additions & 33 deletions python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

from __future__ import annotations

from enum import IntEnum, auto
from typing import TYPE_CHECKING, ClassVar

from polars.polars import _expr_nodes as pl_expr

import pylibcudf as plc
from pylibcudf import expressions as plc_expr

from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr

if TYPE_CHECKING:
from collections.abc import Mapping

from typing_extensions import Self

from cudf_polars.containers import DataFrame

__all__ = ["BinOp"]
Expand All @@ -27,10 +31,90 @@ class BinOp(Expr):
__slots__ = ("op",)
_non_child = ("dtype", "op")

class Operator(IntEnum):
"""Internal and picklable representation of pylibcudf's `BinaryOperator`."""

ADD = auto()
ATAN2 = auto()
BITWISE_AND = auto()
BITWISE_OR = auto()
BITWISE_XOR = auto()
DIV = auto()
EQUAL = auto()
FLOOR_DIV = auto()
GENERIC_BINARY = auto()
GREATER = auto()
GREATER_EQUAL = auto()
INT_POW = auto()
INVALID_BINARY = auto()
LESS = auto()
LESS_EQUAL = auto()
LOGICAL_AND = auto()
LOGICAL_OR = auto()
LOG_BASE = auto()
MOD = auto()
MUL = auto()
NOT_EQUAL = auto()
NULL_EQUALS = auto()
NULL_LOGICAL_AND = auto()
NULL_LOGICAL_OR = auto()
NULL_MAX = auto()
NULL_MIN = auto()
NULL_NOT_EQUALS = auto()
PMOD = auto()
POW = auto()
PYMOD = auto()
SHIFT_LEFT = auto()
SHIFT_RIGHT = auto()
SHIFT_RIGHT_UNSIGNED = auto()
SUB = auto()
TRUE_DIV = auto()

@classmethod
def from_polars(cls, obj: pl_expr.Operator) -> BinOp.Operator:
"""Convert from polars' `Operator`."""
mapping: dict[pl_expr.Operator, BinOp.Operator] = {
pl_expr.Operator.Eq: BinOp.Operator.EQUAL,
pl_expr.Operator.EqValidity: BinOp.Operator.NULL_EQUALS,
pl_expr.Operator.NotEq: BinOp.Operator.NOT_EQUAL,
pl_expr.Operator.NotEqValidity: BinOp.Operator.NULL_NOT_EQUALS,
pl_expr.Operator.Lt: BinOp.Operator.LESS,
pl_expr.Operator.LtEq: BinOp.Operator.LESS_EQUAL,
pl_expr.Operator.Gt: BinOp.Operator.GREATER,
pl_expr.Operator.GtEq: BinOp.Operator.GREATER_EQUAL,
pl_expr.Operator.Plus: BinOp.Operator.ADD,
pl_expr.Operator.Minus: BinOp.Operator.SUB,
pl_expr.Operator.Multiply: BinOp.Operator.MUL,
pl_expr.Operator.Divide: BinOp.Operator.DIV,
pl_expr.Operator.TrueDivide: BinOp.Operator.TRUE_DIV,
pl_expr.Operator.FloorDivide: BinOp.Operator.FLOOR_DIV,
pl_expr.Operator.Modulus: BinOp.Operator.PYMOD,
pl_expr.Operator.And: BinOp.Operator.BITWISE_AND,
pl_expr.Operator.Or: BinOp.Operator.BITWISE_OR,
pl_expr.Operator.Xor: BinOp.Operator.BITWISE_XOR,
pl_expr.Operator.LogicalAnd: BinOp.Operator.LOGICAL_AND,
pl_expr.Operator.LogicalOr: BinOp.Operator.LOGICAL_OR,
}

return mapping[obj]

@classmethod
def to_pylibcudf(cls, obj: Self) -> plc.binaryop.BinaryOperator:
"""Convert to pylibcudf's `BinaryOperator`."""
return getattr(plc.binaryop.BinaryOperator, obj.name)

@classmethod
def to_pylibcudf_expr(cls, obj: Self) -> plc.binaryop.BinaryOperator:
"""Convert to pylibcudf's `ASTOperator`."""
if obj is BinOp.Operator.NULL_EQUALS:
# Name mismatch in pylibcudf's `BinaryOperator` and `ASTOperator`.
return plc_expr.ASTOperator.NULL_EQUAL
return getattr(plc_expr.ASTOperator, obj.name)

def __init__(
self,
dtype: plc.DataType,
op: plc.binaryop.BinaryOperator,
op: BinOp.Operator,
left: Expr,
right: Expr,
) -> None:
Expand All @@ -43,44 +127,19 @@ def __init__(
self.op = op
self.children = (left, right)
if not plc.binaryop.is_supported_operation(
self.dtype, left.dtype, right.dtype, op
self.dtype, left.dtype, right.dtype, BinOp.Operator.to_pylibcudf(op)
):
raise NotImplementedError(
f"Operation {op.name} not supported "
f"for types {left.dtype.id().name} and {right.dtype.id().name} "
f"with output type {self.dtype.id().name}"
)

_BOOL_KLEENE_MAPPING: ClassVar[
dict[plc.binaryop.BinaryOperator, plc.binaryop.BinaryOperator]
] = {
plc.binaryop.BinaryOperator.BITWISE_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND,
plc.binaryop.BinaryOperator.BITWISE_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR,
plc.binaryop.BinaryOperator.LOGICAL_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND,
plc.binaryop.BinaryOperator.LOGICAL_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR,
}

_MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = {
pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL,
pl_expr.Operator.EqValidity: plc.binaryop.BinaryOperator.NULL_EQUALS,
pl_expr.Operator.NotEq: plc.binaryop.BinaryOperator.NOT_EQUAL,
pl_expr.Operator.NotEqValidity: plc.binaryop.BinaryOperator.NULL_NOT_EQUALS,
pl_expr.Operator.Lt: plc.binaryop.BinaryOperator.LESS,
pl_expr.Operator.LtEq: plc.binaryop.BinaryOperator.LESS_EQUAL,
pl_expr.Operator.Gt: plc.binaryop.BinaryOperator.GREATER,
pl_expr.Operator.GtEq: plc.binaryop.BinaryOperator.GREATER_EQUAL,
pl_expr.Operator.Plus: plc.binaryop.BinaryOperator.ADD,
pl_expr.Operator.Minus: plc.binaryop.BinaryOperator.SUB,
pl_expr.Operator.Multiply: plc.binaryop.BinaryOperator.MUL,
pl_expr.Operator.Divide: plc.binaryop.BinaryOperator.DIV,
pl_expr.Operator.TrueDivide: plc.binaryop.BinaryOperator.TRUE_DIV,
pl_expr.Operator.FloorDivide: plc.binaryop.BinaryOperator.FLOOR_DIV,
pl_expr.Operator.Modulus: plc.binaryop.BinaryOperator.PYMOD,
pl_expr.Operator.And: plc.binaryop.BinaryOperator.BITWISE_AND,
pl_expr.Operator.Or: plc.binaryop.BinaryOperator.BITWISE_OR,
pl_expr.Operator.Xor: plc.binaryop.BinaryOperator.BITWISE_XOR,
pl_expr.Operator.LogicalAnd: plc.binaryop.BinaryOperator.LOGICAL_AND,
pl_expr.Operator.LogicalOr: plc.binaryop.BinaryOperator.LOGICAL_OR,
_BOOL_KLEENE_MAPPING: ClassVar[dict[Operator, Operator]] = {
Operator.BITWISE_AND: Operator.NULL_LOGICAL_AND,
Operator.BITWISE_OR: Operator.NULL_LOGICAL_OR,
Operator.LOGICAL_AND: Operator.NULL_LOGICAL_AND,
Operator.LOGICAL_OR: Operator.NULL_LOGICAL_OR,
}

def do_evaluate(
Expand All @@ -103,7 +162,9 @@ def do_evaluate(
elif right.is_scalar:
rop = right.obj_scalar
return Column(
plc.binaryop.binary_operation(lop, rop, self.op, self.dtype),
plc.binaryop.binary_operation(
lop, rop, BinOp.Operator.to_pylibcudf(self.op), self.dtype
),
)

def collect_agg(self, *, depth: int) -> AggInfo:
Expand Down
19 changes: 10 additions & 9 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ExecutionContext,
Expr,
)
from cudf_polars.dsl.expressions.binaryop import BinOp

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -80,24 +81,24 @@ def _distinct(
_BETWEEN_OPS: ClassVar[
dict[
pl_types.ClosedInterval,
tuple[plc.binaryop.BinaryOperator, plc.binaryop.BinaryOperator],
tuple[BinOp.Operator, BinOp.Operator],
]
] = {
"none": (
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.LESS,
BinOp.Operator.GREATER,
BinOp.Operator.LESS,
),
"left": (
plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.LESS,
BinOp.Operator.GREATER_EQUAL,
BinOp.Operator.LESS,
),
"right": (
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.LESS_EQUAL,
BinOp.Operator.GREATER,
BinOp.Operator.LESS_EQUAL,
),
"both": (
plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.LESS_EQUAL,
BinOp.Operator.GREATER_EQUAL,
BinOp.Operator.LESS_EQUAL,
),
}

Expand Down
67 changes: 21 additions & 46 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,6 @@

from cudf_polars.typing import ExprTransformer

# Can't merge these op-mapping dictionaries because scoped enum values
# are exposed by cython with equality/hash based one their underlying
# representation type. So in a dict they are just treated as integers.
BINOP_TO_ASTOP = {
plc.binaryop.BinaryOperator.EQUAL: plc_expr.ASTOperator.EQUAL,
plc.binaryop.BinaryOperator.NULL_EQUALS: plc_expr.ASTOperator.NULL_EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: plc_expr.ASTOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: plc_expr.ASTOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL: plc_expr.ASTOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER: plc_expr.ASTOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc_expr.ASTOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.ADD: plc_expr.ASTOperator.ADD,
plc.binaryop.BinaryOperator.SUB: plc_expr.ASTOperator.SUB,
plc.binaryop.BinaryOperator.MUL: plc_expr.ASTOperator.MUL,
plc.binaryop.BinaryOperator.DIV: plc_expr.ASTOperator.DIV,
plc.binaryop.BinaryOperator.TRUE_DIV: plc_expr.ASTOperator.TRUE_DIV,
plc.binaryop.BinaryOperator.FLOOR_DIV: plc_expr.ASTOperator.FLOOR_DIV,
plc.binaryop.BinaryOperator.PYMOD: plc_expr.ASTOperator.PYMOD,
plc.binaryop.BinaryOperator.BITWISE_AND: plc_expr.ASTOperator.BITWISE_AND,
plc.binaryop.BinaryOperator.BITWISE_OR: plc_expr.ASTOperator.BITWISE_OR,
plc.binaryop.BinaryOperator.BITWISE_XOR: plc_expr.ASTOperator.BITWISE_XOR,
plc.binaryop.BinaryOperator.LOGICAL_AND: plc_expr.ASTOperator.LOGICAL_AND,
plc.binaryop.BinaryOperator.LOGICAL_OR: plc_expr.ASTOperator.LOGICAL_OR,
plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: plc_expr.ASTOperator.NULL_LOGICAL_AND,
plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: plc_expr.ASTOperator.NULL_LOGICAL_OR,
}

UOP_TO_ASTOP = {
plc.unary.UnaryOperator.SIN: plc_expr.ASTOperator.SIN,
plc.unary.UnaryOperator.COS: plc_expr.ASTOperator.COS,
Expand All @@ -75,21 +48,21 @@
}

SUPPORTED_STATISTICS_BINOPS = {
plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL,
expr.BinOp.Operator.EQUAL,
expr.BinOp.Operator.NOT_EQUAL,
expr.BinOp.Operator.LESS,
expr.BinOp.Operator.LESS_EQUAL,
expr.BinOp.Operator.GREATER,
expr.BinOp.Operator.GREATER_EQUAL,
}

REVERSED_COMPARISON = {
plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL,
expr.BinOp.Operator.EQUAL: expr.BinOp.Operator.EQUAL,
expr.BinOp.Operator.NOT_EQUAL: expr.BinOp.Operator.NOT_EQUAL,
expr.BinOp.Operator.LESS: expr.BinOp.Operator.GREATER,
expr.BinOp.Operator.LESS_EQUAL: expr.BinOp.Operator.GREATER_EQUAL,
expr.BinOp.Operator.GREATER: expr.BinOp.Operator.LESS,
expr.BinOp.Operator.GREATER_EQUAL: expr.BinOp.Operator.LESS_EQUAL,
}


Expand Down Expand Up @@ -147,22 +120,20 @@ def _(node: expr.Literal, self: Transformer) -> plc_expr.Expression:

@_to_ast.register
def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression:
if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS:
if node.op == expr.BinOp.Operator.NULL_NOT_EQUALS:
return plc_expr.Operation(
plc_expr.ASTOperator.NOT,
self(
# Reconstruct and apply, rather than directly
# constructing the right expression so we get the
# handling of parquet special cases for free.
expr.BinOp(
node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children
)
expr.BinOp(node.dtype, expr.BinOp.Operator.NULL_EQUALS, *node.children)
),
)
if self.state["for_parquet"]:
op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children)
if op1_col ^ op2_col:
op = node.op
op: expr.BinOp.Operator = node.op
if op not in SUPPORTED_STATISTICS_BINOPS:
raise NotImplementedError(
f"Parquet filter binop with column doesn't support {node.op!r}"
Expand All @@ -175,12 +146,16 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression:
raise NotImplementedError(
"Parquet filter binops must have form 'col binop literal'"
)
return plc_expr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2))
return plc_expr.Operation(
expr.BinOp.Operator.to_pylibcudf_expr(op), self(op1), self(op2)
)
elif op1_col and op2_col:
raise NotImplementedError(
"Parquet filter binops must have one column reference not two"
)
return plc_expr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children))
return plc_expr.Operation(
expr.BinOp.Operator.to_pylibcudf_expr(node.op), *map(self, node.children)
)


@_to_ast.register
Expand Down
14 changes: 6 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,11 @@ def _(

dtype = plc.DataType(plc.TypeId.BOOL8)
predicate = functools.reduce(
functools.partial(
expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND
),
functools.partial(expr.BinOp, dtype, expr.BinOp.Operator.LOGICAL_AND),
(
expr.BinOp(
dtype,
expr.BinOp._MAPPING[op],
expr.BinOp.Operator.from_polars(op),
insert_colrefs(
left.value,
table_ref=plc.expressions.TableReference.LEFT,
Expand Down Expand Up @@ -545,7 +543,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
return expr.BinOp(
dtype,
plc.binaryop.BinaryOperator.LOGICAL_AND,
expr.BinOp.Operator.LOGICAL_AND,
expr.BinOp(dtype, lop, column, lo),
expr.BinOp(dtype, rop, column, hi),
)
Expand Down Expand Up @@ -586,12 +584,12 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
(child,) = children
return expr.BinOp(
dtype,
plc.binaryop.BinaryOperator.LOG_BASE,
expr.BinOp.Operator.LOG_BASE,
child,
expr.Literal(dtype, pa.scalar(base, type=plc.interop.to_arrow(dtype))),
)
elif name == "pow":
return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children)
return expr.BinOp(dtype, expr.BinOp.Operator.POW, *children)
return expr.UnaryFunction(dtype, name, options, *children)
raise NotImplementedError(
f"No handler for Expr function node with {name=}"
Expand Down Expand Up @@ -707,7 +705,7 @@ def _(
) -> expr.Expr:
return expr.BinOp(
dtype,
expr.BinOp._MAPPING[node.op],
expr.BinOp.Operator.from_polars(node.op),
translator.translate_expr(n=node.left),
translator.translate_expr(n=node.right),
)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/tests/dsl/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def make_expr(n1, n2):
a = expr.Col(plc.DataType(plc.TypeId.INT8), n1)
b = expr.Col(plc.DataType(plc.TypeId.INT8), n2)

return expr.BinOp(dt, plc.binaryop.BinaryOperator.ADD, a, b)
return expr.BinOp(dt, expr.BinOp.Operator.ADD, a, b)

e1 = make_expr("a", "b")
e2 = make_expr("a", "b")
Expand Down
Loading

0 comments on commit 3e14ec9

Please sign in to comment.