From 488dc588df08ff72d2510f6db0749df8c7d713e2 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Wed, 14 Aug 2024 11:18:08 +0800 Subject: [PATCH] Implement native support StringView for substr_index Signed-off-by: Chojan Shang --- .../functions/src/unicode/substrindex.rs | 83 ++++++++++++++----- .../sqllogictest/test_files/functions.slt | 59 +++++++++++++ .../sqllogictest/test_files/string_view.slt | 22 +++++ 3 files changed, 144 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index f8ecab9073c4..6591ee26403a 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -18,10 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; -use arrow::datatypes::DataType; +use arrow::array::{ + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, + PrimitiveArray, StringBuilder, +}; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -46,6 +48,7 @@ impl SubstrIndexFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View, Int64]), Exact(vec![Utf8, Utf8, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], @@ -74,15 +77,7 @@ impl ScalarUDFImpl for SubstrIndexFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(substr_index::, vec![])(args), - DataType::LargeUtf8 => { - make_scalar_function(substr_index::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function substr_index") - } - } + make_scalar_function(substr_index, vec![])(args) } fn aliases(&self) -> &[String] { @@ -95,7 +90,7 @@ impl ScalarUDFImpl for SubstrIndexFunc { /// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache /// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org -pub fn substr_index(args: &[ArrayRef]) -> Result { +fn substr_index(args: &[ArrayRef]) -> Result { if args.len() != 3 { return exec_err!( "substr_index was called with {} arguments. It requires 3.", @@ -103,15 +98,63 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { ); } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let count_array = as_int64_array(&args[2])?; + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string::(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string::(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let delimiter_array = args[1].as_string_view(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + other => { + exec_err!("Unsupported data type {other:?} for function substr_index") + } + } +} +pub fn substr_index_general< + 'a, + T: ArrowPrimitiveType, + V: ArrayAccessor, + P: ArrayAccessor, +>( + string_array: V, + delimiter_array: V, + count_array: P, +) -> Result +where + T::Native: OffsetSizeTrait, +{ let mut builder = StringBuilder::new(); - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(count_array.iter()) + let string_iter = ArrayIter::new(string_array); + let delimiter_array_iter = ArrayIter::new(delimiter_array); + let count_array_iter = ArrayIter::new(count_array); + string_iter + .zip(delimiter_array_iter) + .zip(count_array_iter) .for_each(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { // In MySQL, these cases will return an empty string. diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index bea3016a21d3..21dc7949d7af 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -1014,6 +1014,65 @@ arrow.apache.org 100 arrow.apache.org . 3 . . 100 . +query I +SELECT levenshtein(NULL, NULL) +---- +NULL + +# Test substring_index using '.' as delimiter with utf8view +query TIT +SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + # Test substring_index using 'ac' as delimiter query TIT SELECT str, n, substring_index(str, 'ac', n) AS c FROM diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0a9b73babb96..651b28ccaa1d 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -969,6 +969,28 @@ logical_plan 02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1 03)----TableScan: test projection=[column1_utf8view] +## Ensure no casts for SUBSTRINDEX +query TT +EXPLAIN SELECT + SUBSTR_INDEX(column1_utf8view, 'a', 1) as c, + SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2 +FROM test; +---- +logical_plan +01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + +query TT +SELECT + SUBSTR_INDEX(column1_utf8view, 'a', 1) as c, + SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2 +FROM test; +---- +Andrew Andrew +Xi Xiangpeng +R Raph +NULL NULL + ## Ensure no casts on columns for STARTS_WITH query TT EXPLAIN SELECT