From 33e130b0e714c11da185c2505f0dcc6d3ba50098 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Tue, 20 Aug 2024 17:32:51 -0700 Subject: [PATCH] refine code --- datafusion/functions/src/unicode/rpad.rs | 329 ++++++++++++----------- 1 file changed, 167 insertions(+), 162 deletions(-) diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 0b71b8ba5652..c1d6f327928f 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; -use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, +use arrow::array::{ + ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, }; +use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; @@ -29,6 +31,7 @@ use std::any::Any; use std::fmt::Write; use std::sync::Arc; use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; #[derive(Debug)] pub struct RPadFunc { @@ -84,180 +87,182 @@ impl ScalarUDFImpl for RPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args.len() { - 2 => match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(rpad::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(rpad::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }, - 3 => match (args[0].data_type(), args[2].data_type()) { - ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, - ) => make_scalar_function(rpad::, vec![])(args), - (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (first_type, last_type) => { - exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type) - } - }, - number => { - exec_err!("unsupported arguments number {} for rpad", number) + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8 | Utf8View, _) => { + make_scalar_function(rpad::, vec![])(args) + } + (2, LargeUtf8, _) => make_scalar_function(rpad::, vec![])(args), + (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, Utf8 | Utf8View, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (_, _, _) => { + exec_err!("Unsupported combination of data types for function rpad") } } } } -macro_rules! process_rpad { - // For the two-argument case - ($string_array:expr, $length_array:expr, $builder:expr) => {{ - $string_array - .iter() - .zip($length_array.iter()) - .try_for_each(|(string, length)| -> Result<(), DataFusionError> { - match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {} too large", - length - ); - } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - $builder.append_value(""); - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - $builder.append_value(graphemes[..length].concat()); - } else { - $builder.write_str(string)?; - $builder - .write_str(&" ".repeat(length - graphemes.len()))?; - $builder.append_value(""); - } - } - } - _ => $builder.append_null(), - } - Ok(()) - })?; - Ok(Arc::new($builder.finish()) as ArrayRef) - }}; - - // For the three-argument case - ($string_array:expr, $length_array:expr, $fill_array:expr, $builder:expr) => {{ - $string_array - .iter() - .zip($length_array.iter()) - .zip($fill_array.iter()) - .try_for_each(|((string, length), fill)| -> Result<(), DataFusionError> { - match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {} too large", - length - ); - } - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); +pub fn rpad( + args: &[ArrayRef], +) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "rpad was called with {} arguments. It requires 2 or 3 arguments.", + args.len() + ); + } - if length < graphemes.len() { - $builder.append_value(graphemes[..length].concat()); - } else if fill_chars.is_empty() { - $builder.append_value(string); - } else { - $builder.write_str(string)?; - let fill_str = fill_chars - .iter() - .cycle() - .take(length - graphemes.len()) - .collect::(); - $builder.write_str(&fill_str)?; - $builder.append_value(""); - } - } - _ => $builder.append_null(), - } - Ok(()) - })?; - Ok(Arc::new($builder.finish()) as ArrayRef) - }}; + let length_array = as_int64_array(&args[1])?; + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8View, _) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + None, + ) + } + (3, Utf8View, Some(Utf8View)) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string_view()), + ) + } + (3, Utf8View, Some(Utf8 | LargeUtf8)) => { + rpad_impl::<&StringViewArray, &GenericStringArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string::()), + ) + } + (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::< + &GenericStringArray, + &StringViewArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + Some(args[2].as_string_view()), + ), + (_, _, _) => rpad_impl::< + &GenericStringArray, + &GenericStringArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + args.get(2).map(|arg| arg.as_string::()), + ), + } } /// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. /// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad( - args: &[ArrayRef], -) -> Result { +pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( + string_array: StringArrType, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + StringArrType: StringArrayType<'a>, + FillArrType: StringArrayType<'a>, + StringArrayLen: OffsetSizeTrait, +{ let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - match (args.len(), args[0].data_type()) { - (2, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; - process_rpad!(string_array, length_array, builder) - } - (2, _) => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - process_rpad!(string_array, length_array, builder) - } - (3, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; - match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - process_rpad!(string_array, length_array, fill_array, builder) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; - process_rpad!(string_array, length_array, fill_array, builder) - } - other_type => { - exec_err!("unsupported type for rpad's third operator: {}", other_type) - } - } + match fill_array { + None => { + string_array.iter().zip(length_array.iter()).try_for_each( + |(string, length)| -> Result<(), DataFusionError> { + match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + } else { + let graphemes = + string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + builder.write_str(string)?; + builder.write_str( + &" ".repeat(length - graphemes.len()), + )?; + builder.append_value(""); + } + } + } + _ => builder.append_null(), + } + Ok(()) + }, + )?; } - (3, _) => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - process_rpad!(string_array, length_array, fill_array, builder) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; - process_rpad!(string_array, length_array, fill_array, builder) - } - other_type => { - exec_err!("unsupported type for rpad's third operator: {}", other_type) - } - } + Some(fill_array) => { + string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .try_for_each( + |((string, length), fill)| -> Result<(), DataFusionError> { + match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = + string.graphemes(true).collect::>(); - + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill.is_empty() { + builder.append_value(string); + } else { + builder.write_str(string)?; + fill.chars() + .cycle() + .take(length - graphemes.len()) + .for_each(|ch| builder.write_char(ch).unwrap()); + builder.append_value(""); + } + } + _ => builder.append_null(), + } + Ok(()) + }, + )?; } - (other, other_type) => exec_err!( - "rpad requires 2 or 3 arguments with corresponding types, but got {} arguments with type {}", - other, other_type - ), } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)]