Skip to content

Commit

Permalink
feat(functions): add jaro_winkler string similarity function
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjustus committed Dec 12, 2024
1 parent ab08029 commit a97f916
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 1 deletion.
192 changes: 192 additions & 0 deletions src/query/functions/src/scalars/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use databend_common_expression::types::SimpleDomain;
use databend_common_expression::types::StringType;
use databend_common_expression::types::TimestampType;
use databend_common_expression::types::ValueType;
use databend_common_expression::vectorize_2_arg;
use databend_common_expression::vectorize_with_builder_1_arg;
use databend_common_expression::vectorize_with_builder_2_arg;
use databend_common_expression::Column;
Expand Down Expand Up @@ -241,6 +242,14 @@ pub fn register(registry: &mut FunctionRegistry) {
Value::Column(col)
},
);

registry.register_passthrough_nullable_2_arg::<StringType, StringType, Float64Type, _, _>(
"jaro_winkler",
|_, _, _| FunctionDomain::Full,
vectorize_2_arg::<StringType, StringType, Float64Type>(|s1, s2, _ctx| {
jaro_winkler::jaro_winkler(s1, s2).into()
}),
);
}

fn register_inet_aton(registry: &mut FunctionRegistry) {
Expand Down Expand Up @@ -486,3 +495,186 @@ pub fn compute_grouping(cols: &[usize], grouping_id: u32) -> u32 {
}
grouping
}
// this implementation comes from https://github.com/joshuaclayton/jaro_winkler
pub(crate) mod jaro_winkler {
#![deny(missing_docs)]

//! `jaro_winkler` is a crate for calculating Jaro-Winkler distance of two strings.
//!
//! # Examples
//!
//! ```
//! use jaro_winkler::jaro_winkler;
//!
//! assert_eq!(jaro_winkler("martha", "marhta"), 0.9611111111111111);
//! assert_eq!(jaro_winkler("", "words"), 0.0);
//! assert_eq!(jaro_winkler("same", "same"), 1.0);
//! ```
enum DataWrapper {
Vec(Vec<bool>),
Bitwise(u128),
}

impl DataWrapper {
fn build(len: usize) -> Self {
if len <= 128 {
DataWrapper::Bitwise(0)
} else {
let mut internal = Vec::with_capacity(len);
internal.extend(std::iter::repeat(false).take(len));
DataWrapper::Vec(internal)
}
}

fn get(&self, idx: usize) -> bool {
match self {
DataWrapper::Vec(v) => v[idx],
DataWrapper::Bitwise(v1) => (v1 >> idx) & 1 == 1,
}
}

fn set_true(&mut self, idx: usize) {
match self {
DataWrapper::Vec(v) => v[idx] = true,
DataWrapper::Bitwise(v1) => *v1 |= 1 << idx,
}
}
}

/// Calculates the Jaro-Winkler distance of two strings.
///
/// The return value is between 0.0 and 1.0, where 1.0 means the strings are equal.
pub fn jaro_winkler(left_: &str, right_: &str) -> f64 {
let llen = left_.len();
let rlen = right_.len();

let (left, right, s1_len, s2_len) = if llen < rlen {
(right_, left_, rlen, llen)
} else {
(left_, right_, llen, rlen)
};

match (s1_len, s2_len) {
(0, 0) => return 1.0,
(0, _) | (_, 0) => return 0.0,
(_, _) => (),
}

if left == right {
return 1.0;
}

let range = matching_distance(s1_len, s2_len);
let mut s1m = DataWrapper::build(s1_len);
let mut s2m = DataWrapper::build(s2_len);
let mut matching: f64 = 0.0;
let mut transpositions: f64 = 0.0;
let left_as_bytes = left.as_bytes();
let right_as_bytes = right.as_bytes();

for i in 0..s2_len {
let mut j = (i as isize - range as isize).max(0) as usize;
let l = (i + range + 1).min(s1_len);
while j < l {
if right_as_bytes[i] == left_as_bytes[j] && !s1m.get(j) {
s1m.set_true(j);
s2m.set_true(i);
matching += 1.0;
break;
}

j += 1;
}
}

if matching == 0.0 {
return 0.0;
}

let mut l = 0;

for i in 0..s2_len - 1 {
if s2m.get(i) {
let mut j = l;

while j < s1_len {
if s1m.get(j) {
l = j + 1;
break;
}

j += 1;
}

if right_as_bytes[i] != left_as_bytes[j] {
transpositions += 1.0;
}
}
}
transpositions = (transpositions / 2.0).ceil();

let jaro = (matching / (s1_len as f64)
+ matching / (s2_len as f64)
+ (matching - transpositions) / matching)
/ 3.0;

let prefix_length = left_as_bytes
.iter()
.zip(right_as_bytes)
.take(4)
.take_while(|(l, r)| l == r)
.count() as f64;

jaro + prefix_length * 0.1 * (1.0 - jaro)
}

fn matching_distance(s1_len: usize, s2_len: usize) -> usize {
let max = s1_len.max(s2_len) as f32;
((max / 2.0).floor() - 1.0) as usize
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn different_is_zero() {
assert_eq!(jaro_winkler("foo", "bar"), 0.0);
}

#[test]
fn same_is_one() {
assert_eq!(jaro_winkler("foo", "foo"), 1.0);
assert_eq!(jaro_winkler("", ""), 1.0);
}

#[test]
fn test_hello() {
assert_eq!(jaro_winkler("hell", "hello"), 0.96);
}

macro_rules! assert_within {
($x:expr, $y:expr, delta=$d:expr) => {
assert!(($x - $y).abs() <= $d)
};
}

#[test]
fn test_boundary() {
let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
let result = jaro_winkler(long_value, longer_value);
assert_within!(result, 0.82, delta = 0.01);
}

#[test]
fn test_close_to_boundary() {
let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test";
assert_eq!(long_value.len(), 129);
let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
let result = jaro_winkler(long_value, longer_value);
assert_within!(result, 0.8, delta = 0.001);
}
}
}
2 changes: 1 addition & 1 deletion src/query/functions/src/scalars/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ pub fn register(registry: &mut FunctionRegistry) {
output.commit_row();
},
),
)
);
}

