Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement native support StringView for find in set #11970

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions datafusion/functions/src/unicode/find_in_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
PrimitiveArray,
};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};

use datafusion_common::cast::as_generic_string_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand All @@ -46,7 +46,11 @@ impl FindInSetFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
vec![
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
Volatility::Immutable,
),
}
Expand All @@ -71,41 +75,50 @@ impl ScalarUDFImpl for FindInSetFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(find_in_set::<Int32Type>, vec![])(args)
}
DataType::LargeUtf8 => {
make_scalar_function(find_in_set::<Int64Type>, vec![])(args)
}
other => {
exec_err!("Unsupported data type {other:?} for function find_in_set")
}
}
make_scalar_function(find_in_set, vec![])(args)
}
}

///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings
///A string list is a string composed of substrings separated by , characters.
pub fn find_in_set<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!(
"find_in_set was called with {} arguments. It requires 2.",
args.len()
);
}
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let str_list_array = args[1].as_string::<i32>();
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let str_list_array = args[1].as_string::<i64>();
find_in_set_general::<Int64Type, _>(string_array, str_list_array)
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let str_list_array = args[1].as_string_view();
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
}
_ => unreachable!(),
}
}

let str_array: &GenericStringArray<T::Native> =
as_generic_string_array::<T::Native>(&args[0])?;
let str_list_array: &GenericStringArray<T::Native> =
as_generic_string_array::<T::Native>(&args[1])?;

let result = str_array
.iter()
.zip(str_list_array.iter())
pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item = &'a str>>(
string_array: V,
str_list_array: V,
) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
let string_iter = ArrayIter::new(string_array);
let str_list_iter = ArrayIter::new(str_list_array);
let result = string_iter
.zip(str_list_iter)
.map(|(string, str_list)| match (string, str_list) {
(Some(string), Some(str_list)) => {
let mut res = 0;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: find_in_set(CAST(test.column1_utf8view AS Utf8), Utf8("a,b,c,d")) AS c
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably add a test that actually invokes find_in_set with a StringVieAwArray argument.

However, I think @2010YOUY01 has some ideas on a more general testing framework, so maybe this is good enough for now. Let me know what you think

#11790 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some sqllogictests to make sure it is covered.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! 5d5a4e6

01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c
02)--TableScan: test projection=[column1_utf8view]


Expand Down