Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalarValue::eq_array optimized comparison function #844

Merged
merged 4 commits into from
Aug 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 318 additions & 35 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,31 @@ impl std::hash::Hash for ScalarValue {
}
}

// return the index into the dictionary values for array@index as well
// as a reference to the dictionary values array. Returns None for the
// index if the array is NULL at index
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apache/arrow-rs#672 proposes adding this upstream in arrow

I think this properly handles null values now in the DictionaryArray, whereas y initial version did not

#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> Result<(&ArrayRef, Option<usize>)> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_array.keys();
if !keys_col.is_valid(index) {
return Ok((dict_array.values(), None));
}
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;

Ok((dict_array.values(), Some(values_index)))
}

macro_rules! typed_cast {
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
Expand Down Expand Up @@ -399,6 +424,17 @@ macro_rules! build_array_from_option {
}};
}

macro_rules! eq_array_primitive {
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let is_valid = array.is_valid($index);
match $VALUE {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid the match if !is_valid? Would that make any difference to performance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could if there was a specific eq_array implementation for non-null arrays.
On most kernels / code, this has a non-negligible impact on performance.
The code path in the hash aggregate could then check whether the array contains 0 nulls and choose a different implementation if this is the case.
I think at this moment it might not have that much of an impact, maybe for the "easier" hash-aggregates with only few groups at might have a higher relative impact.

Copy link
Contributor Author

@alamb alamb Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #850 to track this suggestion

Some(val) => is_valid && &array.value($index) == val,
None => !is_valid,
}
}};
}

impl ScalarValue {
/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
Expand Down Expand Up @@ -942,28 +978,30 @@ impl ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond)
}
DataType::Dictionary(index_type, _) => match **index_type {
DataType::Int8 => Self::try_from_dict_array::<Int8Type>(array, index)?,
DataType::Int16 => Self::try_from_dict_array::<Int16Type>(array, index)?,
DataType::Int32 => Self::try_from_dict_array::<Int32Type>(array, index)?,
DataType::Int64 => Self::try_from_dict_array::<Int64Type>(array, index)?,
DataType::UInt8 => Self::try_from_dict_array::<UInt8Type>(array, index)?,
DataType::UInt16 => {
Self::try_from_dict_array::<UInt16Type>(array, index)?
}
DataType::UInt32 => {
Self::try_from_dict_array::<UInt32Type>(array, index)?
}
DataType::UInt64 => {
Self::try_from_dict_array::<UInt64Type>(array, index)?
}
_ => {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
)))
DataType::Dictionary(index_type, _) => {
let (values, values_index) = match **index_type {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
DataType::Int16 => get_dict_value::<Int16Type>(array, index)?,
DataType::Int32 => get_dict_value::<Int32Type>(array, index)?,
DataType::Int64 => get_dict_value::<Int64Type>(array, index)?,
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index)?,
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index)?,
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index)?,
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index)?,
_ => {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
)))
}
};

match values_index {
Some(values_index) => Self::try_from_array(values, values_index)?,
// was null
None => values.data_type().try_into()?,
}
},
}
other => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from array of type \"{:?}\"",
Expand All @@ -973,22 +1011,114 @@ impl ScalarValue {
})
}

fn try_from_dict_array<K: ArrowDictionaryKeyType>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored into get_dict_value

/// Compares a single row of array @ index for equality with self,
/// in an optimized fashion.
///
/// This method implements an optimized version of:
///
/// ```text
/// let arr_scalar = Self::try_from_array(array, index).unwrap();
/// arr_scalar.eq(self)
/// ```
///
/// *Performance note*: the arrow compute kernels should be
/// preferred over this function if at all possible as they can be
/// vectorized and are generally much faster.
///
/// This function has a few narrow usescases such as hash table key
/// comparisons where comparing a single row at a time is necessary.
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool {
if let DataType::Dictionary(key_type, _) = array.data_type() {
return self.eq_array_dictionary(array, index, key_type);
}

match self {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)
}
ScalarValue::Float64(val) => {
eq_array_primitive!(array, index, Float64Array, val)
}
ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val),
ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val),
ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val),
ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val),
ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val),
ScalarValue::UInt16(val) => {
eq_array_primitive!(array, index, UInt16Array, val)
}
ScalarValue::UInt32(val) => {
eq_array_primitive!(array, index, UInt32Array, val)
}
ScalarValue::UInt64(val) => {
eq_array_primitive!(array, index, UInt64Array, val)
}
ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val),
ScalarValue::LargeUtf8(val) => {
eq_array_primitive!(array, index, LargeStringArray, val)
}
ScalarValue::Binary(val) => {
eq_array_primitive!(array, index, BinaryArray, val)
}
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
}
ScalarValue::List(_, _) => unimplemented!(),
ScalarValue::Date32(val) => {
eq_array_primitive!(array, index, Date32Array, val)
}
ScalarValue::Date64(val) => {
eq_array_primitive!(array, index, Date64Array, val)
}
ScalarValue::TimestampSecond(val) => {
eq_array_primitive!(array, index, TimestampSecondArray, val)
}
ScalarValue::TimestampMillisecond(val) => {
eq_array_primitive!(array, index, TimestampMillisecondArray, val)
}
ScalarValue::TimestampMicrosecond(val) => {
eq_array_primitive!(array, index, TimestampMicrosecondArray, val)
}
ScalarValue::TimestampNanosecond(val) => {
eq_array_primitive!(array, index, TimestampNanosecondArray, val)
}
ScalarValue::IntervalYearMonth(val) => {
eq_array_primitive!(array, index, IntervalYearMonthArray, val)
}
ScalarValue::IntervalDayTime(val) => {
eq_array_primitive!(array, index, IntervalDayTimeArray, val)
}
}
}

