forked from apache/datafusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
array_distance
function (apache#12211)
* Add `distance` aggregation function Signed-off-by: Austin Liu <[email protected]> Add `distance` aggregation function Signed-off-by: Austin Liu <[email protected]> * Add sql logic test for `distance` Signed-off-by: Austin Liu <[email protected]> * Simplify diff calculation Signed-off-by: Austin Liu <[email protected]> * Add `array_distance`/`list_distance` as list function in functions-nested Signed-off-by: Austin Liu <[email protected]> * Remove aggregate function `distance` Signed-off-by: Austin Liu <[email protected]> * format Signed-off-by: Austin Liu <[email protected]> * clean up error handling Signed-off-by: Austin Liu <[email protected]> * Add `array_distance` in scalar array functions docs Signed-off-by: Austin Liu <[email protected]> * Update bulletin Signed-off-by: Austin Liu <[email protected]> * Prettify example Signed-off-by: Austin Liu <[email protected]> --------- Signed-off-by: Austin Liu <[email protected]>
- Loading branch information
1 parent
1fce2a9
commit bd50698
Showing
4 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
//! [ScalarUDFImpl] definitions for array_distance function. | ||
use crate::utils::{downcast_arg, make_scalar_function}; | ||
use arrow_array::{ | ||
Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, | ||
}; | ||
use arrow_schema::DataType; | ||
use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; | ||
use core::any::type_name; | ||
use datafusion_common::cast::{ | ||
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, | ||
as_int64_array, | ||
}; | ||
use datafusion_common::DataFusionError; | ||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
make_udf_expr_and_func!( | ||
ArrayDistance, | ||
array_distance, | ||
array, | ||
"returns the Euclidean distance between two numeric arrays.", | ||
array_distance_udf | ||
); | ||
|
||
#[derive(Debug)] | ||
pub(super) struct ArrayDistance { | ||
signature: Signature, | ||
aliases: Vec<String>, | ||
} | ||
|
||
impl ArrayDistance { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::variadic_any(Volatility::Immutable), | ||
aliases: vec!["list_distance".to_string()], | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for ArrayDistance { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"array_distance" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
match arg_types[0] { | ||
List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), | ||
_ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), | ||
} | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
make_scalar_function(array_distance_inner)(args) | ||
} | ||
|
||
fn aliases(&self) -> &[String] { | ||
&self.aliases | ||
} | ||
} | ||
|
||
pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
if args.len() != 2 { | ||
return exec_err!("array_distance expects exactly two arguments"); | ||
} | ||
|
||
match (&args[0].data_type(), &args[1].data_type()) { | ||
(List(_), List(_)) => general_array_distance::<i32>(args), | ||
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args), | ||
(array_type1, array_type2) => { | ||
exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'") | ||
} | ||
} | ||
} | ||
|
||
fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> { | ||
let list_array1 = as_generic_list_array::<O>(&arrays[0])?; | ||
let list_array2 = as_generic_list_array::<O>(&arrays[1])?; | ||
|
||
let result = list_array1 | ||
.iter() | ||
.zip(list_array2.iter()) | ||
.map(|(arr1, arr2)| compute_array_distance(arr1, arr2)) | ||
.collect::<Result<Float64Array>>()?; | ||
|
||
Ok(Arc::new(result) as ArrayRef) | ||
} | ||
|
||
/// Computes the Euclidean distance between two arrays | ||
fn compute_array_distance( | ||
arr1: Option<ArrayRef>, | ||
arr2: Option<ArrayRef>, | ||
) -> Result<Option<f64>> { | ||
let value1 = match arr1 { | ||
Some(arr) => arr, | ||
None => return Ok(None), | ||
}; | ||
let value2 = match arr2 { | ||
Some(arr) => arr, | ||
None => return Ok(None), | ||
}; | ||
|
||
let mut value1 = value1; | ||
let mut value2 = value2; | ||
|
||
loop { | ||
match value1.data_type() { | ||
List(_) => { | ||
if downcast_arg!(value1, ListArray).null_count() > 0 { | ||
return Ok(None); | ||
} | ||
value1 = downcast_arg!(value1, ListArray).value(0); | ||
} | ||
LargeList(_) => { | ||
if downcast_arg!(value1, LargeListArray).null_count() > 0 { | ||
return Ok(None); | ||
} | ||
value1 = downcast_arg!(value1, LargeListArray).value(0); | ||
} | ||
_ => break, | ||
} | ||
|
||
match value2.data_type() { | ||
List(_) => { | ||
if downcast_arg!(value2, ListArray).null_count() > 0 { | ||
return Ok(None); | ||
} | ||
value2 = downcast_arg!(value2, ListArray).value(0); | ||
} | ||
LargeList(_) => { | ||
if downcast_arg!(value2, LargeListArray).null_count() > 0 { | ||
return Ok(None); | ||
} | ||
value2 = downcast_arg!(value2, LargeListArray).value(0); | ||
} | ||
_ => break, | ||
} | ||
} | ||
|
||
// Check for NULL values inside the arrays | ||
if value1.null_count() != 0 || value2.null_count() != 0 { | ||
return Ok(None); | ||
} | ||
|
||
let values1 = convert_to_f64_array(&value1)?; | ||
let values2 = convert_to_f64_array(&value2)?; | ||
|
||
if values1.len() != values2.len() { | ||
return exec_err!("Both arrays must have the same length"); | ||
} | ||
|
||
let sum_squares: f64 = values1 | ||
.iter() | ||
.zip(values2.iter()) | ||
.map(|(v1, v2)| { | ||
let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0); | ||
diff * diff | ||
}) | ||
.sum(); | ||
|
||
Ok(Some(sum_squares.sqrt())) | ||
} | ||
|
||
/// Converts an array of any numeric type to a Float64Array. | ||
fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> { | ||
match array.data_type() { | ||
DataType::Float64 => Ok(as_float64_array(array)?.clone()), | ||
DataType::Float32 => { | ||
let array = as_float32_array(array)?; | ||
let converted: Float64Array = | ||
array.iter().map(|v| v.map(|v| v as f64)).collect(); | ||
Ok(converted) | ||
} | ||
DataType::Int64 => { | ||
let array = as_int64_array(array)?; | ||
let converted: Float64Array = | ||
array.iter().map(|v| v.map(|v| v as f64)).collect(); | ||
Ok(converted) | ||
} | ||
DataType::Int32 => { | ||
let array = as_int32_array(array)?; | ||
let converted: Float64Array = | ||
array.iter().map(|v| v.map(|v| v as f64)).collect(); | ||
Ok(converted) | ||
} | ||
_ => exec_err!("Unsupported array type for conversion to Float64Array"), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters