diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 4cebbba839fa..13fa3d55672d 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -18,10 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; @@ -45,7 +45,11 @@ impl ReplaceFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8, Utf8])], + vec![ + Exact(vec![Utf8View, Utf8View, Utf8View]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), + ], Volatility::Immutable, ), } @@ -73,6 +77,7 @@ impl ScalarUDFImpl for ReplaceFunc { match args[0].data_type() { DataType::Utf8 => make_scalar_function(replace::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), + DataType::Utf8View => make_scalar_function(replace_view, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function replace") } @@ -80,6 +85,23 @@ impl ScalarUDFImpl for ReplaceFunc { } } +fn replace_view(args: &[ArrayRef]) -> Result { + let string_array = as_string_view_array(&args[0])?; + let from_array = as_string_view_array(&args[1])?; + let to_array = as_string_view_array(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} /// Replaces all occurrences in string of substring from with substring to. /// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' fn replace(args: &[ArrayRef]) -> Result { @@ -100,4 +122,60 @@ fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -mod test {} +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::array::LargeStringArray; + use arrow::array::StringArray; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use datafusion_common::ScalarValue; + #[test] + fn test_functions() -> Result<()> { + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), + ], + Ok(Some("aacccdqcccc")), + &str, + Utf8, + StringArray + ); + + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( + "aabbb" + )))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))), + ], + Ok(Some("aacc")), + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "aabbbcw" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))), + ], + Ok(Some("aaccbcw")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index cb592fdda0c8..426123dd9be1 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -826,6 +826,16 @@ SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') ---- foohello +query T +SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) +---- +foohello + query T SELECT rtrim(' foo ') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e094bcaf1b5d..15e8fce20296 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -902,7 +902,6 @@ logical_plan 01)Projection: regexp_replace(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$"), Utf8("\1")) AS k 02)--TableScan: test projection=[column1_utf8view] - ## Ensure no casts for REPEAT query TT EXPLAIN SELECT @@ -914,7 +913,6 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REPLACE -## TODO file ticket query TT EXPLAIN SELECT REPLACE(column1_utf8view, 'foo', 'bar') as c1, @@ -922,9 +920,20 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: replace(__common_expr_1, Utf8("foo"), Utf8("bar")) AS c1, replace(__common_expr_1, CAST(test.column2_utf8view AS Utf8), Utf8("bar")) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: replace(test.column1_utf8view, Utf8View("foo"), Utf8View("bar")) AS c1, replace(test.column1_utf8view, test.column2_utf8view, Utf8View("bar")) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +query TT +SELECT + REPLACE(column1_utf8view, 'foo', 'bar') as c1, + REPLACE(column1_utf8view, column2_utf8view, 'bar') as c2 +FROM test; +---- +Andrew Andrew +Xiangpeng bar +Raphael baraphael +NULL NULL + ## Ensure no casts for REVERSE query TT