diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 23d57f38efae..45bfa2351128 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, StringArray}; +use arrow::array::{ArrayRef, AsArray, StringArray}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; @@ -122,12 +122,12 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_replace::(&[ - Arc::clone(&data), - Arc::clone(®ex), - Arc::clone(&replacement), - Arc::clone(&flags), - ]) + regexp_replace::( + data.as_string::(), + regex.as_string::(), + replacement.as_string::(), + Some(&flags), + ) .expect("regexp_replace should work on valid values"), ) }) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index d28c6cd36d65..9693bb6e8b5e 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -16,13 +16,13 @@ // under the License. //! Regx expressions -use arrow::array::new_null_array; -use arrow::array::ArrayAccessor; use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; use arrow::array::StringViewBuilder; +use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayAccessor, StringViewArray}; use arrow::datatypes::DataType; use datafusion_common::cast::as_string_view_array; use datafusion_common::exec_err; @@ -59,6 +59,7 @@ impl RegexpReplaceFunc { Exact(vec![Utf8, Utf8, Utf8]), Exact(vec![Utf8View, Utf8, Utf8]), Exact(vec![Utf8, Utf8, Utf8, Utf8]), + Exact(vec![Utf8View, Utf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -187,104 +188,147 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// # Ok(()) /// # } /// ``` -pub fn regexp_replace(args: &[ArrayRef]) -> Result { +pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( + string_array: V, + pattern_array: B, + replacement_array: B, + flags: Option<&ArrayRef>, +) -> Result +where + V: ArrayAccessor, + B: ArrayAccessor, +{ // Default implementation for regexp_replace, assumes all args are arrays // and args is a sequence of 3 or 4 elements. // creating Regex is expensive so create hashmap for memoization let mut patterns: HashMap = HashMap::new(); - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) - .map(|((string, pattern), replacement)| match (string, pattern, replacement) { - (Some(string), Some(pattern), Some(replacement)) => { - let replacement = regex_replace_posix_groups(replacement); - - // if patterns hashmap already has regexp then use else create and return - let re = match patterns.get(pattern) { - Some(re) => Ok(re), - None => { - match Regex::new(pattern) { - Ok(re) => { - patterns.insert(pattern.to_string(), re); - Ok(patterns.get(pattern).unwrap()) + let datatype = string_array.data_type().to_owned(); + + let string_array_iter = ArrayIter::new(string_array); + let pattern_array_iter = ArrayIter::new(pattern_array); + let replacement_array_iter = ArrayIter::new(replacement_array); + + match flags { + None => { + let result_iter = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) + .map(|((string, pattern), replacement)| { + match (string, pattern, replacement) { + (Some(string), Some(pattern), Some(replacement)) => { + let replacement = regex_replace_posix_groups(replacement); + // if patterns hashmap already has regexp then use else create and return + let re = match patterns.get(pattern) { + Some(re) => Ok(re), + None => match Regex::new(pattern) { + Ok(re) => { + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) + } + Err(err) => { + Err(DataFusionError::External(Box::new(err))) + } }, - Err(err) => Err(DataFusionError::External(Box::new(err))), - } - } - }; + }; - Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose() + Some(re.map(|re| re.replace(string, replacement.as_str()))) + .transpose() + } + _ => Ok(None), + } + }); + + match datatype { + DataType::Utf8 | DataType::LargeUtf8 => { + let result = + result_iter.collect::>>()?; + Ok(Arc::new(result) as ArrayRef) } - _ => Ok(None) - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + DataType::Utf8View => { + let result = result_iter.collect::>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; - let flags_array = as_generic_string_array::(&args[3])?; - - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) - .zip(flags_array.iter()) - .map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) { - (Some(string), Some(pattern), Some(replacement), Some(flags)) => { - let replacement = regex_replace_posix_groups(replacement); - - // format flags into rust pattern - let (pattern, replace_all) = if flags == "g" { - (pattern.to_string(), true) - } else if flags.contains('g') { - (format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true) - } else { - (format!("(?{flags}){pattern}"), false) - }; - - // if patterns hashmap already has regexp then use else create and return - let re = match patterns.get(&pattern) { - Some(re) => Ok(re), - None => { - match Regex::new(pattern.as_str()) { - Ok(re) => { - patterns.insert(pattern.clone(), re); - Ok(patterns.get(&pattern).unwrap()) + Some(flags) => { + let flags_array = as_generic_string_array::(flags)?; + + let result_iter = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) + .zip(flags_array.iter()) + .map(|(((string, pattern), replacement), flags)| { + match (string, pattern, replacement, flags) { + (Some(string), Some(pattern), Some(replacement), Some(flags)) => { + let replacement = regex_replace_posix_groups(replacement); + + // format flags into rust pattern + let (pattern, replace_all) = if flags == "g" { + (pattern.to_string(), true) + } else if flags.contains('g') { + ( + format!( + "(?{}){}", + flags.to_string().replace('g', ""), + pattern + ), + true, + ) + } else { + (format!("(?{flags}){pattern}"), false) + }; + + // if patterns hashmap already has regexp then use else create and return + let re = match patterns.get(&pattern) { + Some(re) => Ok(re), + None => match Regex::new(pattern.as_str()) { + Ok(re) => { + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) + } + Err(err) => { + Err(DataFusionError::External(Box::new(err))) + } }, - Err(err) => Err(DataFusionError::External(Box::new(err))), - } + }; + + Some(re.map(|re| { + if replace_all { + re.replace_all(string, replacement.as_str()) + } else { + re.replace(string, replacement.as_str()) + } + })) + .transpose() } - }; - - Some(re.map(|re| { - if replace_all { - re.replace_all(string, replacement.as_str()) - } else { - re.replace(string, replacement.as_str()) - } - })).transpose() + _ => Ok(None), + } + }); + + match datatype { + DataType::Utf8 | DataType::LargeUtf8 => { + let result = + result_iter.collect::>>()?; + Ok(Arc::new(result) as ArrayRef) } - _ => Ok(None) - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + DataType::Utf8View => { + let result = result_iter.collect::>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } - other => exec_err!( - "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." - ), } } @@ -496,7 +540,47 @@ pub fn specialize_regexp_replace( .iter() .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - regexp_replace::(&args) + + match args[0].data_type() { + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 83c75b8df38c..2a3159bdbff2 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -460,6 +460,96 @@ Xiangpeng Raphael NULL +### Test REGEXP_REPLACE + +# Should run REGEXP_REPLACE with Scalar value for utf8view +query T +SELECT + REGEXP_REPLACE(column1_utf8view, 'e', 'f') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with Scalar value for utf8view with flag +query T +SELECT + REGEXP_REPLACE(column1_utf8view, 'e', 'f', 'i') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with Scalar value for utf8 +query T +SELECT + REGEXP_REPLACE(column1_utf8, 'e', 'f') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with Scalar value for utf8 with flag +query T +SELECT + REGEXP_REPLACE(column1_utf8, 'e', 'f', 'i') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8view +query T +SELECT + REGEXP_REPLACE(column1_utf8view, lower(column1_utf8view), 'bar') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8view with flag +query T +SELECT + REGEXP_REPLACE(column1_utf8view, lower(column1_utf8view), 'bar', 'g') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8 +query T +SELECT + REGEXP_REPLACE(column1_utf8, lower(column1_utf8), 'bar') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8 with flag +query T +SELECT + REGEXP_REPLACE(column1_utf8, lower(column1_utf8), 'bar', 'g') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + ### Initcap query TT