diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d98d7d0abfe2..a6795e99d751 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -45,17 +45,8 @@ pub enum BuiltinScalarFunction { Exp, /// factorial Factorial, - /// iszero - Iszero, /// nanvl Nanvl, - /// round - Round, - /// trunc - Trunc, - /// cot - Cot, - // string functions /// concat Concat, @@ -123,11 +114,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, - BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, - BuiltinScalarFunction::Round => Volatility::Immutable, - BuiltinScalarFunction::Cot => Volatility::Immutable, - BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -175,16 +162,12 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Iszero => Ok(Boolean), - - BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Trunc - | BuiltinScalarFunction::Cot => match input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { + match input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } } } @@ -217,24 +200,6 @@ impl BuiltinScalarFunction { self.volatility(), ), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Round => Signature::one_of( - vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Trunc => Signature::one_of( - vec![ - Exact(vec![Float32, Int64]), - Exact(vec![Float64, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - self.volatility(), - ), BuiltinScalarFunction::Nanvl => Signature::one_of( vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], self.volatility(), @@ -242,9 +207,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Cot => { + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -252,10 +215,6 @@ impl BuiltinScalarFunction { // will be as good as the number of digits in the number Signature::uniform(1, vec![Float64, Float32], self.volatility()) } - BuiltinScalarFunction::Iszero => Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - self.volatility(), - ), } } @@ -268,8 +227,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Trunc ) { Some(vec![Some(true)]) } else { @@ -281,14 +238,10 @@ impl BuiltinScalarFunction { pub fn aliases(&self) -> &'static [&'static str] { match self { BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cot => &["cot"], BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Trunc => &["trunc"], // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b554d87bade1..1e28e27af1e0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -530,7 +530,6 @@ macro_rules! nary_scalar_expr { // generate methods for creating the supported unary/binary expressions // math functions -scalar_expr!(Cot, cot, num, "cotangent of a number"); scalar_expr!(Factorial, factorial, num, "factorial"); scalar_expr!( Ceil, @@ -538,12 +537,7 @@ scalar_expr!( num, "nearest integer greater than or equal to argument" ); -nary_scalar_expr!(Round, round, "round to nearest integer"); -nary_scalar_expr!( - Trunc, - trunc, - "truncate toward zero, with optional precision" -); + scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); @@ -557,12 +551,6 @@ nary_scalar_expr!( ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); -scalar_expr!( - Iszero, - iszero, - num, - "returns true if a given number is +0.0 or -0.0 otherwise returns false" -); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -872,12 +860,6 @@ impl WindowUDFImpl for SimpleWindowUDF { } /// Calls a named built in function -/// ``` -/// use datafusion_expr::{col, lit, call_fn}; -/// -/// // create the expression trunc(x) < 0.2 -/// let expr = call_fn("trunc", vec![col("x")]).unwrap().lt(lit(0.2)); -/// ``` pub fn call_fn(name: impl AsRef, args: Vec) -> Result { match name.as_ref().parse::() { Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))), @@ -935,38 +917,12 @@ mod test { }; } - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; -} - #[test] fn scalar_function_definitions() { - test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Ceil, ceil); - test_nary_scalar_expr!(Round, round, input); - test_nary_scalar_expr!(Round, round, input, decimal_places); - test_nary_scalar_expr!(Trunc, trunc, num); - test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Exp, exp); test_scalar_expr!(Nanvl, nanvl, x, y); - test_scalar_expr!(Iszero, iszero, input); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs new file mode 100644 index 000000000000..66219960d9a2 --- /dev/null +++ b/datafusion/functions/src/math/cot.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct CotFunc { + signature: Signature, +} + +impl Default for CotFunc { + fn default() -> Self { + CotFunc::new() + } +} + +impl CotFunc { + pub fn new() -> Self { + use DataType::*; + Self { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CotFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "cot" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(cot, vec![])(args) + } +} + +///cot SQL function +fn cot(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float64Array, + { compute_cot64 } + )) as ArrayRef), + Float32 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float32Array, + { compute_cot32 } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function cot"), + } +} + +fn compute_cot32(x: f32) -> f32 { + let a = f32::tan(x); + 1.0 / a +} + +fn compute_cot64(x: f64) -> f64 { + let a = f64::tan(x); + 1.0 / a +} + +#[cfg(test)] +mod test { + use crate::math::cot::cot; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_cot_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float32_array(&result).expect("failed to initialize function cot"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + + #[test] + fn test_cot_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float64_array(&result).expect("failed to initialize function cot"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } +} diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs new file mode 100644 index 000000000000..e6a728053359 --- /dev/null +++ b/datafusion/functions/src/math/iszero.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Boolean, Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct IsZeroFunc { + signature: Signature, +} + +impl Default for IsZeroFunc { + fn default() -> Self { + IsZeroFunc::new() + } +} + +impl IsZeroFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for IsZeroFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "iszero" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(iszero, vec![])(args) + } +} + +/// Iszero SQL function +pub fn iszero(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { |x: f64| { x == 0_f64 } } + )) as ArrayRef), + + Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { |x: f32| { x == 0_f32 } } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function iszero"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + + use datafusion_common::cast::as_boolean_array; + + use crate::math::iszero::iszero; + + #[test] + fn test_iszero_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 2655edfe76dc..544de04e4a98 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -17,92 +17,260 @@ //! "math" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod abs; +pub mod cot; pub mod gcd; +pub mod iszero; pub mod lcm; pub mod log; pub mod nans; pub mod pi; pub mod power; +pub mod round; +pub mod trunc; // Create UDFs -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(power::PowerFunc, POWER, power); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(lcm::LcmFunc, LCM, lcm); -make_udf_function!(pi::PiFunc, PI, pi); - -make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); -make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); -make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); - -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); -make_math_unary_udf!(TanFunc, TAN, tan, tan, None); - -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); - +make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); +make_math_unary_udf!(CosFunc, COS, cos, cos, None); +make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); +make_udf_function!(cot::CotFunc, COT, cot); +make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); +make_udf_function!(log::LogFunc, LOG, log); +make_udf_function!(gcd::GcdFunc, GCD, gcd); +make_udf_function!(nans::IsNanFunc, ISNAN, isnan); +make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); +make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); +make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); +make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_udf_function!(pi::PiFunc, PI, pi); +make_udf_function!(power::PowerFunc, POWER, power); make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None); +make_udf_function!(round::RoundFunc, ROUND, round); make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None); make_math_unary_udf!(SinFunc, SIN, sin, sin, None); make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, None); make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, None); +make_math_unary_udf!(TanFunc, TAN, tan, tan, None); +make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_udf_function!(trunc::TruncFunc, TRUNC, trunc); -make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); -make_math_unary_udf!(CosFunc, COS, cos, cos, None); -make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); -make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +pub mod expr_fn { + use datafusion_expr::Expr; -make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); + #[doc = "returns the absolute value of a given number"] + pub fn abs(num: Expr) -> Expr { + super::abs().call(vec![num]) + } + + #[doc = "returns the arc cosine or inverse cosine of a number"] + pub fn acos(num: Expr) -> Expr { + super::acos().call(vec![num]) + } + + #[doc = "returns inverse hyperbolic cosine"] + pub fn acosh(num: Expr) -> Expr { + super::acosh().call(vec![num]) + } + + #[doc = "returns the arc sine or inverse sine of a number"] + pub fn asin(num: Expr) -> Expr { + super::asin().call(vec![num]) + } + + #[doc = "returns inverse hyperbolic sine"] + pub fn asinh(num: Expr) -> Expr { + super::asinh().call(vec![num]) + } + + #[doc = "returns inverse tangent"] + pub fn atan(num: Expr) -> Expr { + super::atan().call(vec![num]) + } + + #[doc = "returns inverse tangent of a division given in the argument"] + pub fn atan2(y: Expr, x: Expr) -> Expr { + super::atan2().call(vec![y, x]) + } + + #[doc = "returns inverse hyperbolic tangent"] + pub fn atanh(num: Expr) -> Expr { + super::atanh().call(vec![num]) + } + + #[doc = "cube root of a number"] + pub fn cbrt(num: Expr) -> Expr { + super::cbrt().call(vec![num]) + } + + #[doc = "cosine"] + pub fn cos(num: Expr) -> Expr { + super::cos().call(vec![num]) + } + + #[doc = "hyperbolic cosine"] + pub fn cosh(num: Expr) -> Expr { + super::cosh().call(vec![num]) + } + + #[doc = "cotangent of a number"] + pub fn cot(num: Expr) -> Expr { + super::cot().call(vec![num]) + } + + #[doc = "converts radians to degrees"] + pub fn degrees(num: Expr) -> Expr { + super::degrees().call(vec![num]) + } + + #[doc = "nearest integer less than or equal to argument"] + pub fn floor(num: Expr) -> Expr { + super::floor().call(vec![num]) + } + + #[doc = "greatest common divisor"] + pub fn gcd(x: Expr, y: Expr) -> Expr { + super::gcd().call(vec![x, y]) + } + + #[doc = "returns true if a given number is +NaN or -NaN otherwise returns false"] + pub fn isnan(num: Expr) -> Expr { + super::isnan().call(vec![num]) + } + + #[doc = "returns true if a given number is +0.0 or -0.0 otherwise returns false"] + pub fn iszero(num: Expr) -> Expr { + super::iszero().call(vec![num]) + } + + #[doc = "least common multiple"] + pub fn lcm(x: Expr, y: Expr) -> Expr { + super::lcm().call(vec![x, y]) + } + + #[doc = "natural logarithm (base e) of a number"] + pub fn ln(num: Expr) -> Expr { + super::ln().call(vec![num]) + } + + #[doc = "logarithm of a number for a particular `base`"] + pub fn log(base: Expr, num: Expr) -> Expr { + super::log().call(vec![base, num]) + } + + #[doc = "base 2 logarithm of a number"] + pub fn log2(num: Expr) -> Expr { + super::log2().call(vec![num]) + } + + #[doc = "base 10 logarithm of a number"] + pub fn log10(num: Expr) -> Expr { + super::log10().call(vec![num]) + } + + #[doc = "Returns an approximate value of π"] + pub fn pi() -> Expr { + super::pi().call(vec![]) + } + + #[doc = "`base` raised to the power of `exponent`"] + pub fn power(base: Expr, exponent: Expr) -> Expr { + super::power().call(vec![base, exponent]) + } + + #[doc = "converts degrees to radians"] + pub fn radians(num: Expr) -> Expr { + super::radians().call(vec![num]) + } + + #[doc = "round to nearest integer"] + pub fn round(args: Vec) -> Expr { + super::round().call(args) + } + + #[doc = "sign of the argument (-1, 0, +1)"] + pub fn signum(num: Expr) -> Expr { + super::signum().call(vec![num]) + } + + #[doc = "sine"] + pub fn sin(num: Expr) -> Expr { + super::sin().call(vec![num]) + } + + #[doc = "hyperbolic sine"] + pub fn sinh(num: Expr) -> Expr { + super::sinh().call(vec![num]) + } + + #[doc = "square root of a number"] + pub fn sqrt(num: Expr) -> Expr { + super::sqrt().call(vec![num]) + } + + #[doc = "returns the tangent of a number"] + pub fn tan(num: Expr) -> Expr { + super::tan().call(vec![num]) + } + + #[doc = "returns the hyperbolic tangent of a number"] + pub fn tanh(num: Expr) -> Expr { + super::tanh().call(vec![num]) + } + + #[doc = "truncate toward zero, with optional precision"] + pub fn trunc(args: Vec) -> Expr { + super::trunc().call(args) + } +} -// Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - ( - isnan, - num, - "returns true if a given number is +NaN or -NaN otherwise returns false" - ), - (abs, num, "returns the absolute value of a given number"), - (power, base exponent, "`base` raised to the power of `exponent`"), - (log, base num, "logarithm of a number for a particular `base`"), - (log2, num, "base 2 logarithm of a number"), - (log10, num, "base 10 logarithm of a number"), - (ln, num, "natural logarithm (base e) of a number"), - ( - acos, - num, - "returns the arc cosine or inverse cosine of a number" - ), - ( - asin, - num, - "returns the arc sine or inverse sine of a number" - ), - (tan, num, "returns the tangent of a number"), - (tanh, num, "returns the hyperbolic tangent of a number"), - (atanh, num, "returns inverse hyperbolic tangent"), - (asinh, num, "returns inverse hyperbolic sine"), - (acosh, num, "returns inverse hyperbolic cosine"), - (atan, num, "returns inverse tangent"), - (atan2, y x, "returns inverse tangent of a division given in the argument"), - (radians, num, "converts degrees to radians"), - (signum, num, "sign of the argument (-1, 0, +1)"), - (sin, num, "sine"), - (sinh, num, "hyperbolic sine"), - (sqrt, num, "square root of a number"), - (cbrt, num, "cube root of a number"), - (cos, num, "cosine"), - (cosh, num, "hyperbolic cosine"), - (degrees, num, "converts radians to degrees"), - (gcd, x y, "greatest common divisor"), - (lcm, x y, "least common multiple"), - (floor, num, "nearest integer less than or equal to argument"), - (pi, , "Returns an approximate value of π") -); +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![ + abs(), + acos(), + acosh(), + asin(), + asinh(), + atan(), + atan2(), + atanh(), + cbrt(), + cos(), + cosh(), + cot(), + degrees(), + floor(), + gcd(), + isnan(), + iszero(), + lcm(), + ln(), + log(), + log2(), + log10(), + pi(), + power(), + radians(), + round(), + signum(), + sin(), + sinh(), + sqrt(), + tan(), + tanh(), + trunc(), + ] +} diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs new file mode 100644 index 000000000000..f4a163137a35 --- /dev/null +++ b/datafusion/functions/src/math/round.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, FuncMonotonicity}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct RoundFunc { + signature: Signature, +} + +impl Default for RoundFunc { + fn default() -> Self { + RoundFunc::new() + } +} + +impl RoundFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RoundFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(round, vec![])(args) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} + +/// Round SQL function +pub fn round(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!( + "round function requires one or two arguments, got {}", + args.len() + ); + } + + let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); + + if args.len() == 2 { + decimal_places = ColumnarValue::Array(args[1].clone()); + } + + match args[0].data_type() { + DataType::Float64 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float64Array, + { + |value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } + } + )) as ArrayRef) + } + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int64Array, + { + |value: f64, decimal_places: i64| { + (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f64.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + + DataType::Float32 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float32Array, + { + |value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } + } + )) as ArrayRef) + } + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int64Array, + { + |value: f32, decimal_places: i64| { + (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f32.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + + other => exec_err!("Unsupported data type {other:?} for function round"), + } +} + +#[cfg(test)] +mod test { + use crate::math::round::round; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_round_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f32_one_input() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64_one_input() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } +} diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs new file mode 100644 index 000000000000..6f88099889cc --- /dev/null +++ b/datafusion/functions/src/math/trunc.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use crate::utils::make_scalar_function; +use datafusion_common::ScalarValue::Int64; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, FuncMonotonicity}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct TruncFunc { + signature: Signature, +} + +impl Default for TruncFunc { + fn default() -> Self { + TruncFunc::new() + } +} + +impl TruncFunc { + pub fn new() -> Self { + use DataType::*; + Self { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + signature: Signature::one_of( + vec![ + Exact(vec![Float32, Int64]), + Exact(vec![Float64, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TruncFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(trunc, vec![])(args) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} + +/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function +fn trunc(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!( + "truncate function requires one or two arguments, got {}", + args.len() + ); + } + + //if only one arg then invoke toolchain trunc(num) and precision = 0 by default + //or then invoke the compute_truncate method to process precision + let num = &args[0]; + let precision = if args.len() == 1 { + ColumnarValue::Scalar(Int64(Some(0))) + } else { + ColumnarValue::Array(args[1].clone()) + }; + + match args[0].data_type() { + Float64 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float64Array, + Int64Array, + { compute_truncate64 } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, + Float32 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float32Array, + Int64Array, + { compute_truncate32 } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, + other => exec_err!("Unsupported data type {other:?} for function trunc"), + } +} + +fn compute_truncate32(x: f32, y: i64) -> f32 { + let factor = 10.0_f32.powi(y as i32); + (x * factor).round() / factor +} + +fn compute_truncate64(x: f64, y: i64) -> f64 { + let factor = 10.0_f64.powi(y as i32); + (x * factor).round() / factor +} + +#[cfg(test)] +mod test { + use crate::math::trunc::trunc; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_truncate_32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![ + 15.0, + 1_234.267_8, + 1_233.123_4, + 3.312_979_2, + -21.123_4, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float32_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 15.0); + assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(2), 1_233.12); + assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(4), -21.123_4); + } + + #[test] + fn test_truncate_64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812_176, + 123.123_456_789, + 123.312_979_313_2, + -321.123_1, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(2), 123.12); + assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(4), -321.123_1); + } + + #[test] + fn test_truncate_64_one_arg() { + let args: Vec = vec![Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812, + 123.123_45, + 123.312_979_313_2, + -321.123, + ]))]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.0); + assert_eq!(floats.value(2), 123.0); + assert_eq!(floats.value(3), 123.0); + assert_eq!(floats.value(4), -321.0); + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 5efcf5942c39..92772e4623be 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -117,8 +117,7 @@ mod tests { use itertools::Itertools; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use datafusion_expr::{Operator, ScalarUDF}; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, @@ -649,11 +648,13 @@ mod tests { col_b.clone(), )) as Arc; - let round_c = &crate::functions::create_physical_expr( - &BuiltinScalarFunction::Round, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let round_c = &create_physical_expr( + &test_fun, &[col_c.clone()], &schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let option_asc = SortOptions { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 124acdc7ac78..2be85a69d7da 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -184,22 +184,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::Factorial => { Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } - BuiltinScalarFunction::Iszero => { - Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) - } BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } BuiltinScalarFunction::Random => Arc::new(math_expressions::random), - BuiltinScalarFunction::Round => { - Arc::new(|args| make_scalar_function_inner(math_expressions::round)(args)) - } - BuiltinScalarFunction::Trunc => { - Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) - } - BuiltinScalarFunction::Cot => { - Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) - } // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index b29230de1f76..55fb54563787 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -27,7 +27,6 @@ use arrow::datatypes::DataType; use arrow_array::Array; use rand::{thread_rng, Rng}; -use datafusion_common::ScalarValue::Int64; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -154,17 +153,8 @@ macro_rules! make_function_scalar_inputs_return_type { }}; } -math_unary_function!("asin", asin); -math_unary_function!("acos", acos); -math_unary_function!("atan", atan); -math_unary_function!("asinh", asinh); -math_unary_function!("acosh", acosh); -math_unary_function!("atanh", atanh); math_unary_function!("ceil", ceil); math_unary_function!("exp", exp); -math_unary_function!("ln", ln); -math_unary_function!("log2", log2); -math_unary_function!("log10", log10); /// Factorial SQL function pub fn factorial(args: &[ArrayRef]) -> Result { @@ -247,29 +237,6 @@ pub fn isnan(args: &[ArrayRef]) -> Result { } } -/// Iszero SQL function -pub fn iszero(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function iszero"), - } -} - /// Random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { @@ -282,192 +249,6 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } -/// Round SQL function -pub fn round(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!( - "round function requires one or two arguments, got {}", - args.len() - ); - } - - let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); - - if args.len() == 2 { - decimal_places = ColumnarValue::Array(args[1].clone()); - } - - match args[0].data_type() { - DataType::Float64 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) - } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, - - DataType::Float32 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) - } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, - - other => exec_err!("Unsupported data type {other:?} for function round"), - } -} - -///cot SQL function -pub fn cot(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float64Array, - { compute_cot64 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float32Array, - { compute_cot32 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function cot"), - } -} - -fn compute_cot32(x: f32) -> f32 { - let a = f32::tan(x); - 1.0 / a -} - -fn compute_cot64(x: f64) -> f64 { - let a = f64::tan(x); - 1.0 / a -} - -/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function -pub fn trunc(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!( - "truncate function requires one or two arguments, got {}", - args.len() - ); - } - - //if only one arg then invoke toolchain trunc(num) and precision = 0 by default - //or then invoke the compute_truncate method to process precision - let num = &args[0]; - let precision = if args.len() == 1 { - ColumnarValue::Scalar(Int64(Some(0))) - } else { - ColumnarValue::Array(args[1].clone()) - }; - - match args[0].data_type() { - DataType::Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float64Array, - Int64Array, - { compute_truncate64 } - )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), - }, - DataType::Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float32Array, - Int64Array, - { compute_truncate32 } - )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), - }, - other => exec_err!("Unsupported data type {other:?} for function trunc"), - } -} - -fn compute_truncate32(x: f32, y: i64) -> f32 { - let factor = 10.0_f32.powi(y as i32); - (x * factor).round() / factor -} - -fn compute_truncate64(x: f64, y: i64) -> f64 { - let factor = 10.0_f64.powi(y as i32); - (x * factor).round() / factor -} - #[cfg(test)] mod tests { use arrow::array::{Float64Array, NullArray}; @@ -492,72 +273,6 @@ mod tests { assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } - #[test] - fn test_round_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![125.2345; 10])), // input - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float32_array(&result).expect("failed to initialize function round"); - - let expected = Float32Array::from(vec![ - 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, - ]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![125.2345; 10])), // input - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float64_array(&result).expect("failed to initialize function round"); - - let expected = Float64Array::from(vec![ - 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, - ]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f32_one_input() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float32_array(&result).expect("failed to initialize function round"); - - let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f64_one_input() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float64_array(&result).expect("failed to initialize function round"); - - let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]); - - assert_eq!(floats, &expected); - } - #[test] fn test_factorial_i64() { let args: Vec = vec![ @@ -573,124 +288,6 @@ mod tests { assert_eq!(ints, &expected); } - #[test] - fn test_cot_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float32_array(&result).expect("failed to initialize function cot"); - - let expected = Float32Array::from(vec![ - -1.986_460_4, - -0.156_119_96, - -0.501_202_8, - 0.156_119_96, - ]); - - let eps = 1e-6; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); - } - - #[test] - fn test_cot_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float64_array(&result).expect("failed to initialize function cot"); - - let expected = Float64Array::from(vec![ - -1.986_458_685_881_4, - -0.156_119_952_161_6, - -0.501_202_783_380_1, - 0.156_119_952_161_6, - ]); - - let eps = 1e-12; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); - } - - #[test] - fn test_truncate_32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![ - 15.0, - 1_234.267_8, - 1_233.123_4, - 3.312_979_2, - -21.123_4, - ])), - Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), - ]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float32_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 15.0); - assert_eq!(floats.value(1), 1_234.268); - assert_eq!(floats.value(2), 1_233.12); - assert_eq!(floats.value(3), 3.312_98); - assert_eq!(floats.value(4), -21.123_4); - } - - #[test] - fn test_truncate_64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![ - 5.0, - 234.267_812_176, - 123.123_456_789, - 123.312_979_313_2, - -321.123_1, - ])), - Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), - ]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float64_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.268); - assert_eq!(floats.value(2), 123.12); - assert_eq!(floats.value(3), 123.312_98); - assert_eq!(floats.value(4), -321.123_1); - } - - #[test] - fn test_truncate_64_one_arg() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 5.0, - 234.267_812, - 123.123_45, - 123.312_979_313_2, - -321.123, - ]))]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float64_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.0); - assert_eq!(floats.value(2), 123.0); - assert_eq!(floats.value(3), 123.0); - assert_eq!(floats.value(4), -321.0); - } - #[test] fn test_nanvl_f64() { let args: Vec = vec![ @@ -766,36 +363,4 @@ mod tests { assert!(!booleans.value(2)); assert!(booleans.value(3)); } - - #[test] - fn test_iszero_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_iszero_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0f245673f6cd..c7c0d9b5a656 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -555,12 +555,12 @@ enum ScalarFunction { // 11 was Log // 12 was Log10 // 13 was Log2 - Round = 14; + // 14 was Round // 15 was Signum // 16 was Sin // 17 was Sqrt // Tan = 18; - Trunc = 19; + // 19 was Trunc // 20 was Array // RegexpMatch = 21; // 22 was BitLength @@ -642,7 +642,7 @@ enum ScalarFunction { // 98 was Cardinality // 99 was ArrayElement // 100 was ArraySlice - Cot = 103; + // 103 was Cot // 104 was ArrayHas // 105 was ArrayHasAny // 106 was ArrayHasAll @@ -653,7 +653,7 @@ enum ScalarFunction { Nanvl = 111; // 112 was Flatten // 113 was IsNan - Iszero = 114; + // 114 was Iszero // 115 was ArrayEmpty // 116 was ArrayPopBack // 117 was StringToArray diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0922fccc7917..c8a1fba40765 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22794,17 +22794,13 @@ impl serde::Serialize for ScalarFunction { Self::Unknown => "unknown", Self::Ceil => "Ceil", Self::Exp => "Exp", - Self::Round => "Round", - Self::Trunc => "Trunc", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Factorial => "Factorial", - Self::Cot => "Cot", Self::Nanvl => "Nanvl", - Self::Iszero => "Iszero", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22820,17 +22816,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown", "Ceil", "Exp", - "Round", - "Trunc", "Concat", "ConcatWithSeparator", "InitCap", "Random", "Coalesce", "Factorial", - "Cot", "Nanvl", - "Iszero", "EndsWith", ]; @@ -22875,17 +22867,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown" => Ok(ScalarFunction::Unknown), "Ceil" => Ok(ScalarFunction::Ceil), "Exp" => Ok(ScalarFunction::Exp), - "Round" => Ok(ScalarFunction::Round), - "Trunc" => Ok(ScalarFunction::Trunc), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Factorial" => Ok(ScalarFunction::Factorial), - "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), - "Iszero" => Ok(ScalarFunction::Iszero), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index db7614144983..facf24219810 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2854,12 +2854,12 @@ pub enum ScalarFunction { /// 11 was Log /// 12 was Log10 /// 13 was Log2 - Round = 14, + /// 14 was Round /// 15 was Signum /// 16 was Sin /// 17 was Sqrt /// Tan = 18; - Trunc = 19, + /// 19 was Trunc /// 20 was Array /// RegexpMatch = 21; /// 22 was BitLength @@ -2941,7 +2941,7 @@ pub enum ScalarFunction { /// 98 was Cardinality /// 99 was ArrayElement /// 100 was ArraySlice - Cot = 103, + /// 103 was Cot /// 104 was ArrayHas /// 105 was ArrayHasAny /// 106 was ArrayHasAll @@ -2952,7 +2952,7 @@ pub enum ScalarFunction { Nanvl = 111, /// 112 was Flatten /// 113 was IsNan - Iszero = 114, + /// 114 was Iszero /// 115 was ArrayEmpty /// 116 was ArrayPopBack /// 117 was StringToArray @@ -2989,17 +2989,13 @@ impl ScalarFunction { ScalarFunction::Unknown => "unknown", ScalarFunction::Ceil => "Ceil", ScalarFunction::Exp => "Exp", - ScalarFunction::Round => "Round", - ScalarFunction::Trunc => "Trunc", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Factorial => "Factorial", - ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", - ScalarFunction::Iszero => "Iszero", ScalarFunction::EndsWith => "EndsWith", } } @@ -3009,17 +3005,13 @@ impl ScalarFunction { "unknown" => Some(Self::Unknown), "Ceil" => Some(Self::Ceil), "Exp" => Some(Self::Exp), - "Round" => Some(Self::Round), - "Trunc" => Some(Self::Trunc), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Factorial" => Some(Self::Factorial), - "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), - "Iszero" => Some(Self::Iszero), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 6a2e89fe00a3..e9eb53e45199 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,13 +37,13 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp, + ceil, coalesce, concat_expr, concat_ws_expr, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, initcap, iszero, + factorial, initcap, logical_plan::{PlanType, StringifiedPlan}, - nanvl, random, round, trunc, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + nanvl, random, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -419,12 +419,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { use protobuf::ScalarFunction; match f { ScalarFunction::Unknown => todo!(), - ScalarFunction::Cot => Self::Cot, ScalarFunction::Exp => Self::Exp, ScalarFunction::Factorial => Self::Factorial, ScalarFunction::Ceil => Self::Ceil, - ScalarFunction::Round => Self::Round, - ScalarFunction::Trunc => Self::Trunc, ScalarFunction::Concat => Self::Concat, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, @@ -432,7 +429,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Nanvl => Self::Nanvl, - ScalarFunction::Iszero => Self::Iszero, } } } @@ -1299,8 +1295,6 @@ pub fn parse_expr( Ok(factorial(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Round => Ok(round(parse_exprs(args, registry, codec)?)), - ScalarFunction::Trunc => Ok(trunc(parse_exprs(args, registry, codec)?)), ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } @@ -1318,14 +1312,10 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Iszero => { - Ok(iszero(parse_expr(&args[0], registry, codec)?)) - } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index db9653e32346..ed5e7a302b20 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1407,12 +1407,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { fn try_from(scalar: &BuiltinScalarFunction) -> Result { let scalar_function = match scalar { - BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, BuiltinScalarFunction::Ceil => Self::Ceil, - BuiltinScalarFunction::Round => Self::Round, - BuiltinScalarFunction::Trunc => Self::Trunc, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, @@ -1420,7 +1417,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Nanvl => Self::Nanvl, - BuiltinScalarFunction::Iszero => Self::Iszero, }; Ok(scalar_function) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5dacf692e904..a74b1a38935b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -34,9 +34,7 @@ use datafusion::datasource::physical_plan::{ FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::{ - create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, -}; +use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::NthValueAgg; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; @@ -603,31 +601,6 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { ))) } -#[test] -fn roundtrip_builtin_scalar_function() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let input = Arc::new(EmptyExec::new(schema.clone())); - - let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Trunc); - - let expr = ScalarFunctionExpr::new( - "trunc", - fun_def, - vec![col("a", &schema)?], - DataType::Float64, - Some(vec![Some(true)]), - false, - ); - - let project = - ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; - - roundtrip_test(Arc::new(project)) -} - #[test] fn roundtrip_scalar_udf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e923a15372d0..19288123558a 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2683,6 +2683,11 @@ fn logical_plan_with_dialect_and_options( vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "round", + vec![DataType::Float64, DataType::Int64], + DataType::Float32, + )) .with_udf(make_udf( "arrow_cast", vec![DataType::Int64, DataType::Utf8],