diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 772b04136129..e285bd85b197 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -21,7 +21,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; @@ -46,8 +48,10 @@ impl OverlayFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View, Int64, Int64]), Exact(vec![Utf8, Utf8, Int64, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8View, Utf8View, Int64]), Exact(vec![Utf8, Utf8, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], @@ -76,54 +80,107 @@ impl ScalarUDFImpl for OverlayFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(overlay::, vec![])(args), + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(overlay::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), other => exec_err!("Unsupported data type {other:?} for function overlay"), } } } +macro_rules! process_overlay { + // For the three-argument case + ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{ + $string_array + .iter() + .zip($characters_array.iter()) + .zip($pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>() + }}; + + // For the four-argument case + ($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{ + $string_array + .iter() + .zip($characters_array.iter()) + .zip($pos_num.iter()) + .zip($len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>() + }}; +} + /// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) /// Replaces a substring of string1 with string2 starting at the integer bit /// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas /// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead -pub fn overlay(args: &[ArrayRef]) -> Result { +fn overlay(args: &[ArrayRef]) -> Result { + let use_string_view = args[0].data_type() == &DataType::Utf8View; + if use_string_view { + string_view_overlay::(args) + } else { + string_overlay::(args) + } +} + +pub fn string_overlay(args: &[ArrayRef]) -> Result { match args.len() { 3 => { let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; + let result = process_overlay!(string_array, characters_array, pos_num)?; Ok(Arc::new(result) as ArrayRef) } 4 => { @@ -132,37 +189,34 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let pos_num = as_int64_array(&args[2])?; let len_num = as_int64_array(&args[3])?; - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; + let result = + process_overlay!(string_array, characters_array, pos_num, len_num)?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") + } + } +} + +pub fn string_view_overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = process_overlay!(string_array, characters_array, pos_num)?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = + process_overlay!(string_array, characters_array, pos_num, len_num)?; Ok(Arc::new(result) as ArrayRef) } other => { diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 3255ddccdb81..2d535b5caad3 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -871,7 +871,7 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 -#overlay tests +# overlay tests statement ok CREATE TABLE over_test( str TEXT, @@ -913,6 +913,31 @@ NULL Thomxas NULL +# overlay tests with utf8view +query T +SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL + query I SELECT levenshtein('kitten', 'sitting') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index fcd71b7f7e94..670856a3ec5f 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -719,16 +719,23 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for OVERLAY -## TODO file ticket query TT EXPLAIN SELECT OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1 FROM test; ---- logical_plan -01)Projection: overlay(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Int64(2)) AS c1 +01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1 02)--TableScan: test projection=[column1_utf8view] +query T +SELECT OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1 FROM test; +---- +Afooew +Xfoogpeng +Rfooael +NULL + ## Ensure no casts for REGEXP_LIKE query TT EXPLAIN SELECT