From 300316b3647981a88a491ca139599fbf3b09ab7e Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 20 Nov 2024 01:02:47 +0000 Subject: [PATCH] Updated to properly handle returning largeutf8 + code refactor and a few more tests. --- datafusion/functions-nested/src/string.rs | 318 +++++++++---------- datafusion/sqllogictest/test_files/array.slt | 57 ++++ 2 files changed, 203 insertions(+), 172 deletions(-) diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index 851aeac7f6cf..da4ab2bed49a 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -32,13 +32,13 @@ use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; use arrow::compute::cast; -use arrow_array::builder::{ArrayBuilder, StringViewBuilder}; +use arrow_array::builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder}; use arrow_array::cast::AsArray; use arrow_array::{GenericStringArray, StringViewArray}; use arrow_schema::DataType::{ Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::exec_err; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ @@ -50,8 +50,8 @@ use std::sync::{Arc, OnceLock}; macro_rules! call_array_function { ($DATATYPE:expr, false) => { match $DATATYPE { - DataType::Utf8View => array_function!(StringViewArray), DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -262,7 +262,7 @@ impl ScalarUDFImpl for StringToArray { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - Utf8View | Utf8 => make_scalar_function(string_to_array_inner::)(args), + Utf8 | Utf8View => make_scalar_function(string_to_array_inner::)(args), LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), other => { exec_err!("unsupported type for string_to_array function as {other:?}") @@ -330,13 +330,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let arr = &args[0]; - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); + let delimiters: Vec> = match args[1].data_type() { + Utf8 => args[1].as_string::().iter().collect(), + Utf8View => args[1].as_string_view().iter().collect(), + LargeUtf8 => args[1].as_string::().iter().collect(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; let mut null_string = String::from(""); let mut with_null_string = false; if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); + null_string = match args[2].data_type() { + Utf8 => args[2].as_string::().value(0).to_string(), + Utf8View => args[2].as_string_view().value(0).to_string(), + LargeUtf8 => args[2].as_string::().value(0).to_string(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; with_null_string = true; } @@ -496,190 +505,145 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { /// String_to_array SQL function /// Splits string at occurrences of delimiter and returns an array of parts /// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { +fn string_to_array_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("string_to_array expects two or three arguments"); } - match (args[0].data_type(), args[1].data_type()) { - (Utf8View, Utf8View) => { + + match args[0].data_type() { + Utf8 => { + let string_array = args[0].as_string::(); + let builder = StringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, StringBuilder>(args, string_array, builder) + } + Utf8View => { let string_array = args[0].as_string_view(); - let delimiter_array = args[1].as_string_view(); let builder = StringViewBuilder::with_capacity(string_array.len()); + string_to_array_inner_2::<&StringViewArray, StringViewBuilder>(args, string_array, builder) + } + LargeUtf8 => { + let string_array = args[0].as_string::(); + let builder = LargeStringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, LargeStringBuilder>(args, string_array, builder) + } + other => exec_err!("unsupported type for first argument to string_to_array function as {other:?}") + } +} - if args.len() == 3 { - match args[2].data_type() { - Utf8View => { - let null_type_array = Some(args[2].as_string_view()); - string_to_array_impl::< - &StringViewArray, - &StringViewArray, - &StringViewArray, - StringViewBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - Utf8 | LargeUtf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - &StringViewArray, - &StringViewArray, - &GenericStringArray, - StringViewBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - other => { - exec_err!( - "unsupported type for string_to_array function as {other:?}" - ) - } - } - } else { +fn string_to_array_inner_2<'a, StringArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[1].data_type() { + Utf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { string_to_array_impl::< + StringArrType, + &GenericStringArray, &StringViewArray, - &StringViewArray, - &GenericStringArray, - StringViewBuilder, - >(string_array, delimiter_array, None, builder) + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) } } - (Utf8View, Utf8 | LargeUtf8) => { - let string_array = args[0].as_string_view(); - let delimiter_array = args[1].as_string::(); - let builder = StringViewBuilder::with_capacity(string_array.len()); - if args.len() == 3 { - match args[2].data_type() { - Utf8View => { - let null_type_array = Some(args[2].as_string_view()); - string_to_array_impl::< - &StringViewArray, - &GenericStringArray, - &StringViewArray, - StringViewBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - Utf8 | LargeUtf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - &StringViewArray, - &GenericStringArray, - &GenericStringArray, - StringViewBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - other => { - exec_err!( - "unsupported type for string_to_array function as {other:?}" - ) - } - } - } else { + Utf8View => { + let delimiter_array = args[1].as_string_view(); + + if args.len() == 2 { string_to_array_impl::< + StringArrType, &StringViewArray, - &GenericStringArray, - &GenericStringArray, - StringViewBuilder, - >(string_array, delimiter_array, None, builder) - } - } - (Utf8 | LargeUtf8, Utf8 | LargeUtf8) => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string::(); - let builder = StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - ); - if args.len() == 3 { - match args[2].data_type() { - Utf8View => { - let null_type_array = Some(args[2].as_string_view()); - string_to_array_impl::< - &GenericStringArray, - &GenericStringArray, - &StringViewArray, - StringBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - Utf8 | LargeUtf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - &GenericStringArray, - &GenericStringArray, - &GenericStringArray, - StringBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - other => { - exec_err!( - "unsupported type for string_to_array function as {other:?}" - ) - } - } + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) } else { - string_to_array_impl::< - &GenericStringArray, - &GenericStringArray, - &GenericStringArray, - StringBuilder, - >(string_array, delimiter_array, None, builder) + string_to_array_inner_3::(args, string_array, delimiter_array, string_builder) } } - (Utf8 | LargeUtf8, Utf8View) => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string_view(); - let builder = StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - ); - if args.len() == 3 { - match args[2].data_type() { - Utf8View => { - let null_type_array = Some(args[2].as_string_view()); - string_to_array_impl::< - &GenericStringArray, - &StringViewArray, - &StringViewArray, - StringBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - Utf8 | LargeUtf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - &GenericStringArray, - &StringViewArray, - &GenericStringArray, - StringBuilder, - >( - string_array, delimiter_array, null_type_array, builder - ) - } - other => { - exec_err!( - "unsupported type for string_to_array function as {other:?}" - ) - } - } - } else { + LargeUtf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { string_to_array_impl::< - &GenericStringArray, + StringArrType, + &GenericStringArray, &StringViewArray, - &GenericStringArray, - StringBuilder, - >(string_array, delimiter_array, None, builder) + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) } } + other => exec_err!("unsupported type for second argument to string_to_array function as {other:?}") + } +} + +fn string_to_array_inner_3<'a, StringArrType, DelimiterArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + delimiter_array: DelimiterArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[2].data_type() { + Utf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &StringViewArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } other => { exec_err!("unsupported type for string_to_array function as {other:?}") } @@ -800,3 +764,13 @@ impl StringArrayBuilderType for StringViewBuilder { StringViewBuilder::append_null(self) } } + +impl StringArrayBuilderType for LargeStringBuilder { + fn append_value(&mut self, val: &str) { + LargeStringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + LargeStringBuilder::append_null(self); + } +} diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ecc5649173f6..da3a53dc07c3 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3985,6 +3985,18 @@ SELECT array_to_string(make_array(arrow_cast('a', 'Utf8View'), 'b', 'c', 'd'), ' ---- a,b,c,d +# array_to_string using largeutf8 for second arg +query TTT +select array_to_string(['h', 'e', 'l', 'l', 'o'], arrow_cast(',', 'LargeUtf8')), array_to_string([1, 2, 3, 4, 5], arrow_cast('-', 'LargeUtf8')), array_to_string([1.0, 2.0, 3.0], arrow_cast('|', 'LargeUtf8')); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string using utf8view for second arg +query TTT +select array_to_string(['h', 'e', 'l', 'l', 'o'], arrow_cast(',', 'Utf8View')), array_to_string([1, 2, 3, 4, 5], arrow_cast('-', 'Utf8View')), array_to_string([1.0, 2.0, 3.0], arrow_cast('|', 'Utf8View')); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + statement ok drop table table1; @@ -6928,6 +6940,51 @@ select string_to_array(e, ',') from values; [adipiscing] NULL +# karge string tests for string_to_array + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'LargeUtf8'), 'xxx') +---- +[abc, def] + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'LargeUtf8'), arrow_cast('xxx', 'LargeUtf8')) +---- +[abc, def] + +query ? +SELECT string_to_array(arrow_cast('abc', 'LargeUtf8'), NULL) +---- +[a, b, c] + +query ? +select string_to_array(arrow_cast(e, 'LargeUtf8'), ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +query ? +select string_to_array(arrow_cast(e, 'LargeUtf8'), ',', arrow_cast('Lorem', 'LargeUtf8')) from values; +---- +[] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + # string view tests for string_to_array # string_to_array scalar function