Skip to content

Commit

Permalink
feat:implement calcite style 'levenshtein' string function (#8168)
Browse files Browse the repository at this point in the history
* feat:implement calcite style 'levenshtein' string function

* format doc style

* cargo lock
  • Loading branch information
Syleechan authored Nov 17, 2023
1 parent 1e6ff64 commit 7618e4d
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 4 deletions.
12 changes: 12 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction {
ArrowTypeof,
/// overlay
OverLay,
/// levenshtein
Levenshtein,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -464,6 +466,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -829,6 +832,10 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "overlay")
}

BuiltinScalarFunction::Levenshtein => {
utf8_to_int_type(&input_expr_types[0], "levenshtein")
}

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Levenshtein => Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
self.volatility(),
),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Trim => &["trim"],
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],
BuiltinScalarFunction::Levenshtein => &["levenshtein"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ scalar_expr!(
);

scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");

scalar_expr!(
Struct,
Expand Down Expand Up @@ -1195,6 +1196,7 @@ mod test {
test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
}

#[test]
Expand Down
13 changes: 13 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,19 @@ pub fn create_physical_fun(
"Unsupported data type {other:?} for function overlay",
))),
}),
BuiltinScalarFunction::Levenshtein => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::levenshtein::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::levenshtein::<i64>)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function levenshtein",
))),
})
}
})
}

Expand Down
67 changes: 65 additions & 2 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
use arrow::{
array::{
Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait,
StringArray,
Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array,
OffsetSizeTrait, StringArray,
},
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::{
cast::{
as_generic_string_array, as_int64_array, as_primitive_array, as_string_array,
Expand Down Expand Up @@ -643,12 +644,59 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

///Returns the Levenshtein distance between the two given strings.
/// LEVENSHTEIN('kitten', 'sitting') = 3
pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"levenshtein function requires two arguments, got {}",
args.len()
)));
}
let str1_array = as_generic_string_array::<T>(&args[0])?;
let str2_array = as_generic_string_array::<T>(&args[1])?;
match args[0].data_type() {
DataType::Utf8 => {
let result = str1_array
.iter()
.zip(str2_array.iter())
.map(|(string1, string2)| match (string1, string2) {
(Some(string1), Some(string2)) => {
Some(datafusion_strsim::levenshtein(string1, string2) as i32)
}
_ => None,
})
.collect::<Int32Array>();
Ok(Arc::new(result) as ArrayRef)
}
DataType::LargeUtf8 => {
let result = str1_array
.iter()
.zip(str2_array.iter())
.map(|(string1, string2)| match (string1, string2) {
(Some(string1), Some(string2)) => {
Some(datafusion_strsim::levenshtein(string1, string2) as i64)
}
_ => None,
})
.collect::<Int64Array>();
Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
)
}
}
}

#[cfg(test)]
mod tests {

use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
use arrow_array::Int64Array;
use datafusion_common::cast::as_int32_array;

use super::*;

Expand Down Expand Up @@ -707,4 +755,19 @@ mod tests {

Ok(())
}

#[test]
fn to_levenshtein() -> Result<()> {
let string1_array =
Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
let string2_array =
Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
let result =
as_int32_array(&res).expect("failed to initialized function levenshtein");
let expected = Int32Array::from(vec![2, 3, 2, 3]);
assert_eq!(&expected, result);

Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ enum ScalarFunction {
OverLay = 121;
Range = 122;
ArrayPopFront = 123;
Levenshtein = 124;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use datafusion_expr::{
date_part, date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left,
ln, log, log10, log2,
levenshtein, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
Expand Down Expand Up @@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
ScalarFunction::Levenshtein => Self::Levenshtein,
}
}
}
Expand Down Expand Up @@ -1630,6 +1631,10 @@ pub fn parse_expr(
))
}
}
ScalarFunction::Levenshtein => Ok(levenshtein(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)),
ScalarFunction::ToTimestampMillis => {
Ok(to_timestamp_millis(parse_expr(&args[0], registry)?))
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
};

Ok(scalar_function)
Expand Down
22 changes: 21 additions & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES
(1, 'OldBrand Product 1', 19.99),
(2, 'OldBrand Product 2', 29.99),
(3, 'OldBrand Product 3', 39.99),
(4, 'OldBrand Product 4', 49.99)
(4, 'OldBrand Product 4', 49.99)

query ITR
SELECT * REPLACE (price*2 AS price) FROM products
Expand Down Expand Up @@ -857,3 +857,23 @@ NULL
NULL
Thomxas
NULL

query I
SELECT levenshtein('kitten', 'sitting')
----
3

query I
SELECT levenshtein('kitten', NULL)
----
NULL

query ?
SELECT levenshtein(NULL, 'sitting')
----
NULL

query ?
SELECT levenshtein(NULL, NULL)
----
NULL
15 changes: 15 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ nullif(expression1, expression2)
- [upper](#upper)
- [uuid](#uuid)
- [overlay](#overlay)
- [levenshtein](#levenshtein)

### `ascii`

Expand Down Expand Up @@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count])
- **pos**: the start position to replace of str.
- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead.

### `levenshtein`

Returns the Levenshtein distance between the two given strings.
For example, `levenshtein('kitten', 'sitting') = 3`

```
levenshtein(str1, str2)
```

#### Arguments

- **str1**: String expression to compute Levenshtein distance with str2.
- **str2**: String expression to compute Levenshtein distance with str1.

## Binary String Functions

- [decode](#decode)
Expand Down

0 comments on commit 7618e4d

Please sign in to comment.