Skip to content

Commit

Permalink
improve performance of regexp_count (apache#13364)
Browse files Browse the repository at this point in the history
* improve performance of regexp_count

* fix clippy

* collect with Int64Array to eliminate one temp Vec

---------

Co-authored-by: Dima <[email protected]>
  • Loading branch information
Dimchikkk and Dima authored Nov 12, 2024
1 parent 8d6899e commit 705dd0e
Showing 1 changed file with 39 additions and 43 deletions.
82 changes: 39 additions & 43 deletions datafusion/functions/src/regex/regexpcount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use datafusion_expr::{
};
use itertools::izip;
use regex::Regex;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -312,12 +311,12 @@ where

let pattern = compile_regex(regex, flags_scalar)?;

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
values
.iter()
.map(|value| count_matches(value, &pattern, start_scalar))
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(true, true, false) => {
let regex = match regex_scalar {
Expand All @@ -336,17 +335,17 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
values
.iter()
.zip(flags_array.iter())
.map(|(value, flags)| {
let pattern =
compile_and_cache_regex(regex, flags, &mut regex_cache)?;
count_matches(value, &pattern, start_scalar)
count_matches(value, pattern, start_scalar)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(true, false, true) => {
let regex = match regex_scalar {
Expand All @@ -360,13 +359,13 @@ where

let start_array = start_array.unwrap();

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
values
.iter()
.zip(start_array.iter())
.map(|(value, start)| count_matches(value, &pattern, start))
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(true, false, false) => {
let regex = match regex_scalar {
Expand All @@ -385,7 +384,7 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
izip!(
values.iter(),
start_array.unwrap().iter(),
Expand All @@ -395,10 +394,10 @@ where
let pattern =
compile_and_cache_regex(regex, flags, &mut regex_cache)?;

count_matches(value, &pattern, start)
count_matches(value, pattern, start)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(false, true, true) => {
if values.len() != regex_array.len() {
Expand All @@ -409,7 +408,7 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
values
.iter()
.zip(regex_array.iter())
Expand All @@ -424,10 +423,10 @@ where
flags_scalar,
&mut regex_cache,
)?;
count_matches(value, &pattern, start_scalar)
count_matches(value, pattern, start_scalar)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(false, true, false) => {
if values.len() != regex_array.len() {
Expand All @@ -447,7 +446,7 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
izip!(values.iter(), regex_array.iter(), flags_array.iter())
.map(|(value, regex, flags)| {
let regex = match regex {
Expand All @@ -458,10 +457,10 @@ where
let pattern =
compile_and_cache_regex(regex, flags, &mut regex_cache)?;

count_matches(value, &pattern, start_scalar)
count_matches(value, pattern, start_scalar)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(false, false, true) => {
if values.len() != regex_array.len() {
Expand All @@ -481,7 +480,7 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
izip!(values.iter(), regex_array.iter(), start_array.iter())
.map(|(value, regex, start)| {
let regex = match regex {
Expand All @@ -494,10 +493,10 @@ where
flags_scalar,
&mut regex_cache,
)?;
count_matches(value, &pattern, start)
count_matches(value, pattern, start)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
(false, false, false) => {
if values.len() != regex_array.len() {
Expand Down Expand Up @@ -526,7 +525,7 @@ where
)));
}

Ok(Arc::new(Int64Array::from_iter_values(
Ok(Arc::new(
izip!(
values.iter(),
regex_array.iter(),
Expand All @@ -541,27 +540,24 @@ where

let pattern =
compile_and_cache_regex(regex, flags, &mut regex_cache)?;
count_matches(value, &pattern, start)
count_matches(value, pattern, start)
})
.collect::<Result<Vec<i64>, ArrowError>>()?,
)))
.collect::<Result<Int64Array, ArrowError>>()?,
))
}
}
}

fn compile_and_cache_regex(
regex: &str,
flags: Option<&str>,
regex_cache: &mut HashMap<String, Regex>,
) -> Result<Regex, ArrowError> {
match regex_cache.entry(regex.to_string()) {
Entry::Vacant(entry) => {
let compiled = compile_regex(regex, flags)?;
entry.insert(compiled.clone());
Ok(compiled)
}
Entry::Occupied(entry) => Ok(entry.get().to_owned()),
fn compile_and_cache_regex<'a>(
regex: &'a str,
flags: Option<&'a str>,
regex_cache: &'a mut HashMap<String, Regex>,
) -> Result<&'a Regex, ArrowError> {
if !regex_cache.contains_key(regex) {
let compiled = compile_regex(regex, flags)?;
regex_cache.insert(regex.to_string(), compiled);
}
Ok(regex_cache.get(regex).unwrap())
}

fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
Expand Down

0 comments on commit 705dd0e

Please sign in to comment.