From e22146650f92299d99c66bd24f7ee6a2cf1bb815 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 11 Oct 2024 13:56:35 +0000 Subject: [PATCH 01/12] A few type annotations --- python/cudf_polars/cudf_polars/testing/asserts.py | 2 +- python/cudf_polars/cudf_polars/testing/plugin.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 7b6f3848fc4..7b45c1eaa06 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -151,7 +151,7 @@ def assert_collect_raises( collect_kwargs: dict[OptimizationArgs, bool] | None = None, polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None, cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None, -): +) -> None: """ Assert that collecting the result of a query raises the expected exceptions. diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py index a3607159e01..e01ccd05527 100644 --- a/python/cudf_polars/cudf_polars/testing/plugin.py +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -16,7 +16,7 @@ from collections.abc import Mapping -def pytest_addoption(parser: pytest.Parser): +def pytest_addoption(parser: pytest.Parser) -> None: """Add plugin-specific options.""" group = parser.getgroup( "cudf-polars", "Plugin to set GPU as default engine for polars tests" @@ -28,7 +28,7 @@ def pytest_addoption(parser: pytest.Parser): ) -def pytest_configure(config: pytest.Config): +def pytest_configure(config: pytest.Config) -> None: """Enable use of this module as a pytest plugin to enable GPU collection.""" no_fallback = config.getoption("--cudf-polars-no-fallback") collect = polars.LazyFrame.collect @@ -172,7 +172,7 @@ def pytest_configure(config: pytest.Config): def pytest_collection_modifyitems( session: pytest.Session, config: pytest.Config, items: list[pytest.Item] -): +) -> None: """Mark known failing tests.""" if config.getoption("--cudf-polars-no-fallback"): # Don't xfail tests if running without fallback From 59911d731ae1604cdb4aa37408966970a0b1d557 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:30:24 +0000 Subject: [PATCH 02/12] Expose all type ids and match order with libcudf --- python/pylibcudf/pylibcudf/libcudf/types.pxd | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pylibcudf/pylibcudf/libcudf/types.pxd b/python/pylibcudf/pylibcudf/libcudf/types.pxd index eabae68bc90..60e293e5cdb 100644 --- a/python/pylibcudf/pylibcudf/libcudf/types.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/types.pxd @@ -70,18 +70,19 @@ cdef extern from "cudf/types.hpp" namespace "cudf" nogil: TIMESTAMP_MILLISECONDS TIMESTAMP_MICROSECONDS TIMESTAMP_NANOSECONDS - DICTIONARY32 - STRING - LIST - STRUCT - NUM_TYPE_IDS + DURATION_DAYS DURATION_SECONDS DURATION_MILLISECONDS DURATION_MICROSECONDS DURATION_NANOSECONDS + DICTIONARY32 + STRING + LIST DECIMAL32 DECIMAL64 DECIMAL128 + STRUCT + NUM_TYPE_IDS cdef cppclass data_type: data_type() except + From 06c15dacae58012aa8a40ab4e7ca729829786892 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 15:40:29 +0000 Subject: [PATCH 03/12] Support all types for scalars in pylibcudf Expressions --- python/pylibcudf/pylibcudf/expressions.pyx | 50 ++++++++++++++++++- .../pylibcudf/libcudf/wrappers/durations.pxd | 5 +- .../pylibcudf/libcudf/wrappers/timestamps.pxd | 5 +- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/python/pylibcudf/pylibcudf/expressions.pyx b/python/pylibcudf/pylibcudf/expressions.pyx index a44c9e25987..1535f68366b 100644 --- a/python/pylibcudf/pylibcudf/expressions.pyx +++ b/python/pylibcudf/pylibcudf/expressions.pyx @@ -5,7 +5,17 @@ from pylibcudf.libcudf.expressions import \ table_reference as TableReference # no-cython-lint from cython.operator cimport dereference -from libc.stdint cimport int32_t, int64_t +from libc.stdint cimport ( + int8_t, + int16_t, + int32_t, + int64_t, + uint8_t, + uint16_t, + uint32_t, + uint64_t, +) +from libcpp cimport bool from libcpp.memory cimport make_unique, unique_ptr from libcpp.string cimport string from libcpp.utility cimport move @@ -18,12 +28,14 @@ from pylibcudf.libcudf.scalar.scalar cimport ( ) from pylibcudf.libcudf.types cimport size_type, type_id from pylibcudf.libcudf.wrappers.durations cimport ( + duration_D, duration_ms, duration_ns, duration_s, duration_us, ) from pylibcudf.libcudf.wrappers.timestamps cimport ( + timestamp_D, timestamp_ms, timestamp_ns, timestamp_s, @@ -78,6 +90,34 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.INT16: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.INT8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT64: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT32: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT16: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.BOOL8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) elif tid == type_id.FLOAT64: self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) @@ -110,6 +150,10 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.TIMESTAMP_DAYS: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) elif tid == type_id.DURATION_NANOSECONDS: self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) @@ -130,6 +174,10 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.DURATION_DAYS: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) else: raise NotImplementedError( f"Don't know how to make literal with type id {tid}" diff --git a/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd b/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd index 7c648425eb5..c9c960d0a79 100644 --- a/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd @@ -1,9 +1,10 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libc.stdint cimport int64_t +from libc.stdint cimport int32_t, int64_t cdef extern from "cudf/wrappers/durations.hpp" namespace "cudf" nogil: + ctypedef int32_t duration_D ctypedef int64_t duration_s ctypedef int64_t duration_ms ctypedef int64_t duration_us diff --git a/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd b/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd index 50d37fd0a68..5dcd144529d 100644 --- a/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd @@ -1,9 +1,10 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libc.stdint cimport int64_t +from libc.stdint cimport int32_t, int64_t cdef extern from "cudf/wrappers/timestamps.hpp" namespace "cudf" nogil: + ctypedef int32_t timestamp_D ctypedef int64_t timestamp_s ctypedef int64_t timestamp_ms ctypedef int64_t timestamp_us From 68f0a9b392a4066799be6a0a4ae8c83c5472e324 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:34:56 +0000 Subject: [PATCH 04/12] Expose compute_column --- .../pylibcudf/pylibcudf/libcudf/transform.pxd | 5 ++++ python/pylibcudf/pylibcudf/transform.pxd | 3 ++ python/pylibcudf/pylibcudf/transform.pyx | 29 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/python/pylibcudf/pylibcudf/libcudf/transform.pxd b/python/pylibcudf/pylibcudf/libcudf/transform.pxd index d21510bd731..47d79083b66 100644 --- a/python/pylibcudf/pylibcudf/libcudf/transform.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/transform.pxd @@ -27,6 +27,11 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil: column_view input ) except + + cdef unique_ptr[column] compute_column( + table_view table, + expression expr + ) except + + cdef unique_ptr[column] transform( column_view input, string unary_udf, diff --git a/python/pylibcudf/pylibcudf/transform.pxd b/python/pylibcudf/pylibcudf/transform.pxd index b530f433c97..4fb623158f0 100644 --- a/python/pylibcudf/pylibcudf/transform.pxd +++ b/python/pylibcudf/pylibcudf/transform.pxd @@ -3,6 +3,7 @@ from libcpp cimport bool from pylibcudf.libcudf.types cimport bitmask_type, data_type from .column cimport Column +from .expressions cimport Expression from .gpumemoryview cimport gpumemoryview from .table cimport Table from .types cimport DataType @@ -10,6 +11,8 @@ from .types cimport DataType cpdef tuple[gpumemoryview, int] nans_to_nulls(Column input) +cpdef Column compute_column(Table input, Expression expr) + cpdef tuple[gpumemoryview, int] bools_to_mask(Column input) cpdef Column mask_to_bools(Py_ssize_t bitmask, int begin_bit, int end_bit) diff --git a/python/pylibcudf/pylibcudf/transform.pyx b/python/pylibcudf/pylibcudf/transform.pyx index bce9702752a..5d8dea8d5f9 100644 --- a/python/pylibcudf/pylibcudf/transform.pyx +++ b/python/pylibcudf/pylibcudf/transform.pyx @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from cython.operator cimport dereference from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move, pair @@ -43,6 +44,34 @@ cpdef tuple[gpumemoryview, int] nans_to_nulls(Column input): ) +cpdef Column compute_column(Table input, Expression expr): + """Create a column by evaluating an expression on a table. + + For details see :cpp:func:`compute_column`. + + Parameters + ---------- + input : Table + Table used for expression evaluation + expr : Expression + Expression to evaluate + + Returns + ------- + Column of the evaluated expression + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_transform.compute_column( + input.view(), dereference(expr.c_obj.get()) + ) + ) + + return Column.from_libcudf(move(c_result)) + + cpdef tuple[gpumemoryview, int] bools_to_mask(Column input): """Create a bitmask from a column of boolean elements From c3986cd56094297e5046de1dfcfa2e90728769ba Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 22 Oct 2024 15:34:24 +0000 Subject: [PATCH 05/12] Add pylibcudf test for compute_column --- .../pylibcudf/tests/test_expressions.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/python/pylibcudf/pylibcudf/tests/test_expressions.py b/python/pylibcudf/pylibcudf/tests/test_expressions.py index 5894ef4624c..3a0e139dda4 100644 --- a/python/pylibcudf/pylibcudf/tests/test_expressions.py +++ b/python/pylibcudf/pylibcudf/tests/test_expressions.py @@ -1,10 +1,9 @@ # Copyright (c) 2024, NVIDIA CORPORATION. import pyarrow as pa +import pyarrow.compute as pc import pylibcudf as plc import pytest - -# We can't really evaluate these expressions, so just make sure -# construction works properly +from utils import assert_column_eq def test_literal_construction_invalid(): @@ -22,7 +21,7 @@ def test_literal_construction_invalid(): ], ) def test_columnref_construction(tableref): - plc.expressions.ColumnReference(1.0, tableref) + plc.expressions.ColumnReference(1, tableref) def test_columnnameref_construction(): @@ -47,3 +46,35 @@ def test_columnnameref_construction(): ) def test_astoperation_construction(kwargs): plc.expressions.Operation(**kwargs) + + +def test_evaluation(): + table_h = pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + lit = pa.scalar(42, type=pa.int64()) + table = plc.interop.from_arrow(table_h) + # expr = abs(b * c - (a + 42)) + expr = plc.expressions.Operation( + plc.expressions.ASTOperator.ABS, + plc.expressions.Operation( + plc.expressions.ASTOperator.SUB, + plc.expressions.Operation( + plc.expressions.ASTOperator.MUL, + plc.expressions.ColumnReference(1), + plc.expressions.ColumnReference(2), + ), + plc.expressions.Operation( + plc.expressions.ASTOperator.ADD, + plc.expressions.ColumnReference(0), + plc.expressions.Literal(plc.interop.from_arrow(lit)), + ), + ), + ) + + expect = pc.abs( + pc.subtract( + pc.multiply(table_h["b"], table_h["c"]), pc.add(table_h["a"], lit) + ) + ) + got = plc.transform.compute_column(table, expr) + + assert_column_eq(expect, got) From c622a0119983d3ca6c54f370b12a482d5045c837 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 11:35:35 +0000 Subject: [PATCH 06/12] Implement conversion from Expr nodes to pylibcudf Expressions We will use this for inequality joins and filter pushdown in the parquet reader. The handling is a bit complicated, since the subset of expressions that the parquet filter accepts is smaller than all possible expressions. Since much of the logic is similar, however, we just dispatch on a transformer state variable to determine which case we're handling. --- python/cudf_polars/cudf_polars/dsl/to_ast.py | 263 +++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 python/cudf_polars/cudf_polars/dsl/to_ast.py diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py new file mode 100644 index 00000000000..0ffd6867fd3 --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Conversion of expression nodes to libcudf AST nodes.""" + +from __future__ import annotations + +from functools import partial, reduce, singledispatch +from typing import TYPE_CHECKING, TypeAlias + +import pylibcudf as plc +from pylibcudf import expressions as pexpr + +from polars.polars import _expr_nodes as pl_expr + +from cudf_polars.dsl import expr +from cudf_polars.dsl.traversal import CachingVisitor +from cudf_polars.typing import GenericTransformer + +if TYPE_CHECKING: + from collections.abc import Mapping + +# Can't merge these op-mapping dictionaries because scoped enum values +# are exposed by cython with equality/hash based one their underlying +# representation type. So in a dict they are just treated as integers. +BINOP_TO_ASTOP = { + plc.binaryop.BinaryOperator.EQUAL: pexpr.ASTOperator.EQUAL, + plc.binaryop.BinaryOperator.NULL_EQUALS: pexpr.ASTOperator.NULL_EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL: pexpr.ASTOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS: pexpr.ASTOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL: pexpr.ASTOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER: pexpr.ASTOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL: pexpr.ASTOperator.GREATER_EQUAL, + plc.binaryop.BinaryOperator.ADD: pexpr.ASTOperator.ADD, + plc.binaryop.BinaryOperator.SUB: pexpr.ASTOperator.SUB, + plc.binaryop.BinaryOperator.MUL: pexpr.ASTOperator.MUL, + plc.binaryop.BinaryOperator.DIV: pexpr.ASTOperator.DIV, + plc.binaryop.BinaryOperator.TRUE_DIV: pexpr.ASTOperator.TRUE_DIV, + plc.binaryop.BinaryOperator.FLOOR_DIV: pexpr.ASTOperator.FLOOR_DIV, + plc.binaryop.BinaryOperator.PYMOD: pexpr.ASTOperator.PYMOD, + plc.binaryop.BinaryOperator.BITWISE_AND: pexpr.ASTOperator.BITWISE_AND, + plc.binaryop.BinaryOperator.BITWISE_OR: pexpr.ASTOperator.BITWISE_OR, + plc.binaryop.BinaryOperator.BITWISE_XOR: pexpr.ASTOperator.BITWISE_XOR, + plc.binaryop.BinaryOperator.LOGICAL_AND: pexpr.ASTOperator.LOGICAL_AND, + plc.binaryop.BinaryOperator.LOGICAL_OR: pexpr.ASTOperator.LOGICAL_OR, + plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: pexpr.ASTOperator.NULL_LOGICAL_AND, + plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: pexpr.ASTOperator.NULL_LOGICAL_OR, +} + +UOP_TO_ASTOP = { + plc.unary.UnaryOperator.SIN: pexpr.ASTOperator.SIN, + plc.unary.UnaryOperator.COS: pexpr.ASTOperator.COS, + plc.unary.UnaryOperator.TAN: pexpr.ASTOperator.TAN, + plc.unary.UnaryOperator.ARCSIN: pexpr.ASTOperator.ARCSIN, + plc.unary.UnaryOperator.ARCCOS: pexpr.ASTOperator.ARCCOS, + plc.unary.UnaryOperator.ARCTAN: pexpr.ASTOperator.ARCTAN, + plc.unary.UnaryOperator.SINH: pexpr.ASTOperator.SINH, + plc.unary.UnaryOperator.COSH: pexpr.ASTOperator.COSH, + plc.unary.UnaryOperator.TANH: pexpr.ASTOperator.TANH, + plc.unary.UnaryOperator.ARCSINH: pexpr.ASTOperator.ARCSINH, + plc.unary.UnaryOperator.ARCCOSH: pexpr.ASTOperator.ARCCOSH, + plc.unary.UnaryOperator.ARCTANH: pexpr.ASTOperator.ARCTANH, + plc.unary.UnaryOperator.EXP: pexpr.ASTOperator.EXP, + plc.unary.UnaryOperator.LOG: pexpr.ASTOperator.LOG, + plc.unary.UnaryOperator.SQRT: pexpr.ASTOperator.SQRT, + plc.unary.UnaryOperator.CBRT: pexpr.ASTOperator.CBRT, + plc.unary.UnaryOperator.CEIL: pexpr.ASTOperator.CEIL, + plc.unary.UnaryOperator.FLOOR: pexpr.ASTOperator.FLOOR, + plc.unary.UnaryOperator.ABS: pexpr.ASTOperator.ABS, + plc.unary.UnaryOperator.RINT: pexpr.ASTOperator.RINT, + plc.unary.UnaryOperator.BIT_INVERT: pexpr.ASTOperator.BIT_INVERT, + plc.unary.UnaryOperator.NOT: pexpr.ASTOperator.NOT, +} + +SUPPORTED_STATISTICS_BINOPS = { + plc.binaryop.BinaryOperator.EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL, +} + +REVERSED_COMPARISON = { + plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER, + plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL, + plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS, + plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL, +} + + +Transformer: TypeAlias = GenericTransformer[expr.Expr, pexpr.Expression] + + +@singledispatch +def _to_ast(node: expr.Expr, self: Transformer) -> pexpr.Expression: + """ + Translate an expression to a pylibcudf Expression. + + Parameters + ---------- + node + Expression to translate. + self + Recursive transformer. The state dictionary should contain a + `for_parquet` key indicating if this transformation should + provide an expression suitable for use in parquet filters. + + If `for_parquet` is `False`, the dictionary should contain a + `name_to_index` mapping that maps column names to their + integer index in the table that will be used for evaluation of + the expression. + + Returns + ------- + pylibcudf Expression. + + Raises + ------ + NotImplementedError or KeyError if the expression cannot be translated. + """ + raise NotImplementedError(f"Unhandled expression type {type(node)}") + + +@_to_ast.register +def _(node: expr.Col, self: Transformer) -> pexpr.Expression: + if self.state["for_parquet"]: + return pexpr.ColumnNameReference(node.name) + return pexpr.ColumnReference(self.state["name_to_index"][node.name]) + + +@_to_ast.register +def _(node: expr.Literal, self: Transformer) -> pexpr.Expression: + return pexpr.Literal(plc.interop.from_arrow(node.value)) + + +@_to_ast.register +def _(node: expr.BinOp, self: Transformer) -> pexpr.Expression: + if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS: + return pexpr.Operation( + pexpr.ASTOperator.NOT, + self( + # Reconstruct and apply, rather than directly + # constructing the right expression so we get the + # handling of parquet special cases for free. + expr.BinOp( + node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children + ) + ), + ) + if self.state["for_parquet"]: + op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children) + if op1_col ^ op2_col: + op = node.op + if op not in SUPPORTED_STATISTICS_BINOPS: + raise NotImplementedError( + f"Parquet filter binop with column doesn't support {node.op!r}" + ) + op1, op2 = node.children + if op2_col: + (op1, op2) = (op2, op1) + op = REVERSED_COMPARISON[op] + if not isinstance(op2, expr.Literal): + raise NotImplementedError( + "Parquet filter binops must have form 'col binop literal'" + ) + return pexpr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2)) + elif op1_col and op2_col: + raise NotImplementedError( + "Parquet filter binops must have one column reference not two" + ) + return pexpr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children)) + + +@_to_ast.register +def _(node: expr.BooleanFunction, self: Transformer) -> pexpr.Expression: + if node.name == pl_expr.BooleanFunction.IsIn: + needles, haystack = node.children + if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: + # 16 is an arbitrary limit + needle_ref = self(needles) + values = [pexpr.Literal(plc.interop.from_arrow(v)) for v in haystack.value] + return reduce( + partial(pexpr.Operation, pexpr.ASTOperator.LOGICAL_OR), + ( + pexpr.Operation(pexpr.ASTOperator.EQUAL, needle_ref, value) + for value in values + ), + ) + if self.state["for_parquet"] and isinstance(node.children[0], expr.Col): + raise NotImplementedError( + f"Parquet filters don't support {node.name} on columns" + ) + if node.name == pl_expr.BooleanFunction.IsNull: + return pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])) + elif node.name == pl_expr.BooleanFunction.IsNotNull: + return pexpr.Operation( + pexpr.ASTOperator.NOT, + pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])), + ) + elif node.name == pl_expr.BooleanFunction.Not: + return pexpr.Operation(pexpr.ASTOperator.NOT, self(node.children[0])) + raise NotImplementedError(f"AST conversion does not support {node.name}") + + +@_to_ast.register +def _(node: expr.UnaryFunction, self: Transformer) -> pexpr.Expression: + if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]: + raise NotImplementedError( + "Parquet filters don't support {node.name} on columns" + ) + return pexpr.Operation( + UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0]) + ) + + +def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None: + """ + Convert an expression to libcudf AST nodes suitable for parquet filtering. + + Parameters + ---------- + node + Expression to convert. + + Returns + ------- + pylibcudf Expression if conversion is possible, otherwise None. + """ + mapper: Transformer = CachingVisitor(_to_ast, state={"for_parquet": True}) + try: + return mapper(node) + except (KeyError, NotImplementedError): + return None + + +def to_ast( + node: expr.Expr, *, name_to_index: Mapping[str, int] +) -> pexpr.Expression | None: + """ + Convert an expression to libcudf AST nodes suitable for compute_column. + + Parameters + ---------- + node + Expression to convert. + name_to_index + Mapping from column names to their index in the table that + will be used for expression evaluation. + + Returns + ------- + pylibcudf Expressoin if conversion is possible, otherwise None. + """ + mapper: Transformer = CachingVisitor( + _to_ast, state={"for_parquet": False, "name_to_index": name_to_index} + ) + try: + return mapper(node) + except (KeyError, NotImplementedError): + return None From 3732c760553d6f125279639f4ce48a8a30c731cd Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 18:04:59 +0000 Subject: [PATCH 07/12] Implement predicate pushdown into parquet read We attempt to turn the predicate into a filter expression that the parquet reader understands. If successful then we don't have to apply the predicate as a post-filter. We can only do this when a row index is not requested. --- python/cudf_polars/cudf_polars/dsl/ir.py | 9 +++++++++ python/cudf_polars/cudf_polars/dsl/to_ast.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index f79e229d3f3..1aa6741d417 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -28,6 +28,7 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import Column, DataFrame from cudf_polars.dsl.nodebase import Node +from cudf_polars.dsl.to_ast import to_parquet_filter from cudf_polars.utils import dtypes if TYPE_CHECKING: @@ -418,9 +419,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: colnames[0], ) elif self.typ == "parquet": + filters = None + if self.predicate is not None and self.row_index is None: + # Can't apply filters during read if we have a row index. + filters = to_parquet_filter(self.predicate.value) tbl_w_meta = plc.io.parquet.read_parquet( plc.io.SourceInfo(self.paths), columns=with_columns, + filters=filters, nrows=n_rows, skip_rows=self.skip_rows, ) @@ -429,6 +435,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # TODO: consider nested column names? tbl_w_meta.column_names(include_children=False), ) + if filters is not None: + # Mask must have been applied. + return df elif self.typ == "ndjson": json_schema: list[tuple[str, str, list]] = [ (name, typ, []) for name, typ in self.schema.items() diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index 0ffd6867fd3..a55f1930862 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -229,7 +229,7 @@ def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None: ------- pylibcudf Expression if conversion is possible, otherwise None. """ - mapper: Transformer = CachingVisitor(_to_ast, state={"for_parquet": True}) + mapper = CachingVisitor(_to_ast, state={"for_parquet": True}) try: return mapper(node) except (KeyError, NotImplementedError): @@ -254,7 +254,7 @@ def to_ast( ------- pylibcudf Expressoin if conversion is possible, otherwise None. """ - mapper: Transformer = CachingVisitor( + mapper = CachingVisitor( _to_ast, state={"for_parquet": False, "name_to_index": name_to_index} ) try: From 9a62f532bb37c92490fbabe4349468e40b86be9f Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:35:36 +0000 Subject: [PATCH 08/12] Add tests of parquet filters --- .../cudf_polars/tests/test_parquet_filters.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 python/cudf_polars/tests/test_parquet_filters.py diff --git a/python/cudf_polars/tests/test_parquet_filters.py b/python/cudf_polars/tests/test_parquet_filters.py new file mode 100644 index 00000000000..545a89250fc --- /dev/null +++ b/python/cudf_polars/tests/test_parquet_filters.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.fixture(scope="module") +def df(): + return pl.DataFrame( + { + "c": ["a", "b", "c", "d", "e", "f"], + "a": [1, 2, 3, None, 4, 5], + "b": pl.Series([None, None, 3, float("inf"), 4, 0], dtype=pl.Float64), + "d": [-1, 2, -3, None, 4, -5], + } + ) + + +@pytest.fixture(scope="module") +def pq_file(tmp_path_factory, df): + tmp_path = tmp_path_factory.mktemp("parquet_filter") + df.write_parquet(tmp_path / "tmp.pq", row_group_size=3) + return pl.scan_parquet(tmp_path / "tmp.pq") + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("a").is_in([0, 1]), + pl.col("a").is_between(0, 2), + (pl.col("a") < 2).not_(), + pl.lit(2) > pl.col("a"), + pl.lit(2) >= pl.col("a"), + pl.lit(2) < pl.col("a"), + pl.lit(2) <= pl.col("a"), + pl.lit(0) == pl.col("a"), + pl.lit(1) != pl.col("a"), + pl.col("a") == pl.col("d"), + (pl.col("b") < pl.lit(2, dtype=pl.Float64).sqrt()), + (pl.col("a") >= pl.lit(2)) & (pl.col("b") > 0), + pl.col("b").is_finite(), + pl.col("a").is_null(), + pl.col("a").is_not_null(), + pl.col("a").abs().is_between(0, 2), + pl.col("a").ne_missing(pl.lit(None, dtype=pl.Int64)), + ], +) +@pytest.mark.parametrize("selection", [["c", "b"], ["a"], ["a", "c"], ["b"], "c"]) +def test_scan_by_hand(expr, selection, pq_file): + df = pq_file.collect() + q = pq_file.filter(expr).select(*selection) + # Not using assert_gpu_result_equal because + # https://github.com/pola-rs/polars/issues/19238 + got = q.collect(engine=pl.GPUEngine(raise_on_fail=True)) + expect = df.filter(expr).select(*selection) + assert_frame_equal(got, expect) From 16efcaf0476d92ecf5c8f70a7fd53acaf8e1b149 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:37:46 +0000 Subject: [PATCH 09/12] Add tests of to_ast and column compute --- python/cudf_polars/tests/dsl/test_to_ast.py | 78 +++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 python/cudf_polars/tests/dsl/test_to_ast.py diff --git a/python/cudf_polars/tests/dsl/test_to_ast.py b/python/cudf_polars/tests/dsl/test_to_ast.py new file mode 100644 index 00000000000..a7b779a6ec9 --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_to_ast.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pylibcudf as plc +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +import cudf_polars.dsl.ir as ir_nodes +from cudf_polars import translate_ir +from cudf_polars.containers.dataframe import DataFrame, NamedColumn +from cudf_polars.dsl.to_ast import to_ast + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "c": ["a", "b", "c", "d", "e", "f"], + "a": [1, 2, 3, None, 4, 5], + "b": pl.Series([None, None, 3, float("inf"), 4, 0], dtype=pl.Float64), + "d": [False, True, True, None, False, False], + } + ) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("a").is_in([0, 1]), + pl.col("a").is_between(0, 2), + (pl.col("a") < pl.col("b")).not_(), + pl.lit(2) > pl.col("a"), + pl.lit(2) >= pl.col("a"), + pl.lit(2) < pl.col("a"), + pl.lit(2) <= pl.col("a"), + pl.lit(0) == pl.col("a"), + pl.lit(1) != pl.col("a"), + (pl.col("b") < pl.lit(2, dtype=pl.Float64).sqrt()), + (pl.col("a") >= pl.lit(2)) & (pl.col("b") > 0), + pl.col("a").is_null(), + pl.col("a").is_not_null(), + pl.col("b").is_finite(), + pytest.param( + pl.col("a").sin(), + marks=pytest.mark.xfail(reason="Need to insert explicit casts"), + ), + pl.col("b").cos(), + pl.col("a").abs().is_between(0, 2), + pl.col("a").ne_missing(pl.lit(None, dtype=pl.Int64)), + [pl.col("a") * 2, pl.col("b") + pl.col("a")], + pl.col("d").not_(), + ], +) +def test_compute_column(expr, df): + q = df.select(expr) + ir = translate_ir(q._ldf.visit()) + + assert isinstance(ir, ir_nodes.Select) + table = ir.children[0].evaluate(cache={}) + name_to_index = {c.name: i for i, c in enumerate(table.columns)} + + def compute_column(e): + ast = to_ast(e.value, name_to_index=name_to_index) + if ast is not None: + return NamedColumn( + plc.transform.compute_column(table.table, ast), name=e.name + ) + return e.evaluate(table) + + got = DataFrame(map(compute_column, ir.exprs)).to_polars() + + expect = q.collect() + + assert_frame_equal(expect, got) From 43fbe363f86e29efdaaf3d274092a34958e3a8c2 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 25 Oct 2024 10:15:42 +0000 Subject: [PATCH 10/12] One less move --- python/pylibcudf/pylibcudf/transform.pyx | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pylibcudf/pylibcudf/transform.pyx b/python/pylibcudf/pylibcudf/transform.pyx index 5d8dea8d5f9..e8d95cadb0c 100644 --- a/python/pylibcudf/pylibcudf/transform.pyx +++ b/python/pylibcudf/pylibcudf/transform.pyx @@ -63,10 +63,8 @@ cpdef Column compute_column(Table input, Expression expr): cdef unique_ptr[column] c_result with nogil: - c_result = move( - cpp_transform.compute_column( - input.view(), dereference(expr.c_obj.get()) - ) + c_result = cpp_transform.compute_column( + input.view(), dereference(expr.c_obj.get()) ) return Column.from_libcudf(move(c_result)) From a37e4fa048f1f4c3bf0d4d1212ef74a4414c513e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 25 Oct 2024 10:52:10 +0000 Subject: [PATCH 11/12] Use plc.compute_column from legacy cython --- python/cudf/cudf/_lib/transform.pyx | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 40d0c9eac3a..1589e23f716 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -7,20 +7,11 @@ from cudf.core._internals.expressions import parse_expression from cudf.core.buffer import acquire_spill_lock, as_buffer from cudf.utils import cudautils -from cython.operator cimport dereference -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - -cimport pylibcudf.libcudf.transform as libcudf_transform from pylibcudf cimport transform as plc_transform from pylibcudf.expressions cimport Expression -from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.expressions cimport expression -from pylibcudf.libcudf.table.table_view cimport table_view from pylibcudf.libcudf.types cimport size_type from cudf._lib.column cimport Column -from cudf._lib.utils cimport table_view_from_columns import pylibcudf as plc @@ -121,13 +112,8 @@ def compute_column(list columns, tuple column_names, expr: str): # At the end, all the stack contains is the expression to evaluate. cdef Expression cudf_expr = visitor.expression - cdef table_view tbl = table_view_from_columns(columns) - cdef unique_ptr[column] col - with nogil: - col = move( - libcudf_transform.compute_column( - tbl, - dereference(cudf_expr.c_obj.get()) - ) - ) - return Column.from_unique_ptr(move(col)) + result = plc_transform.compute_column( + plc.Table([col.to_pylibcudf(mode="read") for col in columns]), + cudf_expr, + ) + return Column.from_pylibcudf(result) From 953d184f64f83682353ade315f841e783d3b7a7f Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 25 Oct 2024 10:54:30 +0000 Subject: [PATCH 12/12] Less ambiguous import name --- python/cudf_polars/cudf_polars/dsl/to_ast.py | 140 ++++++++++--------- 1 file changed, 71 insertions(+), 69 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index a55f1930862..ffdae81de55 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, TypeAlias import pylibcudf as plc -from pylibcudf import expressions as pexpr +from pylibcudf import expressions as plc_expr from polars.polars import _expr_nodes as pl_expr @@ -24,52 +24,52 @@ # are exposed by cython with equality/hash based one their underlying # representation type. So in a dict they are just treated as integers. BINOP_TO_ASTOP = { - plc.binaryop.BinaryOperator.EQUAL: pexpr.ASTOperator.EQUAL, - plc.binaryop.BinaryOperator.NULL_EQUALS: pexpr.ASTOperator.NULL_EQUAL, - plc.binaryop.BinaryOperator.NOT_EQUAL: pexpr.ASTOperator.NOT_EQUAL, - plc.binaryop.BinaryOperator.LESS: pexpr.ASTOperator.LESS, - plc.binaryop.BinaryOperator.LESS_EQUAL: pexpr.ASTOperator.LESS_EQUAL, - plc.binaryop.BinaryOperator.GREATER: pexpr.ASTOperator.GREATER, - plc.binaryop.BinaryOperator.GREATER_EQUAL: pexpr.ASTOperator.GREATER_EQUAL, - plc.binaryop.BinaryOperator.ADD: pexpr.ASTOperator.ADD, - plc.binaryop.BinaryOperator.SUB: pexpr.ASTOperator.SUB, - plc.binaryop.BinaryOperator.MUL: pexpr.ASTOperator.MUL, - plc.binaryop.BinaryOperator.DIV: pexpr.ASTOperator.DIV, - plc.binaryop.BinaryOperator.TRUE_DIV: pexpr.ASTOperator.TRUE_DIV, - plc.binaryop.BinaryOperator.FLOOR_DIV: pexpr.ASTOperator.FLOOR_DIV, - plc.binaryop.BinaryOperator.PYMOD: pexpr.ASTOperator.PYMOD, - plc.binaryop.BinaryOperator.BITWISE_AND: pexpr.ASTOperator.BITWISE_AND, - plc.binaryop.BinaryOperator.BITWISE_OR: pexpr.ASTOperator.BITWISE_OR, - plc.binaryop.BinaryOperator.BITWISE_XOR: pexpr.ASTOperator.BITWISE_XOR, - plc.binaryop.BinaryOperator.LOGICAL_AND: pexpr.ASTOperator.LOGICAL_AND, - plc.binaryop.BinaryOperator.LOGICAL_OR: pexpr.ASTOperator.LOGICAL_OR, - plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: pexpr.ASTOperator.NULL_LOGICAL_AND, - plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: pexpr.ASTOperator.NULL_LOGICAL_OR, + plc.binaryop.BinaryOperator.EQUAL: plc_expr.ASTOperator.EQUAL, + plc.binaryop.BinaryOperator.NULL_EQUALS: plc_expr.ASTOperator.NULL_EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL: plc_expr.ASTOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS: plc_expr.ASTOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL: plc_expr.ASTOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER: plc_expr.ASTOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL: plc_expr.ASTOperator.GREATER_EQUAL, + plc.binaryop.BinaryOperator.ADD: plc_expr.ASTOperator.ADD, + plc.binaryop.BinaryOperator.SUB: plc_expr.ASTOperator.SUB, + plc.binaryop.BinaryOperator.MUL: plc_expr.ASTOperator.MUL, + plc.binaryop.BinaryOperator.DIV: plc_expr.ASTOperator.DIV, + plc.binaryop.BinaryOperator.TRUE_DIV: plc_expr.ASTOperator.TRUE_DIV, + plc.binaryop.BinaryOperator.FLOOR_DIV: plc_expr.ASTOperator.FLOOR_DIV, + plc.binaryop.BinaryOperator.PYMOD: plc_expr.ASTOperator.PYMOD, + plc.binaryop.BinaryOperator.BITWISE_AND: plc_expr.ASTOperator.BITWISE_AND, + plc.binaryop.BinaryOperator.BITWISE_OR: plc_expr.ASTOperator.BITWISE_OR, + plc.binaryop.BinaryOperator.BITWISE_XOR: plc_expr.ASTOperator.BITWISE_XOR, + plc.binaryop.BinaryOperator.LOGICAL_AND: plc_expr.ASTOperator.LOGICAL_AND, + plc.binaryop.BinaryOperator.LOGICAL_OR: plc_expr.ASTOperator.LOGICAL_OR, + plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: plc_expr.ASTOperator.NULL_LOGICAL_AND, + plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: plc_expr.ASTOperator.NULL_LOGICAL_OR, } UOP_TO_ASTOP = { - plc.unary.UnaryOperator.SIN: pexpr.ASTOperator.SIN, - plc.unary.UnaryOperator.COS: pexpr.ASTOperator.COS, - plc.unary.UnaryOperator.TAN: pexpr.ASTOperator.TAN, - plc.unary.UnaryOperator.ARCSIN: pexpr.ASTOperator.ARCSIN, - plc.unary.UnaryOperator.ARCCOS: pexpr.ASTOperator.ARCCOS, - plc.unary.UnaryOperator.ARCTAN: pexpr.ASTOperator.ARCTAN, - plc.unary.UnaryOperator.SINH: pexpr.ASTOperator.SINH, - plc.unary.UnaryOperator.COSH: pexpr.ASTOperator.COSH, - plc.unary.UnaryOperator.TANH: pexpr.ASTOperator.TANH, - plc.unary.UnaryOperator.ARCSINH: pexpr.ASTOperator.ARCSINH, - plc.unary.UnaryOperator.ARCCOSH: pexpr.ASTOperator.ARCCOSH, - plc.unary.UnaryOperator.ARCTANH: pexpr.ASTOperator.ARCTANH, - plc.unary.UnaryOperator.EXP: pexpr.ASTOperator.EXP, - plc.unary.UnaryOperator.LOG: pexpr.ASTOperator.LOG, - plc.unary.UnaryOperator.SQRT: pexpr.ASTOperator.SQRT, - plc.unary.UnaryOperator.CBRT: pexpr.ASTOperator.CBRT, - plc.unary.UnaryOperator.CEIL: pexpr.ASTOperator.CEIL, - plc.unary.UnaryOperator.FLOOR: pexpr.ASTOperator.FLOOR, - plc.unary.UnaryOperator.ABS: pexpr.ASTOperator.ABS, - plc.unary.UnaryOperator.RINT: pexpr.ASTOperator.RINT, - plc.unary.UnaryOperator.BIT_INVERT: pexpr.ASTOperator.BIT_INVERT, - plc.unary.UnaryOperator.NOT: pexpr.ASTOperator.NOT, + plc.unary.UnaryOperator.SIN: plc_expr.ASTOperator.SIN, + plc.unary.UnaryOperator.COS: plc_expr.ASTOperator.COS, + plc.unary.UnaryOperator.TAN: plc_expr.ASTOperator.TAN, + plc.unary.UnaryOperator.ARCSIN: plc_expr.ASTOperator.ARCSIN, + plc.unary.UnaryOperator.ARCCOS: plc_expr.ASTOperator.ARCCOS, + plc.unary.UnaryOperator.ARCTAN: plc_expr.ASTOperator.ARCTAN, + plc.unary.UnaryOperator.SINH: plc_expr.ASTOperator.SINH, + plc.unary.UnaryOperator.COSH: plc_expr.ASTOperator.COSH, + plc.unary.UnaryOperator.TANH: plc_expr.ASTOperator.TANH, + plc.unary.UnaryOperator.ARCSINH: plc_expr.ASTOperator.ARCSINH, + plc.unary.UnaryOperator.ARCCOSH: plc_expr.ASTOperator.ARCCOSH, + plc.unary.UnaryOperator.ARCTANH: plc_expr.ASTOperator.ARCTANH, + plc.unary.UnaryOperator.EXP: plc_expr.ASTOperator.EXP, + plc.unary.UnaryOperator.LOG: plc_expr.ASTOperator.LOG, + plc.unary.UnaryOperator.SQRT: plc_expr.ASTOperator.SQRT, + plc.unary.UnaryOperator.CBRT: plc_expr.ASTOperator.CBRT, + plc.unary.UnaryOperator.CEIL: plc_expr.ASTOperator.CEIL, + plc.unary.UnaryOperator.FLOOR: plc_expr.ASTOperator.FLOOR, + plc.unary.UnaryOperator.ABS: plc_expr.ASTOperator.ABS, + plc.unary.UnaryOperator.RINT: plc_expr.ASTOperator.RINT, + plc.unary.UnaryOperator.BIT_INVERT: plc_expr.ASTOperator.BIT_INVERT, + plc.unary.UnaryOperator.NOT: plc_expr.ASTOperator.NOT, } SUPPORTED_STATISTICS_BINOPS = { @@ -91,11 +91,11 @@ } -Transformer: TypeAlias = GenericTransformer[expr.Expr, pexpr.Expression] +Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression] @singledispatch -def _to_ast(node: expr.Expr, self: Transformer) -> pexpr.Expression: +def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression: """ Translate an expression to a pylibcudf Expression. @@ -125,22 +125,22 @@ def _to_ast(node: expr.Expr, self: Transformer) -> pexpr.Expression: @_to_ast.register -def _(node: expr.Col, self: Transformer) -> pexpr.Expression: +def _(node: expr.Col, self: Transformer) -> plc_expr.Expression: if self.state["for_parquet"]: - return pexpr.ColumnNameReference(node.name) - return pexpr.ColumnReference(self.state["name_to_index"][node.name]) + return plc_expr.ColumnNameReference(node.name) + return plc_expr.ColumnReference(self.state["name_to_index"][node.name]) @_to_ast.register -def _(node: expr.Literal, self: Transformer) -> pexpr.Expression: - return pexpr.Literal(plc.interop.from_arrow(node.value)) +def _(node: expr.Literal, self: Transformer) -> plc_expr.Expression: + return plc_expr.Literal(plc.interop.from_arrow(node.value)) @_to_ast.register -def _(node: expr.BinOp, self: Transformer) -> pexpr.Expression: +def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS: - return pexpr.Operation( - pexpr.ASTOperator.NOT, + return plc_expr.Operation( + plc_expr.ASTOperator.NOT, self( # Reconstruct and apply, rather than directly # constructing the right expression so we get the @@ -166,26 +166,28 @@ def _(node: expr.BinOp, self: Transformer) -> pexpr.Expression: raise NotImplementedError( "Parquet filter binops must have form 'col binop literal'" ) - return pexpr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2)) + return plc_expr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2)) elif op1_col and op2_col: raise NotImplementedError( "Parquet filter binops must have one column reference not two" ) - return pexpr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children)) + return plc_expr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children)) @_to_ast.register -def _(node: expr.BooleanFunction, self: Transformer) -> pexpr.Expression: +def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: if node.name == pl_expr.BooleanFunction.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit needle_ref = self(needles) - values = [pexpr.Literal(plc.interop.from_arrow(v)) for v in haystack.value] + values = [ + plc_expr.Literal(plc.interop.from_arrow(v)) for v in haystack.value + ] return reduce( - partial(pexpr.Operation, pexpr.ASTOperator.LOGICAL_OR), + partial(plc_expr.Operation, plc_expr.ASTOperator.LOGICAL_OR), ( - pexpr.Operation(pexpr.ASTOperator.EQUAL, needle_ref, value) + plc_expr.Operation(plc_expr.ASTOperator.EQUAL, needle_ref, value) for value in values ), ) @@ -194,29 +196,29 @@ def _(node: expr.BooleanFunction, self: Transformer) -> pexpr.Expression: f"Parquet filters don't support {node.name} on columns" ) if node.name == pl_expr.BooleanFunction.IsNull: - return pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])) + return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) elif node.name == pl_expr.BooleanFunction.IsNotNull: - return pexpr.Operation( - pexpr.ASTOperator.NOT, - pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])), + return plc_expr.Operation( + plc_expr.ASTOperator.NOT, + plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) elif node.name == pl_expr.BooleanFunction.Not: - return pexpr.Operation(pexpr.ASTOperator.NOT, self(node.children[0])) + return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") @_to_ast.register -def _(node: expr.UnaryFunction, self: Transformer) -> pexpr.Expression: +def _(node: expr.UnaryFunction, self: Transformer) -> plc_expr.Expression: if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]: raise NotImplementedError( "Parquet filters don't support {node.name} on columns" ) - return pexpr.Operation( + return plc_expr.Operation( UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0]) ) -def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None: +def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None: """ Convert an expression to libcudf AST nodes suitable for parquet filtering. @@ -238,7 +240,7 @@ def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None: def to_ast( node: expr.Expr, *, name_to_index: Mapping[str, int] -) -> pexpr.Expression | None: +) -> plc_expr.Expression | None: """ Convert an expression to libcudf AST nodes suitable for compute_column.