Skip to content

Commit

Permalink
Implement native support StringView for substr_index
Browse files Browse the repository at this point in the history
Signed-off-by: Chojan Shang <[email protected]>
  • Loading branch information
PsiACE committed Aug 14, 2024
1 parent 69c99a7 commit 488dc58
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 20 deletions.
83 changes: 63 additions & 20 deletions datafusion/functions/src/unicode/substrindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
use arrow::datatypes::DataType;
use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
PrimitiveArray, StringBuilder,
};
use arrow::datatypes::{DataType, Int32Type, Int64Type};

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand All @@ -46,6 +48,7 @@ impl SubstrIndexFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
Expand Down Expand Up @@ -74,15 +77,7 @@ impl ScalarUDFImpl for SubstrIndexFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(substr_index::<i32>, vec![])(args),
DataType::LargeUtf8 => {
make_scalar_function(substr_index::<i64>, vec![])(args)
}
other => {
exec_err!("Unsupported data type {other:?} for function substr_index")
}
}
make_scalar_function(substr_index, vec![])(args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -95,23 +90,71 @@ impl ScalarUDFImpl for SubstrIndexFunc {
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!(
"substr_index was called with {} arguments. It requires 3.",
args.len()
);
}

let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let count_array = as_int64_array(&args[2])?;
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let delimiter_array = args[1].as_string::<i32>();
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
substr_index_general::<Int32Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let delimiter_array = args[1].as_string::<i64>();
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
substr_index_general::<Int64Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let delimiter_array = args[1].as_string_view();
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
substr_index_general::<Int32Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
other => {
exec_err!("Unsupported data type {other:?} for function substr_index")
}
}
}

pub fn substr_index_general<
'a,
T: ArrowPrimitiveType,
V: ArrayAccessor<Item = &'a str>,
P: ArrayAccessor<Item = i64>,
>(
string_array: V,
delimiter_array: V,
count_array: P,
) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
let mut builder = StringBuilder::new();
string_array
.iter()
.zip(delimiter_array.iter())
.zip(count_array.iter())
let string_iter = ArrayIter::new(string_array);
let delimiter_array_iter = ArrayIter::new(delimiter_array);
let count_array_iter = ArrayIter::new(count_array);
string_iter
.zip(delimiter_array_iter)
.zip(count_array_iter)
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
// In MySQL, these cases will return an empty string.
Expand Down
59 changes: 59 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,65 @@ arrow.apache.org 100 arrow.apache.org
. 3 .
. 100 .

query I
SELECT levenshtein(NULL, NULL)
----
NULL

# Test substring_index using '.' as delimiter with utf8view
query TIT
SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM
(VALUES
ROW('arrow.apache.org'),
ROW('.'),
ROW('...'),
ROW(NULL)
) AS strings(str),
(VALUES
ROW(1),
ROW(2),
ROW(3),
ROW(100),
ROW(-1),
ROW(-2),
ROW(-3),
ROW(-100)
) AS occurrences(n)
ORDER BY str DESC, n;
----
NULL -100 NULL
NULL -3 NULL
NULL -2 NULL
NULL -1 NULL
NULL 1 NULL
NULL 2 NULL
NULL 3 NULL
NULL 100 NULL
arrow.apache.org -100 arrow.apache.org
arrow.apache.org -3 arrow.apache.org
arrow.apache.org -2 apache.org
arrow.apache.org -1 org
arrow.apache.org 1 arrow
arrow.apache.org 2 arrow.apache
arrow.apache.org 3 arrow.apache.org
arrow.apache.org 100 arrow.apache.org
... -100 ...
... -3 ..
... -2 .
... -1 (empty)
... 1 (empty)
... 2 .
... 3 ..
... 100 ...
. -100 .
. -3 .
. -2 .
. -1 (empty)
. 1 (empty)
. 2 .
. 3 .
. 100 .

# Test substring_index using 'ac' as delimiter
query TIT
SELECT str, n, substring_index(str, 'ac', n) AS c FROM
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,28 @@ logical_plan
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1
03)----TableScan: test projection=[column1_utf8view]

## Ensure no casts for SUBSTRINDEX
query TT
EXPLAIN SELECT
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
FROM test;
----
logical_plan
01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2
02)--TableScan: test projection=[column1_utf8view]

query TT
SELECT
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
FROM test;
----
Andrew Andrew
Xi Xiangpeng
R Raph
NULL NULL

## Ensure no casts on columns for STARTS_WITH
query TT
EXPLAIN SELECT
Expand Down

0 comments on commit 488dc58

Please sign in to comment.