/// Compares a dictionary array with indexes of type `key_type`
/// with the array @ index for equality with self
fn eq_array_dictionary(
&self,
array: &ArrayRef,
index: usize,
) -> Result<Self> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
// (note validity was previously checked in `try_from_array`)
let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;
Self::try_from_array(dict_array.values(), values_index)
key_type: &DataType,
) -> bool {
let (values, values_index) = match key_type {
DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(),
DataType::Int16 => get_dict_value::<Int16Type>(array, index).unwrap(),
DataType::Int32 => get_dict_value::<Int32Type>(array, index).unwrap(),
DataType::Int64 => get_dict_value::<Int64Type>(array, index).unwrap(),
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index).unwrap(),
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index).unwrap(),
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index).unwrap(),
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index).unwrap(),
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
};

match values_index {
Some(values_index) => self.eq_array(values, values_index),
None => self.is_null(),
}
}
}

Expand Down Expand Up @@ -1654,6 +1784,159 @@ mod tests {
assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
}

#[test]
fn scalar_eq_array() {
// Validate that eq_array has the same semantics as ScalarValue::eq
macro_rules! make_typed_vec {
($INPUT:expr, $TYPE:ident) => {{
$INPUT
.iter()
.map(|v| v.map(|v| v as $TYPE))
.collect::<Vec<_>>()
}};
}

let bool_vals = vec![Some(true), None, Some(false)];
let f32_vals = vec![Some(-1.0), None, Some(1.0)];
let f64_vals = make_typed_vec!(f32_vals, f64);

let i8_vals = vec![Some(-1), None, Some(1)];
let i16_vals = make_typed_vec!(i8_vals, i16);
let i32_vals = make_typed_vec!(i8_vals, i32);
let i64_vals = make_typed_vec!(i8_vals, i64);

let u8_vals = vec![Some(0), None, Some(1)];
let u16_vals = make_typed_vec!(u8_vals, u16);
let u32_vals = make_typed_vec!(u8_vals, u32);
let u64_vals = make_typed_vec!(u8_vals, u64);

let str_vals = vec![Some("foo"), None, Some("bar")];

/// Test each value in `scalar` with the corresponding element
/// at `array`. Assumes each element is unique (aka not equal
/// with all other indexes)
struct TestCase {
array: ArrayRef,
scalars: Vec<ScalarValue>,
}

/// Create a test case for casing the input to the specified array type
macro_rules! make_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()),
scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(),
}
}};
}

macro_rules! make_str_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

macro_rules! make_binary_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| {
ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec()))
})
.collect(),
}
}};
}

/// create a test case for DictionaryArray<$INDEX_TY>
macro_rules! make_str_dict_test_case {
($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new(
$INPUT
.iter()
.cloned()
.collect::<DictionaryArray<$INDEX_TY>>(),
),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

let cases = vec![
make_test_case!(bool_vals, BooleanArray, Boolean),
make_test_case!(f32_vals, Float32Array, Float32),
make_test_case!(f64_vals, Float64Array, Float64),
make_test_case!(i8_vals, Int8Array, Int8),
make_test_case!(i16_vals, Int16Array, Int16),
make_test_case!(i32_vals, Int32Array, Int32),
make_test_case!(i64_vals, Int64Array, Int64),
make_test_case!(u8_vals, UInt8Array, UInt8),
make_test_case!(u16_vals, UInt16Array, UInt16),
make_test_case!(u32_vals, UInt32Array, UInt32),
make_test_case!(u64_vals, UInt64Array, UInt64),
make_str_test_case!(str_vals, StringArray, Utf8),
make_str_test_case!(str_vals, LargeStringArray, LargeUtf8),
make_binary_test_case!(str_vals, BinaryArray, Binary),
make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary),
make_test_case!(i32_vals, Date32Array, Date32),
make_test_case!(i64_vals, Date64Array, Date64),
make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond),
make_test_case!(i64_vals, TimestampMillisecondArray, TimestampMillisecond),
make_test_case!(i64_vals, TimestampMicrosecondArray, TimestampMicrosecond),
make_test_case!(i64_vals, TimestampNanosecondArray, TimestampNanosecond),
make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth),
make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime),
make_str_dict_test_case!(str_vals, Int8Type, Utf8),
make_str_dict_test_case!(str_vals, Int16Type, Utf8),
make_str_dict_test_case!(str_vals, Int32Type, Utf8),
make_str_dict_test_case!(str_vals, Int64Type, Utf8),
make_str_dict_test_case!(str_vals, UInt8Type, Utf8),
make_str_dict_test_case!(str_vals, UInt16Type, Utf8),
make_str_dict_test_case!(str_vals, UInt32Type, Utf8),
make_str_dict_test_case!(str_vals, UInt64Type, Utf8),
];

for case in cases {
let TestCase { array, scalars } = case;
assert_eq!(array.len(), scalars.len());

for (index, scalar) in scalars.into_iter().enumerate() {
assert!(
scalar.eq_array(&array, index),
"Expected {:?} to be equal to {:?} at index {}",
scalar,
array,
index
);

// test that all other elements are *not* equal
for other_index in 0..array.len() {
if index != other_index {
assert!(
!scalar.eq_array(&array, other_index),
"Expected {:?} to be NOT equal to {:?} at index {}",
scalar,
array,
other_index
);
}
}
}
}
}

#[test]
fn scalar_partial_ordering() {
use ScalarValue::*;
Expand Down