From 01a370e69931311db6ce08337aaefb309a668c3c Mon Sep 17 00:00:00 2001 From: Michael J Ward Date: Tue, 14 May 2024 08:40:58 -0500 Subject: [PATCH] Upgrade to datafusion 38 (#691) * chore: upgrade datafusion Deps Ref #690 * update concat and concat_ws to use datafusion_functions Moved in https://github.com/apache/datafusion/pull/10089 * feat: upgrade functions.rs Upstream is continuing it's migration to UDFs. Ref https://github.com/apache/datafusion/pull/10098 Ref https://github.com/apache/datafusion/pull/10372 * fix ScalarUDF import * feat: remove deprecated suppors_filter_pushdown and impl supports_filters_pushdown Deprecated function removed in https://github.com/apache/datafusion/pull/9923 * use `unnest_columns_with_options` instead of deprecated `unnest_column_with_option` * remove ScalarFunction wrappers These relied on upstream BuiltinScalarFunction, which are now removed. Ref https://github.com/apache/datafusion/pull/10098 * update dataframe `test_describe` `null_count` was fixed upstream. Ref https://github.com/apache/datafusion/pull/10260 * remove PyDFField and related methods DFField was removed upstream. Ref: https://github.com/apache/datafusion/pull/9595 * bump `datafusion-python` package version to 38.0.0 * re-implement `PyExpr::column_name` The previous implementation relied on `DFField` which was removed upstream. Ref: https://github.com/apache/datafusion/pull/9595 --- Cargo.lock | 145 +++++++++++++-------------- Cargo.toml | 16 +-- datafusion/__init__.py | 3 - datafusion/tests/test_dataframe.py | 6 +- datafusion/tests/test_imports.py | 7 +- src/common.rs | 2 - src/common/df_field.rs | 111 --------------------- src/dataframe.rs | 2 +- src/dataset.rs | 16 ++- src/expr.rs | 26 +++-- src/expr/scalar_function.rs | 65 ------------ src/functions.rs | 152 +++++++++++++++-------------- src/udf.rs | 2 +- 13 files changed, 181 insertions(+), 372 deletions(-) delete mode 100644 src/common/df_field.rs delete mode 100644 src/expr/scalar_function.rs diff --git a/Cargo.lock b/Cargo.lock index 5eb791b46..6b4568b96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -733,9 +733,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85069782056753459dc47e386219aa1fdac5b731f26c28abb8c0ffd4b7c5ab11" +checksum = "05fb4eeeb7109393a0739ac5b8fd892f95ccef691421491c85544f7997366f68" dependencies = [ "ahash", "apache-avro", @@ -754,6 +754,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-array", "datafusion-optimizer", "datafusion-physical-expr", @@ -786,9 +787,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "309d9040751f6dc9e33c85dce6abb55a46ef7ea3644577dd014611c379447ef3" +checksum = "741aeac15c82f239f2fc17deccaab19873abbd62987be20023689b15fa72fa09" dependencies = [ "ahash", "apache-avro", @@ -809,18 +810,18 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e4a44d8ef1b1e85d32234e6012364c411c3787859bb3bba893b0332cb03dfd" +checksum = "6e8ddfb8d8cb51646a30da0122ecfffb81ca16919ae9a3495a9e7468bdcd52b8" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06a3a29ae36bcde07d179cc33b45656a8e7e4d023623e320e48dcf1200eeee95" +checksum = "282122f90b20e8f98ebfa101e4bf20e718fd2684cf81bef4e8c6366571c64404" dependencies = [ "arrow", "chrono", @@ -839,9 +840,9 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a3542aa322029c2121a671ce08000d4b274171070df13f697b14169ccf4f628" +checksum = "5478588f733df0dfd87a62671c7478f590952c95fa2fa5c137e3ff2929491e22" dependencies = [ "ahash", "arrow", @@ -849,6 +850,7 @@ dependencies = [ "chrono", "datafusion-common", "paste", + "serde_json", "sqlparser", "strum 0.26.1", "strum_macros 0.26.1", @@ -856,9 +858,9 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd221792c666eac174ecc09e606312844772acc12cbec61a420c2fca1ee70959" +checksum = "f4afd261cea6ac9c3ca1192fd5e9f940596d8e9208c5b1333f4961405db53185" dependencies = [ "arrow", "base64 0.22.1", @@ -869,21 +871,39 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "hashbrown 0.14.3", "hex", "itertools 0.12.0", "log", "md-5", + "rand", "regex", "sha2", "unicode-segmentation", "uuid", ] +[[package]] +name = "datafusion-functions-aggregate" +version = "38.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b36a6c4838ab94b5bf8f7a96ce6ce059d805c5d1dcaa6ace49e034eb65cd999" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "log", + "paste", + "sqlparser", +] + [[package]] name = "datafusion-functions-array" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e501801e84d9c6ef54caaebcda1b18a6196a24176c12fb70e969bc0572e03c55" +checksum = "d5fdd200a6233f48d3362e7ccb784f926f759100e44ae2137a5e2dcb986a59c4" dependencies = [ "arrow", "arrow-array", @@ -901,9 +921,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bd7f5087817deb961764e8c973d243b54f8572db414a8f0a8f33a48f991e0a" +checksum = "54f2820938810e8a2d71228fd6f59f33396aebc5f5f687fcbf14de5aab6a7e1a" dependencies = [ "arrow", "async-trait", @@ -912,6 +932,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.3", + "indexmap", "itertools 0.12.0", "log", "regex-syntax", @@ -919,9 +940,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cabc0d9aaa0f5eb1b472112f16223c9ffd2fb04e58cbf65c0a331ee6e993f96" +checksum = "9adf8eb12716f52ddf01e09eb6c94d3c9b291e062c05c91b839a448bddba2ff8" dependencies = [ "ahash", "arrow", @@ -931,37 +952,45 @@ dependencies = [ "arrow-schema", "arrow-string", "base64 0.22.1", - "blake2", - "blake3", "chrono", "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", + "datafusion-physical-expr-common", "half", "hashbrown 0.14.3", "hex", "indexmap", "itertools 0.12.0", "log", - "md-5", "paste", "petgraph", - "rand", "regex", - "sha2", - "unicode-segmentation", +] + +[[package]] +name = "datafusion-physical-expr-common" +version = "38.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d5472c3230584c150197b3f2c23f2392b9dc54dbfb62ad41e7e36447cfce4be" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-expr", ] [[package]] name = "datafusion-physical-plan" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17c0523e9c8880f2492a88bbd857dde02bed1ed23f3e9211a89d3d7ec3b44af9" +checksum = "18ae750c38389685a8b62e5b899bbbec488950755ad6d218f3662d35b800c4fe" dependencies = [ "ahash", "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "async-trait", "chrono", @@ -969,7 +998,9 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr", + "datafusion-physical-expr-common", "futures", "half", "hashbrown 0.14.3", @@ -985,7 +1016,7 @@ dependencies = [ [[package]] name = "datafusion-python" -version = "37.1.0" +version = "38.0.0" dependencies = [ "async-trait", "datafusion", @@ -1013,9 +1044,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb54b42227136f6287573f2434b1de249fe1b8e6cd6cc73a634e4a3ec29356" +checksum = "befc67a3cdfbfa76853f43b10ac27337821bb98e519ab6baf431fcc0bcfcafdb" dependencies = [ "arrow", "arrow-array", @@ -1029,9 +1060,9 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "37.1.0" +version = "38.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd3b496697ac22a857c7d497b9d6b40edec19ed2e3e86e2b77051541fefb4c6d" +checksum = "1f62542caa77df003e23a8bc2f1b8a1ffc682fe447c7fcb4905d109e3d7a5b9d" dependencies = [ "async-recursion", "chrono", @@ -1260,19 +1291,6 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" -[[package]] -name = "git2" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" -dependencies = [ - "bitflags 2.4.2", - "libc", - "libgit2-sys", - "log", - "url", -] - [[package]] name = "glob" version = "0.3.1" @@ -1654,18 +1672,6 @@ dependencies = [ "rle-decode-fast", ] -[[package]] -name = "libgit2-sys" -version = "0.16.2+1.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8" -dependencies = [ - "cc", - "libc", - "libz-sys", - "pkg-config", -] - [[package]] name = "libm" version = "0.2.8" @@ -1682,18 +1688,6 @@ dependencies = [ "libc", ] -[[package]] -name = "libz-sys" -version = "1.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "295c17e837573c8c821dbaeb3cceb3d745ad082f7572191409e69cbc1b3fd050" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -2762,9 +2756,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.44.0" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf9c7ff146298ffda83a200f8d5084f08dcee1edfc135fcc1d646a45d50ffd6" +checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" dependencies = [ "log", "sqlparser_derive", @@ -2830,11 +2824,10 @@ dependencies = [ [[package]] name = "substrait" -version = "0.28.1" +version = "0.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df9531ae6784dee4c018ebdb0226872b63cc28765bfa65c1e53b6c58584232af" +checksum = "f01344023c2614171a9ffd6e387eea14e12f7387c5b6adb33f1563187d65e376" dependencies = [ - "git2", "heck 0.5.0", "prettyplease", "prost", @@ -3221,12 +3214,6 @@ dependencies = [ "serde", ] -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 9da36d710..cde3be222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "datafusion-python" -version = "37.1.0" +version = "38.0.0" homepage = "https://datafusion.apache.org/python" repository = "https://github.com/apache/datafusion-python" authors = ["Apache DataFusion "] @@ -37,13 +37,13 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.8" pyo3 = { version = "0.20", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { version = "37.1.0", features = ["pyarrow", "avro", "unicode_expressions"] } -datafusion-common = { version = "37.1.0", features = ["pyarrow"] } -datafusion-expr = "37.1.0" -datafusion-functions-array = "37.1.0" -datafusion-optimizer = "37.1.0" -datafusion-sql = "37.1.0" -datafusion-substrait = { version = "37.1.0", optional = true } +datafusion = { version = "38.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } +datafusion-common = { version = "38.0.0", features = ["pyarrow"] } +datafusion-expr = "38.0.0" +datafusion-functions-array = "38.0.0" +datafusion-optimizer = "38.0.0" +datafusion-sql = "38.0.0" +datafusion-substrait = { version = "38.0.0", optional = true } prost = "0.12" prost-types = "0.12" uuid = { version = "1.8", features = ["v4"] } diff --git a/datafusion/__init__.py b/datafusion/__init__.py index c50bf649d..d0b823bbd 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -37,7 +37,6 @@ ) from .common import ( - DFField, DFSchema, ) @@ -64,8 +63,6 @@ IsNotFalse, IsNotUnknown, Negative, - ScalarFunction, - BuiltinScalarFunction, InList, Exists, Subquery, diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index efb1679b9..2f6a818ea 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -730,9 +730,9 @@ def test_describe(df): "max", "median", ], - "a": [3.0, 3.0, 2.0, 1.0, 1.0, 3.0, 2.0], - "b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0], - "c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], + "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0], + "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0], + "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], } diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index 766ddce89..2a8a3de83 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -27,7 +27,6 @@ ) from datafusion.common import ( - DFField, DFSchema, ) @@ -64,8 +63,6 @@ IsNotFalse, IsNotUnknown, Negative, - ScalarFunction, - BuiltinScalarFunction, InList, Exists, Subquery, @@ -139,8 +136,6 @@ def test_class_module_is_datafusion(): IsNotFalse, IsNotUnknown, Negative, - ScalarFunction, - BuiltinScalarFunction, InList, Exists, Subquery, @@ -165,7 +160,7 @@ def test_class_module_is_datafusion(): assert klass.__module__ == "datafusion.expr" # schema - for klass in [DFField, DFSchema]: + for klass in [DFSchema]: assert klass.__module__ == "datafusion.common" diff --git a/src/common.rs b/src/common.rs index 45523173c..682639aca 100644 --- a/src/common.rs +++ b/src/common.rs @@ -18,7 +18,6 @@ use pyo3::prelude::*; pub mod data_type; -pub mod df_field; pub mod df_schema; pub mod function; pub mod schema; @@ -26,7 +25,6 @@ pub mod schema; /// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/common/df_field.rs b/src/common/df_field.rs deleted file mode 100644 index 68c05361f..000000000 --- a/src/common/df_field.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use datafusion_common::{DFField, OwnedTableReference}; -use pyo3::prelude::*; - -use super::data_type::PyDataType; - -/// PyDFField wraps an arrow-datafusion `DFField` struct type -/// and also supplies convenience methods for interacting -/// with the `DFField` instance in the context of Python -#[pyclass(name = "DFField", module = "datafusion.common", subclass)] -#[derive(Debug, Clone)] -pub struct PyDFField { - pub field: DFField, -} - -impl From for DFField { - fn from(py_field: PyDFField) -> DFField { - py_field.field - } -} - -impl From for PyDFField { - fn from(field: DFField) -> PyDFField { - PyDFField { field } - } -} - -#[pymethods] -impl PyDFField { - #[new] - #[pyo3(signature = (qualifier=None, name="", data_type=DataType::Int64.into(), nullable=false))] - fn new(qualifier: Option, name: &str, data_type: PyDataType, nullable: bool) -> Self { - PyDFField { - field: DFField::new( - qualifier.map(OwnedTableReference::from), - name, - data_type.into(), - nullable, - ), - } - } - - // TODO: Need bindings for Array `Field` first - // #[staticmethod] - // #[pyo3(name = "from")] - // fn py_from(field: Field) -> Self {} - - // TODO: Need bindings for Array `Field` first - // #[staticmethod] - // #[pyo3(name = "from_qualified")] - // fn py_from_qualified(field: Field) -> Self {} - - #[pyo3(name = "name")] - fn py_name(&self) -> PyResult { - Ok(self.field.name().clone()) - } - - #[pyo3(name = "data_type")] - fn py_data_type(&self) -> PyResult { - Ok(self.field.data_type().clone().into()) - } - - #[pyo3(name = "is_nullable")] - fn py_is_nullable(&self) -> PyResult { - Ok(self.field.is_nullable()) - } - - #[pyo3(name = "qualified_name")] - fn py_qualified_name(&self) -> PyResult { - Ok(self.field.qualified_name()) - } - - // TODO: Need bindings for `Column` first - // #[pyo3(name = "qualified_column")] - // fn py_qualified_column(&self) -> PyResult {} - - // TODO: Need bindings for `Column` first - // #[pyo3(name = "unqualified_column")] - // fn py_unqualified_column(&self) -> PyResult {} - - #[pyo3(name = "qualifier")] - fn py_qualifier(&self) -> PyResult> { - Ok(self.field.qualifier().map(|q| format!("{}", q))) - } - - // TODO: Need bindings for Arrow `Field` first - // #[pyo3(name = "field")] - // fn py_field(&self) -> PyResult {} - - #[pyo3(name = "strip_qualifier")] - fn py_strip_qualifier(&self) -> PyResult { - Ok(self.field.clone().strip_qualifier().into()) - } -} diff --git a/src/dataframe.rs b/src/dataframe.rs index f1efc0c7a..8f4514398 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -301,7 +301,7 @@ impl PyDataFrame { .df .as_ref() .clone() - .unnest_column_with_options(column, unnest_options)?; + .unnest_columns_with_options(&[column], unnest_options)?; Ok(Self::new(df)) } diff --git a/src/dataset.rs b/src/dataset.rs index 713610c51..fcbb503c0 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -117,10 +117,16 @@ impl TableProvider for Dataset { /// Tests whether the table provider can make use of a filter expression /// to optimise data retrieval. - fn supports_filter_pushdown(&self, filter: &Expr) -> DFResult { - match PyArrowFilterExpression::try_from(filter) { - Ok(_) => Ok(TableProviderFilterPushDown::Exact), - _ => Ok(TableProviderFilterPushDown::Unsupported), - } + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> DFResult> { + filter + .iter() + .map(|&f| match PyArrowFilterExpression::try_from(f) { + Ok(_) => Ok(TableProviderFilterPushDown::Exact), + _ => Ok(TableProviderFilterPushDown::Unsupported), + }) + .collect() } } diff --git a/src/expr.rs b/src/expr.rs index 3be0d025c..2f1477457 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::utils::exprlist_to_fields; +use datafusion_expr::LogicalPlan; use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; +use std::sync::Arc; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::scalar::ScalarValue; -use datafusion_common::DFField; use datafusion_expr::{ col, expr::{AggregateFunction, InList, InSubquery, ScalarFunction, Sort, WindowFunction}, - lit, - utils::exprlist_to_fields, - Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, Like, LogicalPlan, - Operator, TryCast, + lit, Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, Like, Operator, + TryCast, }; use crate::common::data_type::{DataTypeMap, RexType}; @@ -80,7 +80,6 @@ pub mod logical_node; pub mod placeholder; pub mod projection; pub mod repartition; -pub mod scalar_function; pub mod scalar_subquery; pub mod scalar_variable; pub mod signature; @@ -567,14 +566,14 @@ impl PyExpr { impl PyExpr { pub fn _column_name(&self, plan: &LogicalPlan) -> Result { let field = Self::expr_to_field(&self.expr, plan)?; - Ok(field.qualified_column().flat_name()) + Ok(field.name().to_owned()) } - /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against + /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against pub fn expr_to_field( expr: &Expr, input_plan: &LogicalPlan, - ) -> Result { + ) -> Result, DataFusionError> { match expr { Expr::Sort(Sort { expr, .. }) => { // DataFusion does not support create_name for sort expressions (since they never @@ -583,16 +582,15 @@ impl PyExpr { } Expr::Wildcard { .. } => { // Since * could be any of the valid column names just return the first one - Ok(input_plan.schema().field(0).clone()) + Ok(Arc::new(input_plan.schema().field(0).clone())) } _ => { let fields = exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?; - Ok(fields[0].clone()) + Ok(fields[0].1.clone()) } } } - fn _types(expr: &Expr) -> PyResult { match expr { Expr::BinaryExpr(BinaryExpr { @@ -665,8 +663,6 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/expr/scalar_function.rs b/src/expr/scalar_function.rs deleted file mode 100644 index 776ca3297..000000000 --- a/src/expr/scalar_function.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::expr::PyExpr; -use datafusion_expr::{BuiltinScalarFunction, Expr}; -use pyo3::prelude::*; - -#[pyclass(name = "ScalarFunction", module = "datafusion.expr", subclass)] -#[derive(Clone)] -pub struct PyScalarFunction { - scalar_function: BuiltinScalarFunction, - args: Vec, -} - -impl PyScalarFunction { - pub fn new(scalar_function: BuiltinScalarFunction, args: Vec) -> Self { - Self { - scalar_function, - args, - } - } -} - -#[pyclass(name = "BuiltinScalarFunction", module = "datafusion.expr", subclass)] -#[derive(Clone)] -pub struct PyBuiltinScalarFunction { - scalar_function: BuiltinScalarFunction, -} - -impl From for PyBuiltinScalarFunction { - fn from(scalar_function: BuiltinScalarFunction) -> PyBuiltinScalarFunction { - PyBuiltinScalarFunction { scalar_function } - } -} - -impl From for BuiltinScalarFunction { - fn from(scalar_function: PyBuiltinScalarFunction) -> Self { - scalar_function.scalar_function - } -} - -#[pymethods] -impl PyScalarFunction { - fn fun(&self) -> PyResult { - Ok(self.scalar_function.into()) - } - - fn args(&self) -> PyResult> { - Ok(self.args.iter().map(|e| e.clone().into()).collect()) - } -} diff --git a/src/functions.rs b/src/functions.rs index 7f6b1a877..4b137d90d 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -24,17 +24,46 @@ use crate::expr::window::PyWindowFrame; use crate::expr::PyExpr; use datafusion::execution::FunctionRegistry; use datafusion::functions; +use datafusion::functions_aggregate; use datafusion_common::{Column, ScalarValue, TableReference}; use datafusion_expr::expr::Alias; use datafusion_expr::{ aggregate_function, expr::{ - find_df_window_func, AggregateFunction, AggregateFunctionDefinition, ScalarFunction, Sort, - WindowFunction, + find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction, }, - lit, BuiltinScalarFunction, Expr, WindowFunctionDefinition, + lit, Expr, WindowFunctionDefinition, }; +#[pyfunction] +#[pyo3(signature = (y, x, distinct = false, filter = None, order_by = None))] +pub fn covar_samp( + y: PyExpr, + x: PyExpr, + distinct: bool, + filter: Option, + order_by: Option>, + // null_treatment: Option, +) -> PyExpr { + let filter = filter.map(|x| Box::new(x.expr)); + let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::>()); + functions_aggregate::expr_fn::covar_samp(y.expr, x.expr, distinct, filter, order_by, None) + .into() +} + +#[pyfunction] +#[pyo3(signature = (y, x, distinct = false, filter = None, order_by = None))] +pub fn covar( + y: PyExpr, + x: PyExpr, + distinct: bool, + filter: Option, + order_by: Option>, +) -> PyExpr { + // alias for covar_samp + covar_samp(y, x, distinct, filter, order_by) +} + #[pyfunction] fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { datafusion_expr::in_list( @@ -134,7 +163,7 @@ fn digest(value: PyExpr, method: PyExpr) -> PyExpr { #[pyo3(signature = (*args))] fn concat(args: Vec) -> PyResult { let args = args.into_iter().map(|e| e.expr).collect::>(); - Ok(datafusion_expr::concat(&args).into()) + Ok(functions::string::expr_fn::concat(args).into()) } /// Concatenates all but the first argument, with separators. @@ -144,7 +173,7 @@ fn concat(args: Vec) -> PyResult { #[pyo3(signature = (sep, *args))] fn concat_ws(sep: String, args: Vec) -> PyResult { let args = args.into_iter().map(|e| e.expr).collect::>(); - Ok(datafusion_expr::concat_ws(lit(sep), args).into()) + Ok(functions::string::expr_fn::concat_ws(lit(sep), args).into()) } /// Creates a new Sort Expr @@ -249,27 +278,6 @@ fn window( }) } -macro_rules! scalar_function { - ($NAME: ident, $FUNC: ident) => { - scalar_function!($NAME, $FUNC, stringify!($NAME)); - }; - - ($NAME: ident, $FUNC: ident, $DOC: expr) => { - #[doc = $DOC] - #[pyfunction] - #[pyo3(signature = (*args))] - fn $NAME(args: Vec) -> PyExpr { - let expr = datafusion_expr::Expr::ScalarFunction(ScalarFunction { - func_def: datafusion_expr::ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::$FUNC, - ), - args: args.into_iter().map(|e| e.into()).collect(), - }); - expr.into() - } - }; -} - macro_rules! aggregate_function { ($NAME: ident, $FUNC: ident) => { aggregate_function!($NAME, $FUNC, stringify!($NAME)); @@ -370,21 +378,21 @@ macro_rules! array_fn { expr_fn!(abs, num); expr_fn!(acos, num); -scalar_function!(acosh, Acosh); +expr_fn!(acosh, num); expr_fn!(ascii, arg1, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); expr_fn!(asin, num); -scalar_function!(asinh, Asinh); -scalar_function!(atan, Atan); -scalar_function!(atanh, Atanh); -scalar_function!(atan2, Atan2); +expr_fn!(asinh, num); +expr_fn!(atan, num); +expr_fn!(atanh, num); +expr_fn!(atan2, y x); expr_fn!( bit_length, arg, "Returns number of bits in the string (8 times the octet_length)." ); expr_fn_vec!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); -scalar_function!(cbrt, Cbrt); -scalar_function!(ceil, Ceil); +expr_fn!(cbrt, num); +expr_fn!(ceil, num); expr_fn!( character_length, string, @@ -393,44 +401,44 @@ expr_fn!( expr_fn!(length, string); expr_fn!(char_length, string); expr_fn!(chr, arg, "Returns the character with the given code."); -scalar_function!(coalesce, Coalesce); -scalar_function!(cos, Cos); -scalar_function!(cosh, Cosh); -scalar_function!(degrees, Degrees); +expr_fn_vec!(coalesce); +expr_fn!(cos, num); +expr_fn!(cosh, num); +expr_fn!(degrees, num); expr_fn!(decode, input encoding); expr_fn!(encode, input encoding); -scalar_function!(exp, Exp); -scalar_function!(factorial, Factorial); -scalar_function!(floor, Floor); -scalar_function!(gcd, Gcd); -scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); +expr_fn!(exp, num); +expr_fn!(factorial, num); +expr_fn!(floor, num); +expr_fn!(gcd, x y); +expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); expr_fn!(isnan, num); -scalar_function!(iszero, Iszero); -scalar_function!(lcm, Lcm); -scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); -scalar_function!(ln, Ln); -scalar_function!(log, Log); -scalar_function!(log10, Log10); -scalar_function!(log2, Log2); +expr_fn!(iszero, num); +expr_fn!(lcm, x y); +expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); +expr_fn!(ln, num); +expr_fn!(log, base num); +expr_fn!(log10, num); +expr_fn!(log2, num); expr_fn!(lower, arg1, "Converts the string to all lower case"); -scalar_function!(lpad, Lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); +expr_fn_vec!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); expr_fn_vec!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); expr_fn!( md5, input_arg, "Computes the MD5 hash of the argument, with the result written in hexadecimal." ); -scalar_function!( +expr_fn!( nanvl, - Nanvl, + x y, "Returns x if x is not NaN otherwise returns y." ); expr_fn!(nullif, arg_1 arg_2); expr_fn_vec!(octet_length, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); -scalar_function!(pi, Pi); -scalar_function!(power, Power); -scalar_function!(pow, Power); -scalar_function!(radians, Radians); +expr_fn!(pi); +expr_fn!(power, base exponent); +expr_fn!(pow, power, base exponent); +expr_fn!(radians, num); expr_fn!(regexp_match, input_arg1 input_arg2); expr_fn!( regexp_replace, @@ -443,31 +451,31 @@ expr_fn!( string from to, "Replaces all occurrences in string of substring from with substring to." ); -scalar_function!( +expr_fn!( reverse, - Reverse, + string, "Reverses the order of the characters in the string." ); -scalar_function!(right, Right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); -scalar_function!(round, Round); -scalar_function!(rpad, Rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); +expr_fn!(right, string n, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); +expr_fn_vec!(round); +expr_fn_vec!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); expr_fn_vec!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); expr_fn!(sha224, input_arg1); expr_fn!(sha256, input_arg1); expr_fn!(sha384, input_arg1); expr_fn!(sha512, input_arg1); -scalar_function!(signum, Signum); -scalar_function!(sin, Sin); -scalar_function!(sinh, Sinh); +expr_fn!(signum, num); +expr_fn!(sin, num); +expr_fn!(sinh, num); expr_fn!( split_part, string delimiter index, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)." ); -scalar_function!(sqrt, Sqrt); +expr_fn!(sqrt, num); expr_fn!(starts_with, arg1 arg2, "Returns true if string starts with prefix."); -scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); -scalar_function!(substr, Substr); +expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); +expr_fn!(substr, string position); expr_fn!(tan, num); expr_fn!(tanh, num); expr_fn!( @@ -488,15 +496,15 @@ expr_fn!(date_trunc, part date); expr_fn!(datetrunc, date_trunc, part date); expr_fn!(date_bin, stride source origin); -scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); +expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); -scalar_function!(trunc, Trunc); +expr_fn_vec!(trunc); expr_fn!(upper, arg1, "Converts the string to all upper case."); expr_fn!(uuid); -expr_fn!(r#struct, args); // Use raw identifier since struct is a keyword +expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); -scalar_function!(random, Random); +expr_fn!(random); // Array Functions array_fn!(array_append, array element); @@ -565,9 +573,7 @@ aggregate_function!(array_agg, ArrayAgg); aggregate_function!(avg, Avg); aggregate_function!(corr, Correlation); aggregate_function!(count, Count); -aggregate_function!(covar, Covariance); aggregate_function!(covar_pop, CovariancePop); -aggregate_function!(covar_samp, Covariance); aggregate_function!(grouping, Grouping); aggregate_function!(max, Max); aggregate_function!(mean, Avg); diff --git a/src/udf.rs b/src/udf.rs index 69519f499..8f5ca30b1 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -23,9 +23,9 @@ use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::DataFusionError; -use datafusion::physical_plan::udf::ScalarUDF; use datafusion_expr::create_udf; use datafusion_expr::function::ScalarFunctionImplementation; +use datafusion_expr::ScalarUDF; use crate::expr::PyExpr; use crate::utils::parse_volatility;