diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index 8db8172ebd1..1682e7a8a9c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -6,13 +6,12 @@ from __future__ import annotations +from enum import IntEnum, auto from functools import partial, reduce from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from cudf_polars.containers import Column @@ -24,7 +23,10 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + import polars.type_aliases as pl_types + from polars.polars import _expr_nodes as pl_expr from cudf_polars.containers import DataFrame @@ -32,13 +34,46 @@ class BooleanFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `BooleanFunction`.""" + + All = auto() + AllHorizontal = auto() + Any = auto() + AnyHorizontal = auto() + IsBetween = auto() + IsDuplicated = auto() + IsFinite = auto() + IsFirstDistinct = auto() + IsIn = auto() + IsInfinite = auto() + IsLastDistinct = auto() + IsNan = auto() + IsNotNan = auto() + IsNotNull = auto() + IsNull = auto() + IsUnique = auto() + Not = auto() + + @classmethod + def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: + """Convert from polars' `BooleanFunction`.""" + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None + if function != "BooleanFunction": + raise ValueError("BooleanFunction required") + return getattr(cls, name) + __slots__ = ("name", "options") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: pl_expr.BooleanFunction, + name: BooleanFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -46,7 +81,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name == pl_expr.BooleanFunction.IsIn and not all( + if self.name is BooleanFunction.Name.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): # TODO: If polars IR doesn't put the casts in, we need to @@ -110,12 +145,12 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name in ( - pl_expr.BooleanFunction.IsFinite, - pl_expr.BooleanFunction.IsInfinite, + BooleanFunction.Name.IsFinite, + BooleanFunction.Name.IsInfinite, ): # Avoid evaluating the child if the dtype tells us it's unnecessary. (child,) = self.children - is_finite = self.name == pl_expr.BooleanFunction.IsFinite + is_finite = self.name is BooleanFunction.Name.IsFinite if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.interop.from_arrow( pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) @@ -142,10 +177,10 @@ def do_evaluate( ] # Kleene logic for Any (OR) and All (AND) if ignore_nulls is # False - if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All): + if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All): (ignore_nulls,) = self.options (column,) = columns - is_any = self.name == pl_expr.BooleanFunction.Any + is_any = self.name is BooleanFunction.Name.Any agg = plc.aggregation.any() if is_any else plc.aggregation.all() result = plc.reduce.reduce(column.obj, agg, self.dtype) if not ignore_nulls and column.obj.null_count() > 0: @@ -165,27 +200,27 @@ def do_evaluate( # False || Null => Null True && Null => Null return Column(plc.Column.all_null_like(column.obj, 1)) return Column(plc.Column.from_scalar(result, 1)) - if self.name == pl_expr.BooleanFunction.IsNull: + if self.name is BooleanFunction.Name.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) - elif self.name == pl_expr.BooleanFunction.IsNotNull: + elif self.name is BooleanFunction.Name.IsNotNull: (column,) = columns return Column(plc.unary.is_valid(column.obj)) - elif self.name == pl_expr.BooleanFunction.IsNan: + elif self.name is BooleanFunction.Name.IsNan: (column,) = columns return Column( plc.unary.is_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == pl_expr.BooleanFunction.IsNotNan: + elif self.name is BooleanFunction.Name.IsNotNan: (column,) = columns return Column( plc.unary.is_not_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == pl_expr.BooleanFunction.IsFirstDistinct: + elif self.name is BooleanFunction.Name.IsFirstDistinct: (column,) = columns return self._distinct( column, @@ -197,7 +232,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsLastDistinct: + elif self.name is BooleanFunction.Name.IsLastDistinct: (column,) = columns return self._distinct( column, @@ -209,7 +244,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsUnique: + elif self.name is BooleanFunction.Name.IsUnique: (column,) = columns return self._distinct( column, @@ -221,7 +256,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsDuplicated: + elif self.name is BooleanFunction.Name.IsDuplicated: (column,) = columns return self._distinct( column, @@ -233,7 +268,7 @@ def do_evaluate( pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.AllHorizontal: + elif self.name is BooleanFunction.Name.AllHorizontal: return Column( reduce( partial( @@ -244,7 +279,7 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == pl_expr.BooleanFunction.AnyHorizontal: + elif self.name is BooleanFunction.Name.AnyHorizontal: return Column( reduce( partial( @@ -255,10 +290,10 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == pl_expr.BooleanFunction.IsIn: + elif self.name is BooleanFunction.Name.IsIn: needles, haystack = columns return Column(plc.search.contains(haystack.obj, needles.obj)) - elif self.name == pl_expr.BooleanFunction.Not: + elif self.name is BooleanFunction.Name.Not: (column,) = columns return Column( plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index cd8e5c6a4eb..c2dddfd9940 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -6,12 +6,11 @@ from __future__ import annotations +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from cudf_polars.containers import Column @@ -20,33 +19,94 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + + from polars.polars import _expr_nodes as pl_expr + from cudf_polars.containers import DataFrame __all__ = ["TemporalFunction"] class TemporalFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `TemporalFunction`.""" + + BaseUtcOffset = auto() + CastTimeUnit = auto() + Century = auto() + Combine = auto() + ConvertTimeZone = auto() + DSTOffset = auto() + Date = auto() + Datetime = auto() + DatetimeFunction = auto() + Day = auto() + Duration = auto() + Hour = auto() + IsLeapYear = auto() + IsoYear = auto() + Microsecond = auto() + Millennium = auto() + Millisecond = auto() + Minute = auto() + Month = auto() + MonthEnd = auto() + MonthStart = auto() + Nanosecond = auto() + OffsetBy = auto() + OrdinalDay = auto() + Quarter = auto() + ReplaceTimeZone = auto() + Round = auto() + Second = auto() + Time = auto() + TimeStamp = auto() + ToString = auto() + TotalDays = auto() + TotalHours = auto() + TotalMicroseconds = auto() + TotalMilliseconds = auto() + TotalMinutes = auto() + TotalNanoseconds = auto() + TotalSeconds = auto() + Truncate = auto() + Week = auto() + WeekDay = auto() + WithTimeUnit = auto() + Year = auto() + + @classmethod + def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: + """Convert from polars' `TemporalFunction`.""" + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None + if function != "TemporalFunction": + raise ValueError("TemporalFunction required") + return getattr(cls, name) + __slots__ = ("name", "options") - _COMPONENT_MAP: ClassVar[ - dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent] - ] = { - pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR, - pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH, - pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY, - pl_expr.TemporalFunction.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, - pl_expr.TemporalFunction.Hour: plc.datetime.DatetimeComponent.HOUR, - pl_expr.TemporalFunction.Minute: plc.datetime.DatetimeComponent.MINUTE, - pl_expr.TemporalFunction.Second: plc.datetime.DatetimeComponent.SECOND, - pl_expr.TemporalFunction.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, - pl_expr.TemporalFunction.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, - pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, - } _non_child = ("dtype", "name", "options") + _COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = { + Name.Year: plc.datetime.DatetimeComponent.YEAR, + Name.Month: plc.datetime.DatetimeComponent.MONTH, + Name.Day: plc.datetime.DatetimeComponent.DAY, + Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, + Name.Hour: plc.datetime.DatetimeComponent.HOUR, + Name.Minute: plc.datetime.DatetimeComponent.MINUTE, + Name.Second: plc.datetime.DatetimeComponent.SECOND, + Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, + Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, + Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, + } def __init__( self, dtype: plc.DataType, - name: pl_expr.TemporalFunction, + name: TemporalFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -70,7 +130,7 @@ def do_evaluate( for child in self.children ] (column,) = columns - if self.name == pl_expr.TemporalFunction.Microsecond: + if self.name is TemporalFunction.Name.Microsecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) @@ -90,7 +150,7 @@ def do_evaluate( plc.types.DataType(plc.types.TypeId.INT32), ) return Column(total_micros) - elif self.name == pl_expr.TemporalFunction.Nanosecond: + elif self.name is TemporalFunction.Name.Nanosecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 8b66c9d4676..92c3c658c21 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -6,13 +6,13 @@ from __future__ import annotations +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any import pyarrow as pa import pyarrow.compute as pc from polars.exceptions import InvalidOperationError -from polars.polars import _expr_nodes as pl_expr import pylibcudf as plc @@ -23,19 +23,82 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + + from polars.polars import _expr_nodes as pl_expr + from cudf_polars.containers import DataFrame __all__ = ["StringFunction"] class StringFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `StringFunction`.""" + + Base64Decode = auto() + Base64Encode = auto() + ConcatHorizontal = auto() + ConcatVertical = auto() + Contains = auto() + ContainsMany = auto() + CountMatches = auto() + EndsWith = auto() + EscapeRegex = auto() + Extract = auto() + ExtractAll = auto() + ExtractGroups = auto() + Find = auto() + Head = auto() + HexDecode = auto() + HexEncode = auto() + JsonDecode = auto() + JsonPathMatch = auto() + LenBytes = auto() + LenChars = auto() + Lowercase = auto() + PadEnd = auto() + PadStart = auto() + Replace = auto() + ReplaceMany = auto() + Reverse = auto() + Slice = auto() + Split = auto() + SplitExact = auto() + SplitN = auto() + StartsWith = auto() + StripChars = auto() + StripCharsEnd = auto() + StripCharsStart = auto() + StripPrefix = auto() + StripSuffix = auto() + Strptime = auto() + Tail = auto() + Titlecase = auto() + ToDecimal = auto() + ToInteger = auto() + Uppercase = auto() + ZFill = auto() + + @classmethod + def from_polars(cls, obj: pl_expr.StringFunction) -> Self: + """Convert from polars' `StringFunction`.""" + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None + if function != "StringFunction": + raise ValueError("StringFunction required") + return getattr(cls, name) + __slots__ = ("name", "options", "_regex_program") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: pl_expr.StringFunction, + name: StringFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -47,21 +110,21 @@ def __init__( def _validate_input(self): if self.name not in ( - pl_expr.StringFunction.Contains, - pl_expr.StringFunction.EndsWith, - pl_expr.StringFunction.Lowercase, - pl_expr.StringFunction.Replace, - pl_expr.StringFunction.ReplaceMany, - pl_expr.StringFunction.Slice, - pl_expr.StringFunction.Strptime, - pl_expr.StringFunction.StartsWith, - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, - pl_expr.StringFunction.Uppercase, + StringFunction.Name.Contains, + StringFunction.Name.EndsWith, + StringFunction.Name.Lowercase, + StringFunction.Name.Replace, + StringFunction.Name.ReplaceMany, + StringFunction.Name.Slice, + StringFunction.Name.Strptime, + StringFunction.Name.StartsWith, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, + StringFunction.Name.Uppercase, ): raise NotImplementedError(f"String function {self.name}") - if self.name == pl_expr.StringFunction.Contains: + if self.name is StringFunction.Name.Contains: literal, strict = self.options if not literal: if not strict: @@ -82,7 +145,7 @@ def _validate_input(self): raise NotImplementedError( f"Unsupported regex {pattern} for GPU engine." ) from e - elif self.name == pl_expr.StringFunction.Replace: + elif self.name is StringFunction.Name.Replace: _, literal = self.options if not literal: raise NotImplementedError("literal=False is not supported for replace") @@ -93,7 +156,7 @@ def _validate_input(self): raise NotImplementedError( "libcudf replace does not support empty strings" ) - elif self.name == pl_expr.StringFunction.ReplaceMany: + elif self.name is StringFunction.Name.ReplaceMany: (ascii_case_insensitive,) = self.options if ascii_case_insensitive: raise NotImplementedError( @@ -109,12 +172,12 @@ def _validate_input(self): "libcudf replace_many is implemented differently from polars " "for empty strings" ) - elif self.name == pl_expr.StringFunction.Slice: + elif self.name is StringFunction.Name.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) - elif self.name == pl_expr.StringFunction.Strptime: + elif self.name is StringFunction.Name.Strptime: format, _, exact, cache = self.options if cache: raise NotImplementedError("Strptime cache is a CPU feature") @@ -123,9 +186,9 @@ def _validate_input(self): if not exact: raise NotImplementedError("Strptime does not support exact=False") elif self.name in { - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: if not isinstance(self.children[1], Literal): raise NotImplementedError( @@ -140,7 +203,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - if self.name == pl_expr.StringFunction.Contains: + if self.name is StringFunction.Name.Contains: child, arg = self.children column = child.evaluate(df, context=context, mapping=mapping) @@ -157,7 +220,7 @@ def do_evaluate( return Column( plc.strings.contains.contains_re(column.obj, self._regex_program) ) - elif self.name == pl_expr.StringFunction.Slice: + elif self.name is StringFunction.Name.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) @@ -188,16 +251,16 @@ def do_evaluate( ) ) elif self.name in { - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: column, chars = ( c.evaluate(df, context=context, mapping=mapping) for c in self.children ) - if self.name == pl_expr.StringFunction.StripCharsStart: + if self.name is StringFunction.Name.StripCharsStart: side = plc.strings.SideType.LEFT - elif self.name == pl_expr.StringFunction.StripCharsEnd: + elif self.name is StringFunction.Name.StripCharsEnd: side = plc.strings.SideType.RIGHT else: side = plc.strings.SideType.BOTH @@ -207,13 +270,13 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == pl_expr.StringFunction.Lowercase: + if self.name is StringFunction.Name.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == pl_expr.StringFunction.Uppercase: + elif self.name is StringFunction.Name.Uppercase: (column,) = columns return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == pl_expr.StringFunction.EndsWith: + elif self.name is StringFunction.Name.EndsWith: column, suffix = columns return Column( plc.strings.find.ends_with( @@ -223,7 +286,7 @@ def do_evaluate( else suffix.obj, ) ) - elif self.name == pl_expr.StringFunction.StartsWith: + elif self.name is StringFunction.Name.StartsWith: column, prefix = columns return Column( plc.strings.find.starts_with( @@ -233,7 +296,7 @@ def do_evaluate( else prefix.obj, ) ) - elif self.name == pl_expr.StringFunction.Strptime: + elif self.name is StringFunction.Name.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options col = self.children[0].evaluate(df, context=context, mapping=mapping) @@ -265,7 +328,7 @@ def do_evaluate( res.columns()[0], self.dtype, format ) ) - elif self.name == pl_expr.StringFunction.Replace: + elif self.name is StringFunction.Name.Replace: column, target, repl = columns n, _ = self.options return Column( @@ -273,7 +336,7 @@ def do_evaluate( column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n ) ) - elif self.name == pl_expr.StringFunction.ReplaceMany: + elif self.name is StringFunction.Name.ReplaceMany: column, target, repl = columns return Column( plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index acc4b3669af..c3febc833e2 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -8,8 +8,6 @@ from functools import partial, reduce, singledispatch from typing import TYPE_CHECKING, TypeAlias -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from pylibcudf import expressions as plc_expr @@ -185,7 +183,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: @_to_ast.register def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: - if node.name == pl_expr.BooleanFunction.IsIn: + if node.name is expr.BooleanFunction.Name.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit @@ -204,14 +202,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: raise NotImplementedError( f"Parquet filters don't support {node.name} on columns" ) - if node.name == pl_expr.BooleanFunction.IsNull: + if node.name is expr.BooleanFunction.Name.IsNull: return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) - elif node.name == pl_expr.BooleanFunction.IsNotNull: + elif node.name is expr.BooleanFunction.Name.IsNotNull: return plc_expr.Operation( plc_expr.ASTOperator.NOT, plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) - elif node.name == pl_expr.BooleanFunction.Not: + elif node.name is expr.BooleanFunction.Name.Not: return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 9480ce6e535..b1e2de63ba6 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -531,10 +531,16 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex column.dtype, pa.scalar("", type=plc.interop.to_arrow(column.dtype)), ) - return expr.StringFunction(dtype, name, options, column, chars) + return expr.StringFunction( + dtype, + expr.StringFunction.Name.from_polars(name), + options, + column, + chars, + ) return expr.StringFunction( dtype, - name, + expr.StringFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -551,7 +557,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex ) return expr.BooleanFunction( dtype, - name, + expr.BooleanFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -571,7 +577,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex } result_expr = expr.TemporalFunction( dtype, - name, + expr.TemporalFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py new file mode 100644 index 00000000000..7de8f959843 --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pickle + +import pytest + +from polars.polars import _expr_nodes as pl_expr + +from cudf_polars.dsl.expressions.boolean import BooleanFunction +from cudf_polars.dsl.expressions.datetime import TemporalFunction +from cudf_polars.dsl.expressions.string import StringFunction + + +@pytest.fixture(params=[BooleanFunction, StringFunction, TemporalFunction]) +def function(request): + return request.param + + +def test_function_name_serialization_all_values(function): + # Test serialization and deserialization for all values of function.Name + for name in function.Name: + serialized_name = pickle.dumps(name) + deserialized_name = pickle.loads(serialized_name) + assert deserialized_name is name + + +def test_function_name_invalid(function): + # Test invalid attribute name + with pytest.raises(AttributeError, match="InvalidAttribute"): + assert function.Name.InvalidAttribute is function.Name.InvalidAttribute + + +def test_from_polars_all_names(function): + # Test that all valid names of polars expressions are correctly converted + polars_function = getattr(pl_expr, function.__name__) + polars_names = [name for name in dir(polars_function) if not name.startswith("_")] + # Check names advertised by polars are the same as we advertise + assert set(polars_names) == set(function.Name.__members__) + for name in function.Name: + attr = getattr(polars_function, name.name) + assert function.Name.from_polars(attr) == name + + +def test_from_polars_invalid_attribute(function): + # Test converting from invalid attribute name + with pytest.raises(ValueError, match=f"{function.__name__} required"): + function.Name.from_polars("InvalidAttribute") + + +def test_from_polars_invalid_polars_attribute(function): + # Test converting from polars function with invalid attribute name + with pytest.raises(AttributeError, match="InvalidAttribute"): + function.Name.from_polars(f"{function.__name__}.InvalidAttribute")