Skip to content

Commit

Permalink
Updated to properly handle returning largeutf8 + code refactor and a …
Browse files Browse the repository at this point in the history
…few more tests.
  • Loading branch information
Omega359 committed Nov 20, 2024
1 parent 203cd9e commit 300316b
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 172 deletions.
318 changes: 146 additions & 172 deletions datafusion/functions-nested/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ use std::any::{type_name, Any};

use crate::utils::{downcast_arg, make_scalar_function};
use arrow::compute::cast;
use arrow_array::builder::{ArrayBuilder, StringViewBuilder};
use arrow_array::builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder};
use arrow_array::cast::AsArray;
use arrow_array::{GenericStringArray, StringViewArray};
use arrow_schema::DataType::{
Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View,
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::exec_err;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
Expand All @@ -50,8 +50,8 @@ use std::sync::{Arc, OnceLock};
macro_rules! call_array_function {
($DATATYPE:expr, false) => {
match $DATATYPE {
DataType::Utf8View => array_function!(StringViewArray),
DataType::Utf8 => array_function!(StringArray),
DataType::Utf8View => array_function!(StringViewArray),
DataType::LargeUtf8 => array_function!(LargeStringArray),
DataType::Boolean => array_function!(BooleanArray),
DataType::Float32 => array_function!(Float32Array),
Expand Down Expand Up @@ -262,7 +262,7 @@ impl ScalarUDFImpl for StringToArray {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
Utf8View | Utf8 => make_scalar_function(string_to_array_inner::<i32>)(args),
Utf8 | Utf8View => make_scalar_function(string_to_array_inner::<i32>)(args),
LargeUtf8 => make_scalar_function(string_to_array_inner::<i64>)(args),
other => {
exec_err!("unsupported type for string_to_array function as {other:?}")
Expand Down Expand Up @@ -330,13 +330,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

let arr = &args[0];

let delimiters = as_string_array(&args[1])?;
let delimiters: Vec<Option<&str>> = delimiters.iter().collect();
let delimiters: Vec<Option<&str>> = match args[1].data_type() {
Utf8 => args[1].as_string::<i32>().iter().collect(),
Utf8View => args[1].as_string_view().iter().collect(),
LargeUtf8 => args[1].as_string::<i64>().iter().collect(),
other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}")
};

let mut null_string = String::from("");
let mut with_null_string = false;
if args.len() == 3 {
null_string = as_string_array(&args[2])?.value(0).to_string();
null_string = match args[2].data_type() {
Utf8 => args[2].as_string::<i32>().value(0).to_string(),
Utf8View => args[2].as_string_view().value(0).to_string(),
LargeUtf8 => args[2].as_string::<i64>().value(0).to_string(),
other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}")
};
with_null_string = true;
}

Expand Down Expand Up @@ -496,190 +505,145 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
/// String_to_array SQL function
/// Splits string at occurrences of delimiter and returns an array of parts
/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]'
pub fn string_to_array_inner<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn string_to_array_inner<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("string_to_array expects two or three arguments");
}
match (args[0].data_type(), args[1].data_type()) {
(Utf8View, Utf8View) => {

match args[0].data_type() {
Utf8 => {
let string_array = args[0].as_string::<T>();
let builder = StringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size());
string_to_array_inner_2::<&GenericStringArray<T>, StringBuilder>(args, string_array, builder)
}
Utf8View => {
let string_array = args[0].as_string_view();
let delimiter_array = args[1].as_string_view();
let builder = StringViewBuilder::with_capacity(string_array.len());
string_to_array_inner_2::<&StringViewArray, StringViewBuilder>(args, string_array, builder)
}
LargeUtf8 => {
let string_array = args[0].as_string::<T>();
let builder = LargeStringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size());
string_to_array_inner_2::<&GenericStringArray<T>, LargeStringBuilder>(args, string_array, builder)
}
other => exec_err!("unsupported type for first argument to string_to_array function as {other:?}")
}
}

if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&StringViewArray,
&StringViewArray,
&StringViewArray,
StringViewBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&StringViewArray,
&StringViewArray,
&GenericStringArray<T>,
StringViewBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
fn string_to_array_inner_2<'a, StringArrType, StringBuilderType>(
args: &'a [ArrayRef],
string_array: StringArrType,
string_builder: StringBuilderType,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
StringBuilderType: StringArrayBuilderType,
{
match args[1].data_type() {
Utf8 => {
let delimiter_array = args[1].as_string::<i32>();
if args.len() == 2 {
string_to_array_impl::<
StringArrType,
&GenericStringArray<i32>,
&StringViewArray,
&StringViewArray,
&GenericStringArray<T>,
StringViewBuilder,
>(string_array, delimiter_array, None, builder)
StringBuilderType,
>(string_array, delimiter_array, None, string_builder)
} else {
string_to_array_inner_3::<StringArrType,
&GenericStringArray<i32>,
StringBuilderType>(args, string_array, delimiter_array, string_builder)
}
}
(Utf8View, Utf8 | LargeUtf8) => {
let string_array = args[0].as_string_view();
let delimiter_array = args[1].as_string::<T>();
let builder = StringViewBuilder::with_capacity(string_array.len());
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&StringViewArray,
&GenericStringArray<T>,
&StringViewArray,
StringViewBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&StringViewArray,
&GenericStringArray<T>,
&GenericStringArray<T>,
StringViewBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
Utf8View => {
let delimiter_array = args[1].as_string_view();

if args.len() == 2 {
string_to_array_impl::<
StringArrType,
&StringViewArray,
&GenericStringArray<T>,
&GenericStringArray<T>,
StringViewBuilder,
>(string_array, delimiter_array, None, builder)
}
}
(Utf8 | LargeUtf8, Utf8 | LargeUtf8) => {
let string_array = args[0].as_string::<T>();
let delimiter_array = args[1].as_string::<T>();
let builder = StringBuilder::with_capacity(
string_array.len(),
string_array.get_buffer_memory_size(),
);
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&StringViewArray,
StringBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&GenericStringArray<T>,
StringBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
&StringViewArray,
StringBuilderType,
>(string_array, delimiter_array, None, string_builder)
} else {
string_to_array_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
&GenericStringArray<T>,
StringBuilder,
>(string_array, delimiter_array, None, builder)
string_to_array_inner_3::<StringArrType,
&StringViewArray,
StringBuilderType>(args, string_array, delimiter_array, string_builder)
}
}
(Utf8 | LargeUtf8, Utf8View) => {
let string_array = args[0].as_string::<T>();
let delimiter_array = args[1].as_string_view();
let builder = StringBuilder::with_capacity(
string_array.len(),
string_array.get_buffer_memory_size(),
);
if args.len() == 3 {
match args[2].data_type() {
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
&GenericStringArray<T>,
&StringViewArray,
&StringViewArray,
StringBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
Utf8 | LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<T>());
string_to_array_impl::<
&GenericStringArray<T>,
&StringViewArray,
&GenericStringArray<T>,
StringBuilder,
>(
string_array, delimiter_array, null_type_array, builder
)
}
other => {
exec_err!(
"unsupported type for string_to_array function as {other:?}"
)
}
}
} else {
LargeUtf8 => {
let delimiter_array = args[1].as_string::<i64>();
if args.len() == 2 {
string_to_array_impl::<
&GenericStringArray<T>,
StringArrType,
&GenericStringArray<i64>,
&StringViewArray,
&GenericStringArray<T>,
StringBuilder,
>(string_array, delimiter_array, None, builder)
StringBuilderType,
>(string_array, delimiter_array, None, string_builder)
} else {
string_to_array_inner_3::<StringArrType,
&GenericStringArray<i64>,
StringBuilderType>(args, string_array, delimiter_array, string_builder)
}
}
other => exec_err!("unsupported type for second argument to string_to_array function as {other:?}")
}
}

fn string_to_array_inner_3<'a, StringArrType, DelimiterArrType, StringBuilderType>(
args: &'a [ArrayRef],
string_array: StringArrType,
delimiter_array: DelimiterArrType,
string_builder: StringBuilderType,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
DelimiterArrType: StringArrayType<'a>,
StringBuilderType: StringArrayBuilderType,
{
match args[2].data_type() {
Utf8 => {
let null_type_array = Some(args[2].as_string::<i32>());
string_to_array_impl::<
StringArrType,
DelimiterArrType,
&GenericStringArray<i32>,
StringBuilderType,
>(
string_array,
delimiter_array,
null_type_array,
string_builder,
)
}
Utf8View => {
let null_type_array = Some(args[2].as_string_view());
string_to_array_impl::<
StringArrType,
DelimiterArrType,
&StringViewArray,
StringBuilderType,
>(
string_array,
delimiter_array,
null_type_array,
string_builder,
)
}
LargeUtf8 => {
let null_type_array = Some(args[2].as_string::<i64>());
string_to_array_impl::<
StringArrType,
DelimiterArrType,
&GenericStringArray<i64>,
StringBuilderType,
>(
string_array,
delimiter_array,
null_type_array,
string_builder,
)
}
other => {
exec_err!("unsupported type for string_to_array function as {other:?}")
}
Expand Down Expand Up @@ -800,3 +764,13 @@ impl StringArrayBuilderType for StringViewBuilder {
StringViewBuilder::append_null(self)
}
}

impl StringArrayBuilderType for LargeStringBuilder {
fn append_value(&mut self, val: &str) {
LargeStringBuilder::append_value(self, val);
}

fn append_null(&mut self) {
LargeStringBuilder::append_null(self);
}
}
Loading

0 comments on commit 300316b

Please sign in to comment.