pub(crate) mod soundex {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,8 @@ Functions overloads:
1 is_string(Variant NULL) :: Boolean NULL
0 is_true(Boolean) :: Boolean
1 is_true(Boolean NULL) :: Boolean
0 jaro_winkler(String, String) :: Float64
1 jaro_winkler(String NULL, String NULL) :: Float64 NULL
0 jq FACTORY
0 json_array FACTORY
0 json_array_distinct(Variant) :: Variant
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
query T
SELECT jaro_winkler(NULL, 'hello')
----
NULL

query T
SELECT jaro_winkler('hello', NULL)
----
NULL

query T
SELECT jaro_winkler(NULL, NULL)
----
NULL

query T
SELECT jaro_winkler('', '')
----
1.0

query T
SELECT jaro_winkler('hello', 'hello')
----
1.0

query T
SELECT jaro_winkler('hello', 'helo')
----
0.9533333333333333

query T
SELECT jaro_winkler('martha', 'marhta')
----
0.9611111111111111

query T
SELECT jaro_winkler('你好', '你好啊')
----
0.9333333333333333

query T
SELECT jaro_winkler('🦀hello', '🦀helo')
----
0.9777777777777777

query T
SELECT jaro_winkler('dixon', 'dicksonx')
----
0.8133333333333332

query T
SELECT jaro_winkler('duane', 'dwayne')
----
0.8400000000000001

query T
select jaro_winkler('asdf', 'as x c f');
----
0.6592592592592592

query T
SELECT jaro_winkler('', 'hello')
----
0.0

0 comments on commit a97f916

Please sign in to comment.