Skip to content

Commit

Permalink
Implement native support StringView for contains function
Browse files Browse the repository at this point in the history
Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 committed Sep 6, 2024
1 parent 45dd141 commit 22f383b
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 110 deletions.
4 changes: 2 additions & 2 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ encoding_expressions = ["base64", "hex"]
# enable math functions
math_expressions = []
# enable regular expressions
regex_expressions = ["regex"]
regex_expressions = ["regex", "string_expressions"]
# enable string functions
string_expressions = ["regex", "uuid"]
string_expressions = ["regex_expressions", "uuid"]
# enable unicode functions
unicode_expressions = ["hashbrown", "unicode-segmentation"]

Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/regex/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! "regex" DataFusion functions
pub mod regexp_common;
pub mod regexplike;
pub mod regexpmatch;
pub mod regexpreplace;
Expand Down
121 changes: 121 additions & 0 deletions datafusion/functions/src/regex/regexp_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Common utilities for implementing regex functions
use crate::string::common::StringArrayType;

use arrow::array::{Array, ArrayDataBuilder, BooleanArray};
use arrow::datatypes::DataType;
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use datafusion_common::DataFusionError;
use regex::Regex;

use std::collections::HashMap;

#[cfg(doc)]
use arrow::array::{LargeStringArray, StringArray, StringViewArray};
/// Perform SQL `array ~ regex_array` operation on
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
///
/// If `regex_array` element has an empty value, the corresponding result value is always true.
///
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
/// which allow special search modes, such as case-insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
///
/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs].
///
/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37
pub fn regexp_is_match_utf8<'a, S1, S2, S3>(
array: &'a S1,
regex_array: &'a S2,
flags_array: Option<&'a S3>,
) -> datafusion_common::Result<BooleanArray, DataFusionError>
where
&'a S1: StringArrayType<'a>,
&'a S2: StringArrayType<'a>,
&'a S3: StringArrayType<'a>,
{
if array.len() != regex_array.len() {
return Err(DataFusionError::Execution(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());

let mut patterns: HashMap<String, Regex> = HashMap::new();
let mut result = BooleanBufferBuilder::new(array.len());

let complete_pattern = match flags_array {
Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
|(pattern, flags)| {
pattern.map(|pattern| match flags {
Some(flag) => format!("(?{flag}){pattern}"),
None => pattern.to_string(),
})
},
)) as Box<dyn Iterator<Item = Option<String>>>,
None => Box::new(
regex_array
.iter()
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
),
};

array
.iter()
.zip(complete_pattern)
.map(|(value, pattern)| {
match (value, pattern) {
(Some(_), Some(pattern)) if pattern == *"" => {
result.append(true);
}
(Some(value), Some(pattern)) => {
let existing_pattern = patterns.get(&pattern);
let re = match existing_pattern {
Some(re) => re,
None => {
let re = Regex::new(pattern.as_str()).map_err(|e| {
DataFusionError::Execution(format!(
"Regular expression did not compile: {e:?}"
))
})?;
patterns.entry(pattern).or_insert(re)
}
};
result.append(re.is_match(value));
}
_ => result.append(false),
}
Ok(())
})
.collect::<datafusion_common::Result<Vec<()>, DataFusionError>>()?;

let data = unsafe {
ArrayDataBuilder::new(DataType::Boolean)
.len(array.len())
.buffers(vec![result.into()])
.nulls(nulls)
.build_unchecked()
};

Ok(BooleanArray::from(data))
}
100 changes: 2 additions & 98 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,20 @@

