From 34ec9d41faa8e73f37cd68f971985e851987bb3d Mon Sep 17 00:00:00 2001 From: kf zheng <100595273+Kev1n8@users.noreply.github.com> Date: Mon, 12 Aug 2024 20:01:40 +0800 Subject: [PATCH] Implement native stringview support for BTRIM (#11920) * add utf8view support for generic_trim * add utf8view support for BTRIM * stop LTRIM and RTRIM from complaining generic_trim missing args * add tests to cover utf8view support of BTRIM * fix typo and tiny err * remove useless imports --- datafusion/functions/src/string/btrim.rs | 24 ++++-- datafusion/functions/src/string/common.rs | 78 ++++++++++++++++++- datafusion/functions/src/string/ltrim.rs | 2 +- datafusion/functions/src/string/rtrim.rs | 2 +- .../sqllogictest/test_files/string_view.slt | 37 ++++++++- 5 files changed, 131 insertions(+), 12 deletions(-) diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 349928d09664..86470dd7a646 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -16,9 +16,8 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use std::any::Any; - use arrow::datatypes::DataType; +use std::any::Any; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; @@ -32,7 +31,8 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' fn btrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Both) + let use_string_view = args[0].data_type() == &DataType::Utf8View; + general_trim::(args, TrimType::Both, use_string_view) } #[derive(Debug)] @@ -52,7 +52,16 @@ impl BTrimFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. + // If that fails, it proceeds to `(Utf8, Utf8)`. + Exact(vec![Utf8View, Utf8View]), + // Exact(vec![Utf8, Utf8View]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8View]), + Exact(vec![Utf8]), + ], Volatility::Immutable, ), aliases: vec![String::from("trim")], @@ -79,7 +88,7 @@ impl ScalarUDFImpl for BTrimFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function( + DataType::Utf8 | DataType::Utf8View => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), @@ -87,7 +96,10 @@ impl ScalarUDFImpl for BTrimFunc { btrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), - other => exec_err!("Unsupported data type {other:?} for function btrim"), + other => exec_err!( + "Unsupported data type {other:?} for function btrim,\ + expected for Utf8, LargeUtf8 or Utf8View." + ), } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index d36bd5cecc47..7037c1d1c3c3 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -25,7 +25,7 @@ use arrow::array::{ use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -49,6 +49,7 @@ impl Display for TrimType { pub(crate) fn general_trim( args: &[ArrayRef], trim_type: TrimType, + use_string_view: bool, ) -> Result { let func = match trim_type { TrimType::Left => |input, pattern: &str| { @@ -68,6 +69,74 @@ pub(crate) fn general_trim( }, }; + if use_string_view { + string_view_trim::(trim_type, func, args) + } else { + string_trim::(trim_type, func, args) + } +} + +// removing 'a will cause compiler complaining lifetime of `func` +fn string_view_trim<'a, T: OffsetSizeTrait>( + trim_type: TrimType, + func: fn(&'a str, &'a str) -> &'a str, + args: &'a [ArrayRef], +) -> Result { + let string_array = as_string_view_array(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_string_view_array(&args[1])?; + + if characters_array.len() == 1 { + if characters_array.is_null(0) { + return Ok(new_null_array( + // The schema is expecting utf8 as null + &DataType::Utf8, + string_array.len(), + )); + } + + let characters = characters_array.value(0); + let result = string_array + .iter() + .map(|item| item.map(|string| func(string, characters))) + .collect::>(); + return Ok(Arc::new(result) as ArrayRef); + } + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +fn string_trim<'a, T: OffsetSizeTrait>( + trim_type: TrimType, + func: fn(&'a str, &'a str) -> &'a str, + args: &'a [ArrayRef], +) -> Result { let string_array = as_generic_string_array::(&args[0])?; match args.len() { @@ -84,7 +153,10 @@ pub(crate) fn general_trim( if characters_array.len() == 1 { if characters_array.is_null(0) { - return Ok(new_null_array(args[0].data_type(), args[0].len())); + return Ok(new_null_array( + string_array.data_type(), + string_array.len(), + )); } let characters = characters_array.value(0); @@ -109,7 +181,7 @@ pub(crate) fn general_trim( other => { exec_err!( "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) + ) } } } diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index de14bbaa2bcf..6a9fafdd9299 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -32,7 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Left) + general_trim::(args, TrimType::Left, false) } #[derive(Debug)] diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 2d29b50cb173..50b626e3df0e 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -32,7 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Right) + general_trim::(args, TrimType::Right, false) } #[derive(Debug)] diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 5edda9b80431..fcd71b7f7e94 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -563,15 +563,50 @@ SELECT 228 0 NULL ## Ensure no casts for BTRIM +# Test BTRIM with Utf8View input +query TT +EXPLAIN SELECT + BTRIM(column1_utf8view) AS l +FROM test; +---- +logical_plan +01)Projection: btrim(test.column1_utf8view) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test BTRIM with Utf8View input and Utf8View pattern query TT EXPLAIN SELECT BTRIM(column1_utf8view, 'foo') AS l FROM test; ---- logical_plan -01)Projection: btrim(CAST(test.column1_utf8view AS Utf8), Utf8("foo")) AS l +01)Projection: btrim(test.column1_utf8view, Utf8View("foo")) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test BTRIM with Utf8View bytes longer than 12 +query TT +EXPLAIN SELECT + BTRIM(column1_utf8view, 'this is longer than 12') AS l +FROM test; +---- +logical_plan +01)Projection: btrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l 02)--TableScan: test projection=[column1_utf8view] +# Test BTRIM outputs +query TTTT +SELECT + BTRIM(column1_utf8view, 'foo') AS l1, + BTRIM(column1_utf8view, 'A') AS l2, + BTRIM(column1_utf8view) AS l3, + BTRIM(column1_utf8view, NULL) AS l4 +FROM test; +---- +Andrew ndrew Andrew NULL +Xiangpeng Xiangpeng Xiangpeng NULL +Raphael Raphael Raphael NULL +NULL NULL NULL NULL + ## Ensure no casts for CHARACTER_LENGTH query TT EXPLAIN SELECT