Skip to content

Commit

Permalink
Move repeat, replace, split_part to datafusion_functions (apache#9784)
Browse files Browse the repository at this point in the history
* Fix to_timestamp benchmark

* Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started.

* Fixed missing trim() function.

* Move repeat, replace, split_part to datafusion_functions
  • Loading branch information
Omega359 authored Mar 24, 2024
1 parent cb9da2b commit 1e4ddb6
Show file tree
Hide file tree
Showing 13 changed files with 469 additions and 261 deletions.
44 changes: 7 additions & 37 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,12 @@ pub enum BuiltinScalarFunction {
Lpad,
/// random
Random,
/// repeat
Repeat,
/// replace
Replace,
/// reverse
Reverse,
/// right
Right,
/// rpad
Rpad,
/// split_part
SplitPart,
/// strpos
Strpos,
/// substr
Expand Down Expand Up @@ -238,12 +232,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
BuiltinScalarFunction::Replace => Volatility::Immutable,
BuiltinScalarFunction::Reverse => Volatility::Immutable,
BuiltinScalarFunction::Right => Volatility::Immutable,
BuiltinScalarFunction::Rpad => Volatility::Immutable,
BuiltinScalarFunction::SplitPart => Volatility::Immutable,
BuiltinScalarFunction::Strpos => Volatility::Immutable,
BuiltinScalarFunction::Substr => Volatility::Immutable,
BuiltinScalarFunction::Translate => Volatility::Immutable,
Expand Down Expand Up @@ -293,22 +284,13 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::Pi => Ok(Float64),
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::Repeat => {
utf8_to_str_type(&input_expr_types[0], "repeat")
}
BuiltinScalarFunction::Replace => {
utf8_to_str_type(&input_expr_types[0], "replace")
}
BuiltinScalarFunction::Reverse => {
utf8_to_str_type(&input_expr_types[0], "reverse")
}
BuiltinScalarFunction::Right => {
utf8_to_str_type(&input_expr_types[0], "right")
}
BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
utf8_to_int_type(&input_expr_types[0], "strpos/instr/position")
Expand Down Expand Up @@ -417,21 +399,12 @@ impl BuiltinScalarFunction {
self.volatility(),
)
}
BuiltinScalarFunction::Left
| BuiltinScalarFunction::Repeat
| BuiltinScalarFunction::Right => Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
self.volatility(),
),
BuiltinScalarFunction::SplitPart => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, Utf8, Int64]),
Exact(vec![Utf8, LargeUtf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),
BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => {
Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
self.volatility(),
)
}

BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => {
Signature::one_of(
Expand Down Expand Up @@ -467,7 +440,7 @@ impl BuiltinScalarFunction {
self.volatility(),
),

BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
BuiltinScalarFunction::Translate => {
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
}
BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()),
Expand Down Expand Up @@ -637,12 +610,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"],
BuiltinScalarFunction::Substr => &["substr"],
BuiltinScalarFunction::Translate => &["translate"],
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,8 @@ scalar_expr!(
);
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`");
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
scalar_expr!(Substr, substr, string position, "substring from the `position` to the end");
Expand Down Expand Up @@ -1056,13 +1053,10 @@ mod test {
test_scalar_expr!(Left, left, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
Expand Down
24 changes: 24 additions & 0 deletions datafusion/functions/src/string/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ mod lower;
mod ltrim;
mod octet_length;
mod overlay;
mod repeat;
mod replace;
mod rtrim;
mod split_part;
mod starts_with;
mod to_hex;
mod upper;
Expand All @@ -43,8 +46,11 @@ make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim);
make_udf_function!(lower::LowerFunc, LOWER, lower);
make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length);
make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay);
make_udf_function!(repeat::RepeatFunc, REPEAT, repeat);
make_udf_function!(replace::ReplaceFunc, REPLACE, replace);
make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim);
make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with);
make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part);
make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex);
make_udf_function!(upper::UpperFunc, UPPER, upper);
make_udf_function!(uuid::UuidFunc, UUID, uuid);
Expand Down Expand Up @@ -87,11 +93,26 @@ pub mod expr_fn {
super::overlay().call(args)
}

#[doc = "Repeats the `string` to `n` times"]
pub fn repeat(string: Expr, n: Expr) -> Expr {
super::repeat().call(vec![string, n])
}

#[doc = "Replaces all occurrences of `from` with `to` in the `string`"]
pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr {
super::replace().call(vec![string, from, to])
}

#[doc = "Removes all characters, spaces by default, from the end of a string"]
pub fn rtrim(args: Vec<Expr>) -> Expr {
super::rtrim().call(args)
}

#[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."]
pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr {
super::split_part().call(vec![string, delimiter, index])
}

#[doc = "Returns true if string starts with prefix."]
pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr {
super::starts_with().call(vec![arg1, arg2])
Expand Down Expand Up @@ -128,7 +149,10 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
ltrim(),
octet_length(),
overlay(),
repeat(),
replace(),
rtrim(),
split_part(),
starts_with(),
to_hex(),
upper(),
Expand Down
144 changes: 144 additions & 0 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

use crate::string::common::*;

#[derive(Debug)]
pub(super) struct RepeatFunc {
signature: Signature,
}

impl RepeatFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for RepeatFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"repeat"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "repeat")
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(repeat::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(repeat::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function repeat"),
}
}
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let number_array = as_int64_array(&args[1])?;

let result = string_array
.iter()
.zip(number_array.iter())
.map(|(string, number)| match (string, number) {
(Some(string), Some(number)) => Some(string.repeat(number as usize)),
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;

use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

use crate::string::common::test::test_function;
use crate::string::repeat::RepeatFunc;

#[test]
fn test_functions() -> Result<()> {
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(Some("PgPgPgPg")),
&str,
Utf8,
StringArray
);

test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(None),
&str,
Utf8,
StringArray
);
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
Utf8,
StringArray
);

Ok(())
}
}
Loading

0 comments on commit 1e4ddb6

Please sign in to comment.