Skip to content

Commit

Permalink
Implement native support StringView for REGEXP_LIKE
Browse files Browse the repository at this point in the history
Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 committed Oct 12, 2024
1 parent 8cf030a commit 899fb7f
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 51 deletions.
8 changes: 2 additions & 6 deletions datafusion/functions/benches/regx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,8 @@ fn criterion_benchmark(c: &mut Criterion) {

b.iter(|| {
black_box(
regexp_like::<i32>(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&flags),
])
.expect("regexp_like should work on valid values"),
regexp_like(&[Arc::clone(&data), Arc::clone(&regex), Arc::clone(&flags)])
.expect("regexp_like should work on valid values"),
)
})
});
Expand Down
202 changes: 158 additions & 44 deletions datafusion/functions/src/regex/regexplike.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
// specific language governing permissions and limitations
// under the License.

//! Regx expressions
use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
//! Regex expressions
use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray};
use arrow::compute::kernels::regexp;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
use datafusion_common::exec_err;
use datafusion_common::ScalarValue;
use datafusion_common::{arrow_datafusion_err, plan_err};
use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX;
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

use std::any::Any;
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -82,14 +83,27 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo

impl RegexpLikeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![Utf8View, Utf8]),
TypeSignature::Exact(vec![Utf8View, Utf8View]),
TypeSignature::Exact(vec![Utf8View, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8, Utf8View]),
TypeSignature::Exact(vec![Utf8, LargeUtf8]),
TypeSignature::Exact(vec![LargeUtf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, Utf8View]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]),
TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]),
TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -120,6 +134,7 @@ impl ScalarUDFImpl for RegexpLikeFunc {
_ => Boolean,
})
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len = args
.iter()
Expand All @@ -135,7 +150,7 @@ impl ScalarUDFImpl for RegexpLikeFunc {
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;

let result = regexp_like_func(&args);
let result = regexp_like(&args);
if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
Expand All @@ -149,15 +164,7 @@ impl ScalarUDFImpl for RegexpLikeFunc {
Some(get_regexp_like_doc())
}
}
fn regexp_like_func(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => regexp_like::<i32>(args),
DataType::LargeUtf8 => regexp_like::<i64>(args),
other => {
internal_err!("Unsupported data type {other:?} for function regexp_like")
}
}
}

/// Tests a string using a regular expression returning true if at
/// least one match, false otherwise.
///
Expand Down Expand Up @@ -200,47 +207,114 @@ fn regexp_like_func(args: &[ArrayRef]) -> Result<ArrayRef> {
/// # Ok(())
/// # }
/// ```
pub fn regexp_like<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn regexp_like(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
let flags: Option<&GenericStringArray<T>> = None;
let array = regexp::regexp_is_match(values, regex, flags)
.map_err(|e| arrow_datafusion_err!(e))?;

Ok(Arc::new(array) as ArrayRef)
}
2 => handle_regexp_like(&args[0], &args[1], None),
3 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
let flags = as_generic_string_array::<T>(&args[2])?;
let flags = args[2].as_string::<i32>();

if flags.iter().any(|s| s == Some("g")) {
return plan_err!("regexp_like() does not support the \"global\" option");
}

let array = regexp::regexp_is_match(values, regex, Some(flags))
.map_err(|e| arrow_datafusion_err!(e))?;

Ok(Arc::new(array) as ArrayRef)
}
handle_regexp_like(&args[0], &args[1], Some(flags))
},
other => exec_err!(
"regexp_like was called with {other} arguments. It requires at least 2 and at most 3."
"`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3."
),
}
}

fn handle_regexp_like(
values: &ArrayRef,
patterns: &ArrayRef,
flags: Option<&GenericStringArray<i32>>,
) -> Result<ArrayRef> {
let array = match (values.data_type(), patterns.data_type()) {
(Utf8View, Utf8) => {
let value = values.as_string_view();
let pattern = patterns.as_string::<i32>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(Utf8View, Utf8View) => {
let value = values.as_string_view();
let pattern = patterns.as_string_view();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(Utf8View, LargeUtf8) => {
let value = values.as_string_view();
let pattern = patterns.as_string::<i64>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(Utf8, Utf8) => {
let value = values.as_string::<i32>();
let pattern = patterns.as_string::<i32>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(Utf8, Utf8View) => {
let value = values.as_string::<i32>();
let pattern = patterns.as_string_view();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(Utf8, LargeUtf8) => {
let value = values.as_string_view();
let pattern = patterns.as_string::<i64>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(LargeUtf8, Utf8) => {
let value = values.as_string::<i64>();
let pattern = patterns.as_string::<i32>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(LargeUtf8, Utf8View) => {
let value = values.as_string::<i64>();
let pattern = patterns.as_string_view();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
(LargeUtf8, LargeUtf8) => {
let value = values.as_string::<i64>();
let pattern = patterns.as_string::<i64>();

regexp::regexp_is_match(value, pattern, flags)
.map_err(|e| arrow_datafusion_err!(e))?
}
other => {
return internal_err!(
"Unsupported data type {other:?} for function `regexp_like`"
)
}
};

Ok(Arc::new(array) as ArrayRef)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::BooleanBuilder;
use arrow::array::StringArray;
use arrow::array::{BooleanBuilder, StringViewArray};

use crate::regex::regexplike::regexp_like;

#[test]
fn test_case_sensitive_regexp_like() {
fn test_case_sensitive_regexp_like_utf8() {
let values = StringArray::from(vec!["abc"; 5]);

let patterns =
Expand All @@ -254,13 +328,33 @@ mod tests {
expected_builder.append_value(false);
let expected = expected_builder.finish();

let re = regexp_like::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_case_sensitive_regexp_like_utf8view() {
let values = StringViewArray::from(vec!["abc"; 5]);

let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);

let mut expected_builder: BooleanBuilder = BooleanBuilder::new();
expected_builder.append_value(true);
expected_builder.append_value(false);
expected_builder.append_value(true);
expected_builder.append_value(false);
expected_builder.append_value(false);
let expected = expected_builder.finish();

let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_case_insensitive_regexp_like() {
fn test_case_insensitive_regexp_like_utf8() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
Expand All @@ -274,9 +368,29 @@ mod tests {
expected_builder.append_value(false);
let expected = expected_builder.finish();

let re =
regexp_like::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();
let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_case_insensitive_regexp_like_utf8view() {
let values = StringViewArray::from(vec!["abc"; 5]);
let patterns =
StringViewArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let flags = StringArray::from(vec!["i"; 5]);

let mut expected_builder: BooleanBuilder = BooleanBuilder::new();
expected_builder.append_value(true);
expected_builder.append_value(true);
expected_builder.append_value(true);
expected_builder.append_value(true);
expected_builder.append_value(false);
let expected = expected_builder.finish();

let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -288,7 +402,7 @@ mod tests {
let flags = StringArray::from(vec!["g"]);

let re_err =
regexp_like::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.expect_err("unsupported flag should have failed");

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: regexp_like(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k
01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for REGEXP_MATCH
Expand Down

0 comments on commit 899fb7f

Please sign in to comment.