-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ScalarValue::eq_array optimized comparison function #844
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,6 +277,31 @@ impl std::hash::Hash for ScalarValue { | |
} | ||
} | ||
|
||
// return the index into the dictionary values for array@index as well | ||
// as a reference to the dictionary values array. Returns None for the | ||
// index if the array is NULL at index | ||
#[inline] | ||
fn get_dict_value<K: ArrowDictionaryKeyType>( | ||
array: &ArrayRef, | ||
index: usize, | ||
) -> Result<(&ArrayRef, Option<usize>)> { | ||
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); | ||
|
||
// look up the index in the values dictionary | ||
let keys_col = dict_array.keys(); | ||
if !keys_col.is_valid(index) { | ||
return Ok((dict_array.values(), None)); | ||
} | ||
let values_index = keys_col.value(index).to_usize().ok_or_else(|| { | ||
DataFusionError::Internal(format!( | ||
"Can not convert index to usize in dictionary of type creating group by value {:?}", | ||
keys_col.data_type() | ||
)) | ||
})?; | ||
|
||
Ok((dict_array.values(), Some(values_index))) | ||
} | ||
|
||
macro_rules! typed_cast { | ||
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ | ||
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); | ||
|
@@ -399,6 +424,17 @@ macro_rules! build_array_from_option { | |
}}; | ||
} | ||
|
||
macro_rules! eq_array_primitive { | ||
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ | ||
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); | ||
let is_valid = array.is_valid($index); | ||
match $VALUE { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we avoid the match if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it could if there was a specific There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filed #850 to track this suggestion |
||
Some(val) => is_valid && &array.value($index) == val, | ||
None => !is_valid, | ||
} | ||
}}; | ||
} | ||
|
||
impl ScalarValue { | ||
/// Getter for the `DataType` of the value | ||
pub fn get_datatype(&self) -> DataType { | ||
|
@@ -942,28 +978,30 @@ impl ScalarValue { | |
DataType::Timestamp(TimeUnit::Nanosecond, _) => { | ||
typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond) | ||
} | ||
DataType::Dictionary(index_type, _) => match **index_type { | ||
DataType::Int8 => Self::try_from_dict_array::<Int8Type>(array, index)?, | ||
DataType::Int16 => Self::try_from_dict_array::<Int16Type>(array, index)?, | ||
DataType::Int32 => Self::try_from_dict_array::<Int32Type>(array, index)?, | ||
DataType::Int64 => Self::try_from_dict_array::<Int64Type>(array, index)?, | ||
DataType::UInt8 => Self::try_from_dict_array::<UInt8Type>(array, index)?, | ||
DataType::UInt16 => { | ||
Self::try_from_dict_array::<UInt16Type>(array, index)? | ||
} | ||
DataType::UInt32 => { | ||
Self::try_from_dict_array::<UInt32Type>(array, index)? | ||
} | ||
DataType::UInt64 => { | ||
Self::try_from_dict_array::<UInt64Type>(array, index)? | ||
} | ||
_ => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Index type not supported while creating scalar from dictionary: {}", | ||
array.data_type(), | ||
))) | ||
DataType::Dictionary(index_type, _) => { | ||
let (values, values_index) = match **index_type { | ||
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?, | ||
DataType::Int16 => get_dict_value::<Int16Type>(array, index)?, | ||
DataType::Int32 => get_dict_value::<Int32Type>(array, index)?, | ||
DataType::Int64 => get_dict_value::<Int64Type>(array, index)?, | ||
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index)?, | ||
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index)?, | ||
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index)?, | ||
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index)?, | ||
_ => { | ||
return Err(DataFusionError::Internal(format!( | ||
"Index type not supported while creating scalar from dictionary: {}", | ||
array.data_type(), | ||
))) | ||
} | ||
}; | ||
|
||
match values_index { | ||
Some(values_index) => Self::try_from_array(values, values_index)?, | ||
// was null | ||
None => values.data_type().try_into()?, | ||
} | ||
}, | ||
} | ||
other => { | ||
return Err(DataFusionError::NotImplemented(format!( | ||
"Can't create a scalar from array of type \"{:?}\"", | ||
|
@@ -973,22 +1011,114 @@ impl ScalarValue { | |
}) | ||
} | ||
|
||
fn try_from_dict_array<K: ArrowDictionaryKeyType>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactored into |
||
/// Compares a single row of array @ index for equality with self, | ||
/// in an optimized fashion. | ||
/// | ||
/// This method implements an optimized version of: | ||
/// | ||
/// ```text | ||
/// let arr_scalar = Self::try_from_array(array, index).unwrap(); | ||
/// arr_scalar.eq(self) | ||
/// ``` | ||
/// | ||
/// *Performance note*: the arrow compute kernels should be | ||
/// preferred over this function if at all possible as they can be | ||
/// vectorized and are generally much faster. | ||
/// | ||
/// This function has a few narrow usescases such as hash table key | ||
/// comparisons where comparing a single row at a time is necessary. | ||
#[inline] | ||
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { | ||
if let DataType::Dictionary(key_type, _) = array.data_type() { | ||
return self.eq_array_dictionary(array, index, key_type); | ||
} | ||
|
||
match self { | ||
ScalarValue::Boolean(val) => { | ||
eq_array_primitive!(array, index, BooleanArray, val) | ||
} | ||
ScalarValue::Float32(val) => { | ||
eq_array_primitive!(array, index, Float32Array, val) | ||
} | ||
ScalarValue::Float64(val) => { | ||
eq_array_primitive!(array, index, Float64Array, val) | ||
} | ||
ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), | ||
ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), | ||
ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), | ||
ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), | ||
ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), | ||
ScalarValue::UInt16(val) => { | ||
eq_array_primitive!(array, index, UInt16Array, val) | ||
} | ||
ScalarValue::UInt32(val) => { | ||
eq_array_primitive!(array, index, UInt32Array, val) | ||
} | ||
ScalarValue::UInt64(val) => { | ||
eq_array_primitive!(array, index, UInt64Array, val) | ||
} | ||
ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), | ||
ScalarValue::LargeUtf8(val) => { | ||
eq_array_primitive!(array, index, LargeStringArray, val) | ||
} | ||
ScalarValue::Binary(val) => { | ||
eq_array_primitive!(array, index, BinaryArray, val) | ||
} | ||
ScalarValue::LargeBinary(val) => { | ||
eq_array_primitive!(array, index, LargeBinaryArray, val) | ||
} | ||
ScalarValue::List(_, _) => unimplemented!(), | ||
ScalarValue::Date32(val) => { | ||
eq_array_primitive!(array, index, Date32Array, val) | ||
} | ||
ScalarValue::Date64(val) => { | ||
eq_array_primitive!(array, index, Date64Array, val) | ||
} | ||
ScalarValue::TimestampSecond(val) => { | ||
eq_array_primitive!(array, index, TimestampSecondArray, val) | ||
} | ||
ScalarValue::TimestampMillisecond(val) => { | ||
eq_array_primitive!(array, index, TimestampMillisecondArray, val) | ||
} | ||
ScalarValue::TimestampMicrosecond(val) => { | ||
eq_array_primitive!(array, index, TimestampMicrosecondArray, val) | ||
} | ||
ScalarValue::TimestampNanosecond(val) => { | ||
eq_array_primitive!(array, index, TimestampNanosecondArray, val) | ||
} | ||
ScalarValue::IntervalYearMonth(val) => { | ||
eq_array_primitive!(array, index, IntervalYearMonthArray, val) | ||
} | ||
ScalarValue::IntervalDayTime(val) => { | ||
eq_array_primitive!(array, index, IntervalDayTimeArray, val) | ||
} | ||
} | ||
} | ||
|
||
/// Compares a dictionary array with indexes of type `key_type` | ||
/// with the array @ index for equality with self | ||
fn eq_array_dictionary( | ||
&self, | ||
array: &ArrayRef, | ||
index: usize, | ||
) -> Result<Self> { | ||
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); | ||
|
||
// look up the index in the values dictionary | ||
// (note validity was previously checked in `try_from_array`) | ||
let keys_col = dict_array.keys(); | ||
let values_index = keys_col.value(index).to_usize().ok_or_else(|| { | ||
DataFusionError::Internal(format!( | ||
"Can not convert index to usize in dictionary of type creating group by value {:?}", | ||
keys_col.data_type() | ||
)) | ||
})?; | ||
Self::try_from_array(dict_array.values(), values_index) | ||
key_type: &DataType, | ||
) -> bool { | ||
let (values, values_index) = match key_type { | ||
DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(), | ||
DataType::Int16 => get_dict_value::<Int16Type>(array, index).unwrap(), | ||
DataType::Int32 => get_dict_value::<Int32Type>(array, index).unwrap(), | ||
DataType::Int64 => get_dict_value::<Int64Type>(array, index).unwrap(), | ||
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index).unwrap(), | ||
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index).unwrap(), | ||
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index).unwrap(), | ||
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index).unwrap(), | ||
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type), | ||
}; | ||
|
||
match values_index { | ||
Some(values_index) => self.eq_array(values, values_index), | ||
None => self.is_null(), | ||
} | ||
} | ||
} | ||
|
||
|
@@ -1654,6 +1784,159 @@ mod tests { | |
assert_eq!(std::mem::size_of::<ScalarValue>(), 32); | ||
} | ||
|
||
#[test] | ||
fn scalar_eq_array() { | ||
// Validate that eq_array has the same semantics as ScalarValue::eq | ||
macro_rules! make_typed_vec { | ||
($INPUT:expr, $TYPE:ident) => {{ | ||
$INPUT | ||
.iter() | ||
.map(|v| v.map(|v| v as $TYPE)) | ||
.collect::<Vec<_>>() | ||
}}; | ||
} | ||
|
||
let bool_vals = vec![Some(true), None, Some(false)]; | ||
let f32_vals = vec![Some(-1.0), None, Some(1.0)]; | ||
let f64_vals = make_typed_vec!(f32_vals, f64); | ||
|
||
let i8_vals = vec![Some(-1), None, Some(1)]; | ||
let i16_vals = make_typed_vec!(i8_vals, i16); | ||
let i32_vals = make_typed_vec!(i8_vals, i32); | ||
let i64_vals = make_typed_vec!(i8_vals, i64); | ||
|
||
let u8_vals = vec![Some(0), None, Some(1)]; | ||
let u16_vals = make_typed_vec!(u8_vals, u16); | ||
let u32_vals = make_typed_vec!(u8_vals, u32); | ||
let u64_vals = make_typed_vec!(u8_vals, u64); | ||
|
||
let str_vals = vec![Some("foo"), None, Some("bar")]; | ||
|
||
/// Test each value in `scalar` with the corresponding element | ||
/// at `array`. Assumes each element is unique (aka not equal | ||
/// with all other indexes) | ||
struct TestCase { | ||
array: ArrayRef, | ||
scalars: Vec<ScalarValue>, | ||
} | ||
|
||
/// Create a test case for casing the input to the specified array type | ||
macro_rules! make_test_case { | ||
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ | ||
TestCase { | ||
array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), | ||
scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), | ||
} | ||
}}; | ||
} | ||
|
||
macro_rules! make_str_test_case { | ||
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ | ||
TestCase { | ||
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), | ||
scalars: $INPUT | ||
.iter() | ||
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) | ||
.collect(), | ||
} | ||
}}; | ||
} | ||
|
||
macro_rules! make_binary_test_case { | ||
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ | ||
TestCase { | ||
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), | ||
scalars: $INPUT | ||
.iter() | ||
.map(|v| { | ||
ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec())) | ||
}) | ||
.collect(), | ||
} | ||
}}; | ||
} | ||
|
||
/// create a test case for DictionaryArray<$INDEX_TY> | ||
macro_rules! make_str_dict_test_case { | ||
($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{ | ||
TestCase { | ||
array: Arc::new( | ||
$INPUT | ||
.iter() | ||
.cloned() | ||
.collect::<DictionaryArray<$INDEX_TY>>(), | ||
), | ||
scalars: $INPUT | ||
.iter() | ||
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) | ||
.collect(), | ||
} | ||
}}; | ||
} | ||
|
||
let cases = vec![ | ||
make_test_case!(bool_vals, BooleanArray, Boolean), | ||
make_test_case!(f32_vals, Float32Array, Float32), | ||
make_test_case!(f64_vals, Float64Array, Float64), | ||
make_test_case!(i8_vals, Int8Array, Int8), | ||
make_test_case!(i16_vals, Int16Array, Int16), | ||
make_test_case!(i32_vals, Int32Array, Int32), | ||
make_test_case!(i64_vals, Int64Array, Int64), | ||
make_test_case!(u8_vals, UInt8Array, UInt8), | ||
make_test_case!(u16_vals, UInt16Array, UInt16), | ||
make_test_case!(u32_vals, UInt32Array, UInt32), | ||
make_test_case!(u64_vals, UInt64Array, UInt64), | ||
make_str_test_case!(str_vals, StringArray, Utf8), | ||
make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), | ||
make_binary_test_case!(str_vals, BinaryArray, Binary), | ||
make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), | ||
make_test_case!(i32_vals, Date32Array, Date32), | ||
make_test_case!(i64_vals, Date64Array, Date64), | ||
make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond), | ||
make_test_case!(i64_vals, TimestampMillisecondArray, TimestampMillisecond), | ||
make_test_case!(i64_vals, TimestampMicrosecondArray, TimestampMicrosecond), | ||
make_test_case!(i64_vals, TimestampNanosecondArray, TimestampNanosecond), | ||
make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth), | ||
make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime), | ||
make_str_dict_test_case!(str_vals, Int8Type, Utf8), | ||
make_str_dict_test_case!(str_vals, Int16Type, Utf8), | ||
make_str_dict_test_case!(str_vals, Int32Type, Utf8), | ||
make_str_dict_test_case!(str_vals, Int64Type, Utf8), | ||
make_str_dict_test_case!(str_vals, UInt8Type, Utf8), | ||
make_str_dict_test_case!(str_vals, UInt16Type, Utf8), | ||
make_str_dict_test_case!(str_vals, UInt32Type, Utf8), | ||
make_str_dict_test_case!(str_vals, UInt64Type, Utf8), | ||
]; | ||
|
||
for case in cases { | ||
let TestCase { array, scalars } = case; | ||
assert_eq!(array.len(), scalars.len()); | ||
|
||
for (index, scalar) in scalars.into_iter().enumerate() { | ||
assert!( | ||
scalar.eq_array(&array, index), | ||
"Expected {:?} to be equal to {:?} at index {}", | ||
scalar, | ||
array, | ||
index | ||
); | ||
|
||
// test that all other elements are *not* equal | ||
for other_index in 0..array.len() { | ||
if index != other_index { | ||
assert!( | ||
!scalar.eq_array(&array, other_index), | ||
"Expected {:?} to be NOT equal to {:?} at index {}", | ||
scalar, | ||
array, | ||
other_index | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[test] | ||
fn scalar_partial_ordering() { | ||
use ScalarValue::*; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apache/arrow-rs#672 proposes adding this upstream in arrow
I think this properly handles
null
values now in the DictionaryArray, whereas y initial version did not