//! Common utilities for implementing string functions
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use arrow::array::{
new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef,
BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray,
GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray,
StringBuilder, StringViewArray,
};
use arrow::buffer::{Buffer, MutableBuffer, NullBuffer};
use arrow::datatypes::DataType;
use arrow_buffer::BooleanBufferBuilder;
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
use datafusion_common::Result;
use datafusion_common::{exec_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use regex::Regex;

pub(crate) enum TrimType {
Left,
Expand Down Expand Up @@ -481,96 +478,3 @@ where
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
}

#[cfg(doc)]
use arrow::array::LargeStringArray;
/// Perform SQL `array ~ regex_array` operation on
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
///
/// If `regex_array` element has an empty value, the corresponding result value is always true.
///
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
/// which allow special search modes, such as case-insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
///
/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs].
///
/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37
pub fn regexp_is_match<'a, S1, S2, S3>(
array: &'a S1,
regex_array: &'a S2,
flags_array: Option<&'a S3>,
) -> Result<BooleanArray, DataFusionError>
where
&'a S1: StringArrayType<'a>,
&'a S2: StringArrayType<'a>,
&'a S3: StringArrayType<'a>,
{
if array.len() != regex_array.len() {
return Err(DataFusionError::Execution(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());

let mut patterns: HashMap<String, Regex> = HashMap::new();
let mut result = BooleanBufferBuilder::new(array.len());

let complete_pattern = match flags_array {
Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
|(pattern, flags)| {
pattern.map(|pattern| match flags {
Some(flag) => format!("(?{flag}){pattern}"),
None => pattern.to_string(),
})
},
)) as Box<dyn Iterator<Item = Option<String>>>,
None => Box::new(
regex_array
.iter()
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
),
};

array
.iter()
.zip(complete_pattern)
.map(|(value, pattern)| {
match (value, pattern) {
(Some(_), Some(pattern)) if pattern == *"" => {
result.append(true);
}
(Some(value), Some(pattern)) => {
let existing_pattern = patterns.get(&pattern);
let re = match existing_pattern {
Some(re) => re,
None => {
let re = Regex::new(pattern.as_str()).map_err(|e| {
DataFusionError::Execution(format!(
"Regular expression did not compile: {e:?}"
))
})?;
patterns.entry(pattern).or_insert(re)
}
};
result.append(re.is_match(value));
}
_ => result.append(false),
}
Ok(())
})
.collect::<Result<Vec<()>, DataFusionError>>()?;

let data = unsafe {
ArrayDataBuilder::new(DataType::Boolean)
.len(array.len())
.buffers(vec![result.into()])
.nulls(nulls)
.build_unchecked()
};

Ok(BooleanArray::from(data))
}
20 changes: 10 additions & 10 deletions datafusion/functions/src/string/contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::string::common::regexp_is_match;
use crate::regex::regexp_common::regexp_is_match_utf8;
use crate::utils::make_scalar_function;

use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray};
Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8View, Utf8View) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string_view();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
StringViewArray,
StringViewArray,
GenericStringArray<i32>,
Expand All @@ -103,7 +103,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8View, Utf8) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
StringViewArray,
GenericStringArray<i32>,
GenericStringArray<i32>,
Expand All @@ -114,7 +114,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8View, LargeUtf8) => {
let mod_str = args[0].as_string_view();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
StringViewArray,
GenericStringArray<i64>,
GenericStringArray<i32>,
Expand All @@ -125,7 +125,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8, Utf8View) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string_view();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
StringViewArray,
GenericStringArray<i32>,
Expand All @@ -136,7 +136,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8, Utf8) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i32>,
GenericStringArray<i32>,
Expand All @@ -147,7 +147,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(Utf8, LargeUtf8) => {
let mod_str = args[0].as_string::<i32>();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i64>,
GenericStringArray<i32>,
Expand All @@ -158,7 +158,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(LargeUtf8, Utf8View) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string_view();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
StringViewArray,
GenericStringArray<i32>,
Expand All @@ -169,7 +169,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(LargeUtf8, Utf8) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string::<i32>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
GenericStringArray<i32>,
GenericStringArray<i32>,
Expand All @@ -180,7 +180,7 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
(LargeUtf8, LargeUtf8) => {
let mod_str = args[0].as_string::<i64>();
let match_str = args[1].as_string::<i64>();
let res = regexp_is_match::<
let res = regexp_is_match_utf8::<
GenericStringArray<i64>,
GenericStringArray<i64>,
GenericStringArray<i32>,
Expand Down

0 comments on commit 22f383b

Please sign in to comment.