Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion from cudf-polars expressions to libcudf ast for parquet filters #17141

Merged
merged 13 commits into from
Oct 30, 2024
9 changes: 9 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame
from cudf_polars.dsl.nodebase import Node
from cudf_polars.dsl.to_ast import to_parquet_filter
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
Expand Down Expand Up @@ -418,9 +419,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
colnames[0],
)
elif self.typ == "parquet":
filters = None
if self.predicate is not None and self.row_index is None:
# Can't apply filters during read if we have a row index.
filters = to_parquet_filter(self.predicate.value)
tbl_w_meta = plc.io.parquet.read_parquet(
plc.io.SourceInfo(self.paths),
columns=with_columns,
filters=filters,
nrows=n_rows,
skip_rows=self.skip_rows,
)
Expand All @@ -429,6 +435,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
# TODO: consider nested column names?
tbl_w_meta.column_names(include_children=False),
)
if filters is not None:
# Mask must have been applied.
return df
vyasr marked this conversation as resolved.
Show resolved Hide resolved
elif self.typ == "ndjson":
json_schema: list[tuple[str, str, list]] = [
(name, typ, []) for name, typ in self.schema.items()
Expand Down
263 changes: 263 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Conversion of expression nodes to libcudf AST nodes."""

from __future__ import annotations

from functools import partial, reduce, singledispatch
from typing import TYPE_CHECKING, TypeAlias

import pylibcudf as plc
from pylibcudf import expressions as pexpr
wence- marked this conversation as resolved.
Show resolved Hide resolved

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.dsl import expr
from cudf_polars.dsl.traversal import CachingVisitor
from cudf_polars.typing import GenericTransformer

if TYPE_CHECKING:
from collections.abc import Mapping

# Can't merge these op-mapping dictionaries because scoped enum values
# are exposed by cython with equality/hash based one their underlying
# representation type. So in a dict they are just treated as integers.
BINOP_TO_ASTOP = {
plc.binaryop.BinaryOperator.EQUAL: pexpr.ASTOperator.EQUAL,
plc.binaryop.BinaryOperator.NULL_EQUALS: pexpr.ASTOperator.NULL_EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: pexpr.ASTOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: pexpr.ASTOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL: pexpr.ASTOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER: pexpr.ASTOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL: pexpr.ASTOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.ADD: pexpr.ASTOperator.ADD,
plc.binaryop.BinaryOperator.SUB: pexpr.ASTOperator.SUB,
plc.binaryop.BinaryOperator.MUL: pexpr.ASTOperator.MUL,
plc.binaryop.BinaryOperator.DIV: pexpr.ASTOperator.DIV,
plc.binaryop.BinaryOperator.TRUE_DIV: pexpr.ASTOperator.TRUE_DIV,
plc.binaryop.BinaryOperator.FLOOR_DIV: pexpr.ASTOperator.FLOOR_DIV,
plc.binaryop.BinaryOperator.PYMOD: pexpr.ASTOperator.PYMOD,
plc.binaryop.BinaryOperator.BITWISE_AND: pexpr.ASTOperator.BITWISE_AND,
plc.binaryop.BinaryOperator.BITWISE_OR: pexpr.ASTOperator.BITWISE_OR,
plc.binaryop.BinaryOperator.BITWISE_XOR: pexpr.ASTOperator.BITWISE_XOR,
plc.binaryop.BinaryOperator.LOGICAL_AND: pexpr.ASTOperator.LOGICAL_AND,
plc.binaryop.BinaryOperator.LOGICAL_OR: pexpr.ASTOperator.LOGICAL_OR,
plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: pexpr.ASTOperator.NULL_LOGICAL_AND,
plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: pexpr.ASTOperator.NULL_LOGICAL_OR,
}

UOP_TO_ASTOP = {
plc.unary.UnaryOperator.SIN: pexpr.ASTOperator.SIN,
plc.unary.UnaryOperator.COS: pexpr.ASTOperator.COS,
plc.unary.UnaryOperator.TAN: pexpr.ASTOperator.TAN,
plc.unary.UnaryOperator.ARCSIN: pexpr.ASTOperator.ARCSIN,
plc.unary.UnaryOperator.ARCCOS: pexpr.ASTOperator.ARCCOS,
plc.unary.UnaryOperator.ARCTAN: pexpr.ASTOperator.ARCTAN,
plc.unary.UnaryOperator.SINH: pexpr.ASTOperator.SINH,
plc.unary.UnaryOperator.COSH: pexpr.ASTOperator.COSH,
plc.unary.UnaryOperator.TANH: pexpr.ASTOperator.TANH,
plc.unary.UnaryOperator.ARCSINH: pexpr.ASTOperator.ARCSINH,
plc.unary.UnaryOperator.ARCCOSH: pexpr.ASTOperator.ARCCOSH,
plc.unary.UnaryOperator.ARCTANH: pexpr.ASTOperator.ARCTANH,
plc.unary.UnaryOperator.EXP: pexpr.ASTOperator.EXP,
plc.unary.UnaryOperator.LOG: pexpr.ASTOperator.LOG,
plc.unary.UnaryOperator.SQRT: pexpr.ASTOperator.SQRT,
plc.unary.UnaryOperator.CBRT: pexpr.ASTOperator.CBRT,
plc.unary.UnaryOperator.CEIL: pexpr.ASTOperator.CEIL,
plc.unary.UnaryOperator.FLOOR: pexpr.ASTOperator.FLOOR,
plc.unary.UnaryOperator.ABS: pexpr.ASTOperator.ABS,
plc.unary.UnaryOperator.RINT: pexpr.ASTOperator.RINT,
plc.unary.UnaryOperator.BIT_INVERT: pexpr.ASTOperator.BIT_INVERT,
plc.unary.UnaryOperator.NOT: pexpr.ASTOperator.NOT,
}

SUPPORTED_STATISTICS_BINOPS = {
plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL,
}

REVERSED_COMPARISON = {
plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL,
plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL,
}


Transformer: TypeAlias = GenericTransformer[expr.Expr, pexpr.Expression]


@singledispatch
def _to_ast(node: expr.Expr, self: Transformer) -> pexpr.Expression:
"""
Translate an expression to a pylibcudf Expression.

Parameters
----------
node
Expression to translate.
self
Recursive transformer. The state dictionary should contain a
`for_parquet` key indicating if this transformation should
provide an expression suitable for use in parquet filters.

If `for_parquet` is `False`, the dictionary should contain a
`name_to_index` mapping that maps column names to their
integer index in the table that will be used for evaluation of
the expression.

Returns
-------
pylibcudf Expression.

Raises
------
NotImplementedError or KeyError if the expression cannot be translated.
"""
raise NotImplementedError(f"Unhandled expression type {type(node)}")


@_to_ast.register
def _(node: expr.Col, self: Transformer) -> pexpr.Expression:
if self.state["for_parquet"]:
return pexpr.ColumnNameReference(node.name)
return pexpr.ColumnReference(self.state["name_to_index"][node.name])


@_to_ast.register
def _(node: expr.Literal, self: Transformer) -> pexpr.Expression:
return pexpr.Literal(plc.interop.from_arrow(node.value))


@_to_ast.register
def _(node: expr.BinOp, self: Transformer) -> pexpr.Expression:
if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS:
return pexpr.Operation(
pexpr.ASTOperator.NOT,
self(
# Reconstruct and apply, rather than directly
# constructing the right expression so we get the
# handling of parquet special cases for free.
expr.BinOp(
node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children
)
),
)
if self.state["for_parquet"]:
op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children)
if op1_col ^ op2_col:
op = node.op
if op not in SUPPORTED_STATISTICS_BINOPS:
raise NotImplementedError(
f"Parquet filter binop with column doesn't support {node.op!r}"
)
op1, op2 = node.children
if op2_col:
(op1, op2) = (op2, op1)
op = REVERSED_COMPARISON[op]
if not isinstance(op2, expr.Literal):
raise NotImplementedError(
"Parquet filter binops must have form 'col binop literal'"
)
vyasr marked this conversation as resolved.
Show resolved Hide resolved
return pexpr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2))
elif op1_col and op2_col:
raise NotImplementedError(
"Parquet filter binops must have one column reference not two"
)
return pexpr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children))


@_to_ast.register
def _(node: expr.BooleanFunction, self: Transformer) -> pexpr.Expression:
if node.name == pl_expr.BooleanFunction.IsIn:
needles, haystack = node.children
if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16:
# 16 is an arbitrary limit
Comment on lines +181 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, what is the purpose of this limit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to make one scalar for every value and upload it to the device. So I just picked a value as a cutoff

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea here that you think once we need to create more than a certain number of scalars the cost of allocation will be high enough that we will underperform the CPU? The end result here is that we raise and fall back when there are more than 16 scalars, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means that (for example) we will do the parquet filter as a post-filter (still on the GPU) rather than during the read.

needle_ref = self(needles)
values = [pexpr.Literal(plc.interop.from_arrow(v)) for v in haystack.value]
return reduce(
partial(pexpr.Operation, pexpr.ASTOperator.LOGICAL_OR),
(
pexpr.Operation(pexpr.ASTOperator.EQUAL, needle_ref, value)
for value in values
),
)
if self.state["for_parquet"] and isinstance(node.children[0], expr.Col):
raise NotImplementedError(
f"Parquet filters don't support {node.name} on columns"
)
if node.name == pl_expr.BooleanFunction.IsNull:
return pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0]))
elif node.name == pl_expr.BooleanFunction.IsNotNull:
return pexpr.Operation(
pexpr.ASTOperator.NOT,
pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])),
)
elif node.name == pl_expr.BooleanFunction.Not:
return pexpr.Operation(pexpr.ASTOperator.NOT, self(node.children[0]))
raise NotImplementedError(f"AST conversion does not support {node.name}")


@_to_ast.register
def _(node: expr.UnaryFunction, self: Transformer) -> pexpr.Expression:
if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]:
raise NotImplementedError(
"Parquet filters don't support {node.name} on columns"
)
return pexpr.Operation(
UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0])
)


def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None:
"""
Convert an expression to libcudf AST nodes suitable for parquet filtering.

Parameters
----------
node
Expression to convert.

Returns
-------
pylibcudf Expression if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(_to_ast, state={"for_parquet": True})
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None


def to_ast(
node: expr.Expr, *, name_to_index: Mapping[str, int]
) -> pexpr.Expression | None:
"""
Convert an expression to libcudf AST nodes suitable for compute_column.

Parameters
----------
node
Expression to convert.
name_to_index
Mapping from column names to their index in the table that
will be used for expression evaluation.

Returns
-------
pylibcudf Expressoin if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(
_to_ast, state={"for_parquet": False, "name_to_index": name_to_index}
)
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def assert_collect_raises(
collect_kwargs: dict[OptimizationArgs, bool] | None = None,
polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
):
) -> None:
"""
Assert that collecting the result of a query raises the expected exceptions.

Expand Down
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections.abc import Mapping


def pytest_addoption(parser: pytest.Parser):
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add plugin-specific options."""
group = parser.getgroup(
"cudf-polars", "Plugin to set GPU as default engine for polars tests"
Expand All @@ -28,7 +28,7 @@ def pytest_addoption(parser: pytest.Parser):
)


def pytest_configure(config: pytest.Config):
def pytest_configure(config: pytest.Config) -> None:
"""Enable use of this module as a pytest plugin to enable GPU collection."""
no_fallback = config.getoption("--cudf-polars-no-fallback")
collect = polars.LazyFrame.collect
Expand Down Expand Up @@ -172,7 +172,7 @@ def pytest_configure(config: pytest.Config):

def pytest_collection_modifyitems(
session: pytest.Session, config: pytest.Config, items: list[pytest.Item]
):
) -> None:
"""Mark known failing tests."""
if config.getoption("--cudf-polars-no-fallback"):
# Don't xfail tests if running without fallback
Expand Down
Loading
Loading