From 341939ab0d26e003324e6355c0c4c19f43980567 Mon Sep 17 00:00:00 2001 From: Max Justus Spransy Date: Fri, 29 Nov 2024 15:59:18 -0800 Subject: [PATCH] feat(functions): add jaro_winkler string similarity function --- src/query/functions/src/scalars/other.rs | 196 ++++++++++++++++++ src/query/functions/src/scalars/string.rs | 2 +- .../it/scalars/testdata/function_list.txt | 2 + ...02_0079_function_strings_jaro_winkler.test | 64 ++++++ 4 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 tests/sqllogictests/suites/query/functions/02_0079_function_strings_jaro_winkler.test diff --git a/src/query/functions/src/scalars/other.rs b/src/query/functions/src/scalars/other.rs index 0c4a44c38d55..aad129db4b25 100644 --- a/src/query/functions/src/scalars/other.rs +++ b/src/query/functions/src/scalars/other.rs @@ -48,6 +48,7 @@ use databend_common_expression::types::TimestampType; use databend_common_expression::types::ValueType; use databend_common_expression::vectorize_with_builder_1_arg; use databend_common_expression::vectorize_with_builder_2_arg; +use databend_common_expression::vectorize_2_arg; use databend_common_expression::Column; use databend_common_expression::Domain; use databend_common_expression::EvalContext; @@ -241,6 +242,17 @@ pub fn register(registry: &mut FunctionRegistry) { Value::Column(col) }, ); + + registry + .register_passthrough_nullable_2_arg::( + "jaro_winkler", + |_, _, _| FunctionDomain::Full, + vectorize_2_arg::( + |s1, s2, _ctx| { + jaro_winkler::jaro_winkler(s1, s2).into() + }, + ), + ); } fn register_inet_aton(registry: &mut FunctionRegistry) { @@ -486,3 +498,187 @@ pub fn compute_grouping(cols: &[usize], grouping_id: u32) -> u32 { } grouping } +// +// this implementation comes from https://github.com/joshuaclayton/jaro_winkler +pub(crate) mod jaro_winkler { + #![deny(missing_docs)] + + //! `jaro_winkler` is a crate for calculating Jaro-Winkler distance of two strings. + //! + //! # Examples + //! + //! ``` + //! use jaro_winkler::jaro_winkler; + //! + //! assert_eq!(jaro_winkler("martha", "marhta"), 0.9611111111111111); + //! assert_eq!(jaro_winkler("", "words"), 0.0); + //! assert_eq!(jaro_winkler("same", "same"), 1.0); + //! ``` + + enum DataWrapper { + Vec(Vec), + Bitwise(u128), + } + + impl DataWrapper { + fn build(len: usize) -> Self { + if len <= 128 { + DataWrapper::Bitwise(0) + } else { + let mut internal = Vec::with_capacity(len); + internal.extend(std::iter::repeat(false).take(len)); + DataWrapper::Vec(internal) + } + } + + fn get(&self, idx: usize) -> bool { + match self { + DataWrapper::Vec(v) => v[idx], + DataWrapper::Bitwise(v1) => (v1 >> idx) & 1 == 1, + } + } + + fn set_true(&mut self, idx: usize) { + match self { + DataWrapper::Vec(v) => v[idx] = true, + DataWrapper::Bitwise(v1) => *v1 |= 1 << idx, + } + } + } + + /// Calculates the Jaro-Winkler distance of two strings. + /// + /// The return value is between 0.0 and 1.0, where 1.0 means the strings are equal. + pub fn jaro_winkler(left_: &str, right_: &str) -> f64 { + let llen = left_.len(); + let rlen = right_.len(); + + let (left, right, s1_len, s2_len) = if llen < rlen { + (right_, left_, rlen, llen) + } else { + (left_, right_, llen, rlen) + }; + + match (s1_len, s2_len) { + (0, 0) => return 1.0, + (0, _) | (_, 0) => return 0.0, + (_, _) => (), + } + + if left == right { + return 1.0; + } + + let range = matching_distance(s1_len, s2_len); + let mut s1m = DataWrapper::build(s1_len); + let mut s2m = DataWrapper::build(s2_len); + let mut matching: f64 = 0.0; + let mut transpositions: f64 = 0.0; + let left_as_bytes = left.as_bytes(); + let right_as_bytes = right.as_bytes(); + + for i in 0..s2_len { + let mut j = (i as isize - range as isize).max(0) as usize; + let l = (i + range + 1).min(s1_len); + while j < l { + if right_as_bytes[i] == left_as_bytes[j] && !s1m.get(j) { + s1m.set_true(j); + s2m.set_true(i); + matching += 1.0; + break; + } + + j += 1; + } + } + + if matching == 0.0 { + return 0.0; + } + + let mut l = 0; + + for i in 0..s2_len - 1 { + if s2m.get(i) { + let mut j = l; + + while j < s1_len { + if s1m.get(j) { + l = j + 1; + break; + } + + j += 1; + } + + if right_as_bytes[i] != left_as_bytes[j] { + transpositions += 1.0; + } + } + } + transpositions = (transpositions / 2.0).ceil(); + + let jaro = (matching / (s1_len as f64) + + matching / (s2_len as f64) + + (matching - transpositions) / matching) + / 3.0; + + let prefix_length = left_as_bytes + .iter() + .zip(right_as_bytes) + .take(4) + .take_while(|(l, r)| l == r) + .count() as f64; + + jaro + prefix_length * 0.1 * (1.0 - jaro) + } + + fn matching_distance(s1_len: usize, s2_len: usize) -> usize { + let max = s1_len.max(s2_len) as f32; + ((max / 2.0).floor() - 1.0) as usize + } + +#[cfg(test)] + mod tests { + use super::*; + + #[test] + fn different_is_zero() { + assert_eq!(jaro_winkler("foo", "bar"), 0.0); + } + + #[test] + fn same_is_one() { + assert_eq!(jaro_winkler("foo", "foo"), 1.0); + assert_eq!(jaro_winkler("", ""), 1.0); + } + + #[test] + fn test_hello() { + assert_eq!(jaro_winkler("hell", "hello"), 0.96); + } + + macro_rules! assert_within { + ($x:expr, $y:expr, delta=$d:expr) => { + assert!(($x - $y).abs() <= $d) + }; + } + + #[test] + fn test_boundary() { + let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s"; + let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s"; + let result = jaro_winkler(long_value, longer_value); + assert_within!(result, 0.82, delta = 0.01); + } + + #[test] + fn test_close_to_boundary() { + let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test"; + assert_eq!(long_value.len(), 129); + let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s"; + let result = jaro_winkler(long_value, longer_value); + assert_within!(result, 0.8, delta = 0.001); + } + } +} diff --git a/src/query/functions/src/scalars/string.rs b/src/query/functions/src/scalars/string.rs index ecf7cb11ac6d..72d00c8aad00 100644 --- a/src/query/functions/src/scalars/string.rs +++ b/src/query/functions/src/scalars/string.rs @@ -776,7 +776,7 @@ pub fn register(registry: &mut FunctionRegistry) { output.commit_row(); }, ), - ) + ); } pub(crate) mod soundex { diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index afc725e13450..82deddcce9d0 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -2288,6 +2288,8 @@ Functions overloads: 1 is_string(Variant NULL) :: Boolean NULL 0 is_true(Boolean) :: Boolean 1 is_true(Boolean NULL) :: Boolean +0 jaro_winkler(String, String) :: Float64 +1 jaro_winkler(String NULL, String NULL) :: Float64 NULL 0 jq FACTORY 0 json_array FACTORY 0 json_array_distinct(Variant) :: Variant diff --git a/tests/sqllogictests/suites/query/functions/02_0079_function_strings_jaro_winkler.test b/tests/sqllogictests/suites/query/functions/02_0079_function_strings_jaro_winkler.test new file mode 100644 index 000000000000..7a589a121a09 --- /dev/null +++ b/tests/sqllogictests/suites/query/functions/02_0079_function_strings_jaro_winkler.test @@ -0,0 +1,64 @@ +query T +SELECT jaro_winkler(NULL, 'hello') +---- +NULL + +query T +SELECT jaro_winkler('hello', NULL) +---- +NULL + +query T +SELECT jaro_winkler(NULL, NULL) +---- +NULL + +query T +SELECT jaro_winkler('', '') +---- +1.0 + +query T +SELECT jaro_winkler('hello', 'hello') +---- +1.0 + +query T +SELECT jaro_winkler('hello', 'helo') +---- +0.9533333333333333 + +query T +SELECT jaro_winkler('martha', 'marhta') +---- +0.9611111111111111 + +query T +SELECT jaro_winkler('你好', '你好啊') +---- +0.9333333333333333 + +query T +SELECT jaro_winkler('🦀hello', '🦀helo') +---- +0.9777777777777777 + +query T +SELECT jaro_winkler('dixon', 'dicksonx') +---- +0.8133333333333332 + +query T +SELECT jaro_winkler('duane', 'dwayne') +---- +0.8400000000000001 + +query T +select jaro_winkler('asdf', 'as x c f'); +---- +0.6592592592592592 + +query T +SELECT jaro_winkler('', 'hello') +---- +0.0