From 1676e938f0fdd6aceb663e1768c09acbad47a678 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Mon, 15 Jul 2024 21:36:03 +0800 Subject: [PATCH] stage progress --- datafusion/core/Cargo.toml | 4 - datafusion/functions/Cargo.toml | 18 ++- datafusion/functions/src/lib.rs | 1 - datafusion/functions/src/udf.rs | 103 ++++++++++++++++-- datafusion/physical-expr/Cargo.toml | 6 +- .../physical-expr/src/expressions/binary.rs | 12 +- 6 files changed, 119 insertions(+), 25 deletions(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 10e36e4d41b5..532ca8fde9e7 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -57,7 +57,6 @@ default = [ "unicode_expressions", "compression", "parquet", - "arrow_udf", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) @@ -76,9 +75,6 @@ unicode_expressions = [ "datafusion-sql/unicode_expressions", "datafusion-functions/unicode_expressions", ] -arrow_udf = [ - "datafusion-functions/arrow_udf", -] [dependencies] ahash = { workspace = true } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index f6aa3cc6e7b5..48e46dc2bdf5 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -31,6 +31,18 @@ rust-version = { workspace = true } [lints] workspace = true +[profile.dev] +codegen-units = 1 + +[profile.release] +codegen-units = 1 + +[profile.bench] +codegen-units = 1 + +[profile.test] +codegen-units = 1 + [features] # enable core functions core_expressions = [] @@ -46,7 +58,6 @@ default = [ "regex_expressions", "string_expressions", "unicode_expressions", - "arrow_udf", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] @@ -58,8 +69,6 @@ regex_expressions = ["regex"] string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] -# enable arrow_udf -arrow_udf = ["arrow-udf",] [lib] name = "datafusion_functions" @@ -70,7 +79,8 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } arrow-buffer = { workspace = true } -arrow-udf = { workspace = true, optional = true, features = ["global_registry"] } +arrow-udf = { version="0.3.0", features = ["global_registry"] } +linkme = { version = "0.3.27"} base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 5e32f34106cb..bd4f6aefac47 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -133,7 +133,6 @@ make_stub_package!(unicode, "unicode_expressions"); #[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] pub mod planner; -#[cfg(feature = "arrow_udf")] pub mod udf; mod utils; diff --git a/datafusion/functions/src/udf.rs b/datafusion/functions/src/udf.rs index 91774cca0d19..d89c2d95c631 100644 --- a/datafusion/functions/src/udf.rs +++ b/datafusion/functions/src/udf.rs @@ -17,11 +17,100 @@ use arrow_udf::function; -#[function("eq(bool, bool) -> bool", output="eval_eq_boolean")] -#[function("eq(string, string) -> bool", output="eval_eq_string")] -#[function("eq(binary, binary) -> bool", output="eval_eq_binary")] -#[function("eq(largestring, largestring) -> bool", output="eval_eq_largestring")] -#[function("eq(largebinary, largebinary) -> bool", output="eval_eq_largebinary")] -fn eq(_lhs: T, _rhs: T) -> bool { - _lhs == _rhs +#[function("eq(boolean, boolean) -> boolean")] +fn eq(lhs: bool, rhs: bool) -> bool { + lhs == rhs +} + +#[function("gcd(int, int) -> int", output = "eval_gcd")] +fn gcd(mut a: i32, mut b: i32) -> i32 { + while b != 0 { + (a, b) = (b, a % b); + } + a +} + +#[cfg(test)] +mod tests { + use std::{sync::Arc, vec}; + + use arrow::{ + array::{BooleanArray, RecordBatch}, + datatypes::{Field, Schema}, + }; + use arrow_udf::sig::REGISTRY; + + #[test] + fn test_eq() { + let bool_field = Field::new("", arrow::datatypes::DataType::Boolean, false); + let schema = Schema::new(vec![bool_field.clone()]); + let record_batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(BooleanArray::from(vec![true, false, true]))], + ) + .unwrap(); + + println!("Function signatures:"); + REGISTRY.iter().for_each(|sig| { + println!("{:?}", sig.name); + println!("{:?}", sig.arg_types); + println!("{:?}", sig.return_type); + }); + + let eval_eq_boolean = REGISTRY + .get("eq", &[bool_field.clone(), bool_field.clone()], &bool_field) + .unwrap() + .function + .as_scalar() + .unwrap(); + + let result = eval_eq_boolean(&record_batch).unwrap(); + + assert!(result + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0)); + } + + #[test] + fn test_gcd() { + let int_field = Field::new("", arrow::datatypes::DataType::Int32, false); + let schema = Schema::new(vec![int_field.clone(), int_field.clone()]); + let record_batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::Int32Array::from(vec![20, 30, 40])), + ], + ) + .unwrap(); + + println!("Function signatures:"); + REGISTRY.iter().for_each(|sig| { + println!("{:?}", sig.name); + println!("{:?}", sig.arg_types); + println!("{:?}", sig.return_type); + }); + + let eval_gcd_int = REGISTRY + .get("gcd", &[int_field.clone(), int_field.clone()], &int_field) + .unwrap() + .function + .as_scalar() + .unwrap(); + + let result = eval_gcd_int(&record_batch).unwrap(); + + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 10 + ); + } } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 3b0743b7e1ed..e0800c3ff196 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -42,10 +42,6 @@ default = [ ] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] -arrow_udf = [ - "arrow-udf", - "datafusion-functions/arrow_udf", -] [dependencies] ahash = { workspace = true } @@ -55,7 +51,7 @@ arrow-buffer = { workspace = true } arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-string = { workspace = true } -arrow-udf = { workspace = true, optional = true, features = ["global_registry"] } +arrow-udf = { workspace = true, features = ["global_registry"] } base64 = { version = "0.22", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 3732ce582558..ac0bb2c0b19b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -303,7 +303,7 @@ impl PhysicalExpr for BinaryExpr { println!("schema: {:?}", schema); let record_batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ lhs.clone().into_array(batch.num_rows())?, rhs.clone().into_array(batch.num_rows())?, @@ -315,7 +315,7 @@ impl PhysicalExpr for BinaryExpr { let Some(eval_eq_string) = REGISTRY .get( "eq", - &schema + schema .all_fields() .into_iter() .map(|f| f.to_owned()) @@ -323,7 +323,11 @@ impl PhysicalExpr for BinaryExpr { .as_slice(), &Field::new("bool", DataType::Boolean, false), ) - .and_then(|f| f.function.as_scalar()) + .and_then(|f| { + println!("Function found"); + + return f.function.as_scalar(); + }) else { return internal_err!("Failed to get eq function"); }; @@ -336,7 +340,7 @@ impl PhysicalExpr for BinaryExpr { return internal_err!("Failed to get result array"); }; - return Ok(ColumnarValue::Array(result_array.clone())); + return Ok(ColumnarValue::Array(Arc::clone(result_array))); } Operator::NotEq => return apply_cmp(&lhs, &rhs, neq), Operator::Lt => return apply_cmp(&lhs, &rhs, lt),