Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Internal error in regexp_replace() for some StringView input #12203

Merged
merged 8 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions datafusion/functions/benches/regx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
extern crate criterion;

use arrow::array::builder::StringBuilder;
use arrow::array::{ArrayRef, StringArray};
use arrow::array::{ArrayRef, AsArray, StringArray};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_functions::regex::regexplike::regexp_like;
use datafusion_functions::regex::regexpmatch::regexp_match;
Expand Down Expand Up @@ -122,12 +122,12 @@ fn criterion_benchmark(c: &mut Criterion) {

b.iter(|| {
black_box(
regexp_replace::<i32>(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&replacement),
Arc::clone(&flags),
])
regexp_replace::<i32, _, _>(
data.as_string::<i32>(),
regex.as_string::<i32>(),
replacement.as_string::<i32>(),
Some(&flags),
)
.expect("regexp_replace should work on valid values"),
)
})
Expand Down
246 changes: 161 additions & 85 deletions datafusion/functions/src/regex/regexpreplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
// under the License.

//! Regx expressions
use arrow::array::new_null_array;
use arrow::array::ArrayAccessor;
use arrow::array::ArrayDataBuilder;
use arrow::array::BufferBuilder;
use arrow::array::GenericStringArray;
use arrow::array::StringViewBuilder;
use arrow::array::{new_null_array, ArrayIter, AsArray};
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_string_view_array;
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::exec_err;
use datafusion_common::plan_err;
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -59,6 +59,7 @@ impl RegexpReplaceFunc {
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8]),
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -187,104 +188,117 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
/// # Ok(())
/// # }
/// ```
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>(
string_array: V,
pattern_array: B,
replacement_array: B,
flags: Option<&ArrayRef>,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
{
// Default implementation for regexp_replace, assumes all args are arrays
// and args is a sequence of 3 or 4 elements.

// creating Regex is expensive so create hashmap for memoization
let mut patterns: HashMap<String, Regex> = HashMap::new();

match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;

let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(pattern) {
Some(re) => Ok(re),
None => {
match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re);
Ok(patterns.get(pattern).unwrap())
let string_array_iter = ArrayIter::new(string_array);
let pattern_array_iter = ArrayIter::new(pattern_array);
let replacement_array_iter = ArrayIter::new(replacement_array);

match flags {
None => {
let result = string_array_iter
.zip(pattern_array_iter)
.zip(replacement_array_iter)
.map(|((string, pattern), replacement)| {
match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);
// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(pattern) {
Some(re) => Ok(re),
None => match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re);
Ok(patterns.get(pattern).unwrap())
}
Err(err) => {
Err(DataFusionError::External(Box::new(err)))
}
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
}
};
};

Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Some(re.map(|re| re.replace(string, replacement.as_str())))
.transpose()
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;

Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;
let flags_array = as_generic_string_array::<T>(&args[3])?;

let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);

// format flags into rust pattern
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true)
} else {
(format!("(?{flags}){pattern}"), false)
};

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(&pattern) {
Some(re) => Ok(re),
None => {
match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern.clone(), re);
Ok(patterns.get(&pattern).unwrap())
Some(flags) => {
let flags_array = as_generic_string_array::<T>(flags)?;

let result = string_array_iter
.zip(pattern_array_iter)
.zip(replacement_array_iter)
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| {
match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);

// format flags into rust pattern
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(
format!(
"(?{}){}",
flags.to_string().replace('g', ""),
pattern
),
true,
)
} else {
(format!("(?{flags}){pattern}"), false)
};

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(&pattern) {
Some(re) => Ok(re),
None => match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern.clone(), re);
Ok(patterns.get(&pattern).unwrap())
}
Err(err) => {
Err(DataFusionError::External(Box::new(err)))
}
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
};

Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
}))
.transpose()
}
};

Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
})).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
devanbenz marked this conversation as resolved.
Show resolved Hide resolved

Ok(Arc::new(result) as ArrayRef)
}
other => exec_err!(
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
),
}
}

Expand Down Expand Up @@ -496,7 +510,69 @@ pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
.iter()
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;
regexp_replace::<T>(&args)

match args[0].data_type() {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let pattern_array = args[1].as_string::<i32>();
let replacement_array = args[2].as_string::<i32>();
let regexp_replace_result = regexp_replace::<i32, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)?;

if regexp_replace_result.data_type() == &DataType::Utf8 {
devanbenz marked this conversation as resolved.
Show resolved Hide resolved
let string_view_array =
as_string_array(&regexp_replace_result)?.to_owned();

let mut builder =
StringViewBuilder::with_capacity(string_view_array.len())
.with_block_size(1024 * 1024 * 2);

for val in string_view_array.iter() {
if let Some(val) = val {
builder.append_value(val);
} else {
builder.append_null();
}
}

let result = builder.finish();
Ok(Arc::new(result) as ArrayRef)
} else {
Ok(regexp_replace_result)
}
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let pattern_array = args[1].as_string::<i32>();
let replacement_array = args[2].as_string::<i32>();
regexp_replace::<i32, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let pattern_array = args[1].as_string::<i64>();
let replacement_array = args[2].as_string::<i64>();
regexp_replace::<i64, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)
}
other => {
exec_err!(
"Unsupported data type {other:?} for function regex_replace"
)
}
}
}
}
}
Expand Down
Loading