Skip to content

Commit

Permalink
move Trunc, Cot, Round, iszero functions to datafusion-functions (#10000
Browse files Browse the repository at this point in the history
)

* move Floor, Gcd, Lcm, Pi to datafusion-functions

* remove floor fn

* move Trunc, Cot, Round, iszero functions to datafusion-functions

* Make mod iszero public, minor ordering change to keep the alphabetical ordering theme.

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Omega359 and alamb authored Apr 8, 2024
1 parent 0088c28 commit 78f8ef1
Show file tree
Hide file tree
Showing 17 changed files with 1,061 additions and 692 deletions.
61 changes: 7 additions & 54 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,8 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// iszero
Iszero,
/// nanvl
Nanvl,
/// round
Round,
/// trunc
Trunc,
/// cot
Cot,

// string functions
/// concat
Concat,
Expand Down Expand Up @@ -123,11 +114,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Iszero => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Round => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
Expand Down Expand Up @@ -175,16 +162,12 @@ impl BuiltinScalarFunction {
_ => Ok(Float64),
},

BuiltinScalarFunction::Iszero => Ok(Boolean),

BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Round
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
}
}
}
}

Expand Down Expand Up @@ -217,45 +200,21 @@ impl BuiltinScalarFunction {
self.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Round => Signature::one_of(
vec![
Exact(vec![Float64, Int64]),
Exact(vec![Float32, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
],
self.volatility(),
),
BuiltinScalarFunction::Trunc => Signature::one_of(
vec![
Exact(vec![Float32, Int64]),
Exact(vec![Float64, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
],
self.volatility(),
),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Factorial => {
Signature::uniform(1, vec![Int64], self.volatility())
}
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Cot => {
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
// We accept f32 because in this case it is clear that the best approximation
// will be as good as the number of digits in the number
Signature::uniform(1, vec![Float64, Float32], self.volatility())
}
BuiltinScalarFunction::Iszero => Signature::one_of(
vec![Exact(vec![Float32]), Exact(vec![Float64])],
self.volatility(),
),
}
}

Expand All @@ -268,8 +227,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Factorial
| BuiltinScalarFunction::Round
| BuiltinScalarFunction::Trunc
) {
Some(vec![Some(true)])
} else {
Expand All @@ -281,14 +238,10 @@ impl BuiltinScalarFunction {
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cot => &["cot"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Iszero => &["iszero"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Random => &["random"],
BuiltinScalarFunction::Round => &["round"],
BuiltinScalarFunction::Trunc => &["trunc"],

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
Expand Down
46 changes: 1 addition & 45 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,20 +530,14 @@ macro_rules! nary_scalar_expr {
// generate methods for creating the supported unary/binary expressions

// math functions
scalar_expr!(Cot, cot, num, "cotangent of a number");
scalar_expr!(Factorial, factorial, num, "factorial");
scalar_expr!(
Ceil,
ceil,
num,
"nearest integer greater than or equal to argument"
);
nary_scalar_expr!(Round, round, "round to nearest integer");
nary_scalar_expr!(
Trunc,
trunc,
"truncate toward zero, with optional precision"
);

scalar_expr!(Exp, exp, num, "exponential");

scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
Expand All @@ -557,12 +551,6 @@ nary_scalar_expr!(
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");
scalar_expr!(
Iszero,
iszero,
num,
"returns true if a given number is +0.0 or -0.0 otherwise returns false"
);

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
pub fn case(expr: Expr) -> CaseBuilder {
Expand Down Expand Up @@ -872,12 +860,6 @@ impl WindowUDFImpl for SimpleWindowUDF {
}

/// Calls a named built in function
/// ```
/// use datafusion_expr::{col, lit, call_fn};
///
/// // create the expression trunc(x) < 0.2
/// let expr = call_fn("trunc", vec![col("x")]).unwrap().lt(lit(0.2));
/// ```
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))),
Expand Down Expand Up @@ -935,38 +917,12 @@ mod test {
};
}

macro_rules! test_nary_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = [$(stringify!($arg)),*];
let result = $FUNC(
vec![
$(
col(stringify!($arg.to_string()))
),*
]
);
if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}

#[test]
fn scalar_function_definitions() {
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Factorial, factorial);
test_unary_scalar_expr!(Ceil, ceil);
test_nary_scalar_expr!(Round, round, input);
test_nary_scalar_expr!(Round, round, input, decimal_places);
test_nary_scalar_expr!(Trunc, trunc, num);
test_nary_scalar_expr!(Trunc, trunc, num, precision);
test_unary_scalar_expr!(Exp, exp);
test_scalar_expr!(Nanvl, nanvl, x, y);
test_scalar_expr!(Iszero, iszero, input);

test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(EndsWith, ends_with, string, characters);
Expand Down
166 changes: 166 additions & 0 deletions datafusion/functions/src/math/cot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Float32, Float64};

use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

use crate::utils::make_scalar_function;

#[derive(Debug)]
pub struct CotFunc {
signature: Signature,
}

impl Default for CotFunc {
fn default() -> Self {
CotFunc::new()
}
}

impl CotFunc {
pub fn new() -> Self {
use DataType::*;
Self {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
// We accept f32 because in this case it is clear that the best approximation
// will be as good as the number of digits in the number
signature: Signature::uniform(
1,
vec![Float64, Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for CotFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"cot"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(cot, vec![])(args)
}
}

///cot SQL function
fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
Float64 => Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"x",
Float64Array,
{ compute_cot64 }
)) as ArrayRef),
Float32 => Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"x",
Float32Array,
{ compute_cot32 }
)) as ArrayRef),
other => exec_err!("Unsupported data type {other:?} for function cot"),
}
}

fn compute_cot32(x: f32) -> f32 {
let a = f32::tan(x);
1.0 / a
}

fn compute_cot64(x: f64) -> f64 {
let a = f64::tan(x);
1.0 / a
}

#[cfg(test)]
mod test {
use crate::math::cot::cot;
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use datafusion_common::cast::{as_float32_array, as_float64_array};
use std::sync::Arc;

#[test]
fn test_cot_f32() {
let args: Vec<ArrayRef> =
vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
let result = cot(&args).expect("failed to initialize function cot");
let floats =
as_float32_array(&result).expect("failed to initialize function cot");

let expected = Float32Array::from(vec![
-1.986_460_4,
-0.156_119_96,
-0.501_202_8,
0.156_119_96,
]);

let eps = 1e-6;
assert_eq!(floats.len(), 4);
assert!((floats.value(0) - expected.value(0)).abs() < eps);
assert!((floats.value(1) - expected.value(1)).abs() < eps);
assert!((floats.value(2) - expected.value(2)).abs() < eps);
assert!((floats.value(3) - expected.value(3)).abs() < eps);
}

#[test]
fn test_cot_f64() {
let args: Vec<ArrayRef> =
vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
let result = cot(&args).expect("failed to initialize function cot");
let floats =
as_float64_array(&result).expect("failed to initialize function cot");

let expected = Float64Array::from(vec![
-1.986_458_685_881_4,
-0.156_119_952_161_6,
-0.501_202_783_380_1,
0.156_119_952_161_6,
]);

let eps = 1e-12;
assert_eq!(floats.len(), 4);
assert!((floats.value(0) - expected.value(0)).abs() < eps);
assert!((floats.value(1) - expected.value(1)).abs() < eps);
assert!((floats.value(2) - expected.value(2)).abs() < eps);
assert!((floats.value(3) - expected.value(3)).abs() < eps);
}
}
Loading

0 comments on commit 78f8ef1

Please sign in to comment.