Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/11953: Support StringView for TRANSLATE() fn #11967

Merged
merged 9 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions datafusion/functions/src/unicode/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, StringArray};
use arrow::datatypes::DataType;
use hashbrown::HashMap;
use unicode_segmentation::UnicodeSegmentation;

use datafusion_common::cast::as_generic_string_array;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::utils::{make_scalar_function, utf8_to_str_type};

#[derive(Debug)]
pub struct TranslateFunc {
signature: Signature,
Expand All @@ -46,7 +44,10 @@ impl TranslateFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Utf8, Utf8])],
vec![
Exact(vec![Utf8View, Utf8, Utf8]),
Exact(vec![Utf8, Utf8, Utf8]),
],
Volatility::Immutable,
),
}
Expand All @@ -71,27 +72,50 @@ impl ScalarUDFImpl for TranslateFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(translate::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(translate::<i64>, vec![])(args),
other => {
exec_err!("Unsupported data type {other:?} for function translate")
}
make_scalar_function(invoke_translate, vec![])(args)
}
}

fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let from_array = args[1].as_string::<i32>();
let to_array = args[1].as_string::<i32>();
translate::<_, _>(string_array, from_array, to_array)
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let from_array = args[1].as_string::<i32>();
let to_array = args[1].as_string::<i32>();
translate::<_, _>(string_array, from_array, to_array)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let from_array = args[1].as_string::<i64>();
let to_array = args[1].as_string::<i64>();
translate::<_, _>(string_array, from_array, to_array)
}
other => {
exec_err!("Unsupported data type {other:?} for function translate")
}
}
}

/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
/// translate('12345', '143', 'ax') = 'a2x5'
fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let from_array = as_generic_string_array::<T>(&args[1])?;
let to_array = as_generic_string_array::<T>(&args[2])?;

let result = string_array
.iter()
.zip(from_array.iter())
.zip(to_array.iter())
fn translate<'a, V, B>(string_array: V, from_array: B, to_array: B) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
{
let string_array_iter = ArrayIter::new(string_array);
let from_array_iter = ArrayIter::new(from_array);
let to_array_iter = ArrayIter::new(to_array);

let result = string_array_iter
.zip(from_array_iter)
.zip(to_array_iter)
.map(|((string, from), to)| match (string, from, to) {
(Some(string), Some(from), Some(to)) => {
// create a hashmap of [char, index] to change from O(n) to O(1) for from list
Expand Down Expand Up @@ -120,7 +144,7 @@ fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
}
_ => None,
})
.collect::<GenericStringArray<T>>();
.collect::<StringArray>();
devanbenz marked this conversation as resolved.
Show resolved Hide resolved

Ok(Arc::new(result) as ArrayRef)
}
Expand Down
16 changes: 14 additions & 2 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,19 @@ logical_plan
01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4
02)--TableScan: test projection=[column1_utf8view]

### Test TRANSLATE

# Should run TRANSLATE using utf8view column successfully
devanbenz marked this conversation as resolved.
Show resolved Hide resolved
query T
SELECT
TRANSLATE(column1_utf8view, 'foo', 'bar') as c
FROM test;
----
Andrew
Xiangpeng
Raphael
NULL

### Initcap

query TT
Expand Down Expand Up @@ -895,14 +908,13 @@ logical_plan
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for TRANSLATE
## TODO file ticket
query TT
EXPLAIN SELECT
TRANSLATE(column1_utf8view, 'foo', 'bar') as c
FROM test;
----
logical_plan
01)Projection: translate(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Utf8("bar")) AS c
01)Projection: translate(test.column1_utf8view, Utf8("foo"), Utf8("bar")) AS c
devanbenz marked this conversation as resolved.
Show resolved Hide resolved
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for FIND_IN_SET
Expand Down
Loading