diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 3fbcadd3de5a..86d17654c060 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -277,6 +277,31 @@ impl std::hash::Hash for ScalarValue { } } +// return the index into the dictionary values for array@index as well +// as a reference to the dictionary values array. Returns None for the +// index if the array is NULL at index +#[inline] +fn get_dict_value( + array: &ArrayRef, + index: usize, +) -> Result<(&ArrayRef, Option)> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_array.keys(); + if !keys_col.is_valid(index) { + return Ok((dict_array.values(), None)); + } + let values_index = keys_col.value(index).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + Ok((dict_array.values(), Some(values_index))) +} + macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); @@ -399,6 +424,17 @@ macro_rules! build_array_from_option { }}; } +macro_rules! eq_array_primitive { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + let is_valid = array.is_valid($index); + match $VALUE { + Some(val) => is_valid && &array.value($index) == val, + None => !is_valid, + } + }}; +} + impl ScalarValue { /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { @@ -942,28 +978,30 @@ impl ScalarValue { DataType::Timestamp(TimeUnit::Nanosecond, _) => { typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond) } - DataType::Dictionary(index_type, _) => match **index_type { - DataType::Int8 => Self::try_from_dict_array::(array, index)?, - DataType::Int16 => Self::try_from_dict_array::(array, index)?, - DataType::Int32 => Self::try_from_dict_array::(array, index)?, - DataType::Int64 => Self::try_from_dict_array::(array, index)?, - DataType::UInt8 => Self::try_from_dict_array::(array, index)?, - DataType::UInt16 => { - Self::try_from_dict_array::(array, index)? - } - DataType::UInt32 => { - Self::try_from_dict_array::(array, index)? - } - DataType::UInt64 => { - Self::try_from_dict_array::(array, index)? - } - _ => { - return Err(DataFusionError::Internal(format!( - "Index type not supported while creating scalar from dictionary: {}", - array.data_type(), - ))) + DataType::Dictionary(index_type, _) => { + let (values, values_index) = match **index_type { + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, + _ => { + return Err(DataFusionError::Internal(format!( + "Index type not supported while creating scalar from dictionary: {}", + array.data_type(), + ))) + } + }; + + match values_index { + Some(values_index) => Self::try_from_array(values, values_index)?, + // was null + None => values.data_type().try_into()?, } - }, + } other => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from array of type \"{:?}\"", @@ -973,22 +1011,114 @@ impl ScalarValue { }) } - fn try_from_dict_array( + /// Compares a single row of array @ index for equality with self, + /// in an optimized fashion. + /// + /// This method implements an optimized version of: + /// + /// ```text + /// let arr_scalar = Self::try_from_array(array, index).unwrap(); + /// arr_scalar.eq(self) + /// ``` + /// + /// *Performance note*: the arrow compute kernels should be + /// preferred over this function if at all possible as they can be + /// vectorized and are generally much faster. + /// + /// This function has a few narrow usescases such as hash table key + /// comparisons where comparing a single row at a time is necessary. + #[inline] + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { + if let DataType::Dictionary(key_type, _) = array.data_type() { + return self.eq_array_dictionary(array, index, key_type); + } + + match self { + ScalarValue::Boolean(val) => { + eq_array_primitive!(array, index, BooleanArray, val) + } + ScalarValue::Float32(val) => { + eq_array_primitive!(array, index, Float32Array, val) + } + ScalarValue::Float64(val) => { + eq_array_primitive!(array, index, Float64Array, val) + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), + ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), + ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), + ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), + ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), + ScalarValue::UInt16(val) => { + eq_array_primitive!(array, index, UInt16Array, val) + } + ScalarValue::UInt32(val) => { + eq_array_primitive!(array, index, UInt32Array, val) + } + ScalarValue::UInt64(val) => { + eq_array_primitive!(array, index, UInt64Array, val) + } + ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), + ScalarValue::LargeUtf8(val) => { + eq_array_primitive!(array, index, LargeStringArray, val) + } + ScalarValue::Binary(val) => { + eq_array_primitive!(array, index, BinaryArray, val) + } + ScalarValue::LargeBinary(val) => { + eq_array_primitive!(array, index, LargeBinaryArray, val) + } + ScalarValue::List(_, _) => unimplemented!(), + ScalarValue::Date32(val) => { + eq_array_primitive!(array, index, Date32Array, val) + } + ScalarValue::Date64(val) => { + eq_array_primitive!(array, index, Date64Array, val) + } + ScalarValue::TimestampSecond(val) => { + eq_array_primitive!(array, index, TimestampSecondArray, val) + } + ScalarValue::TimestampMillisecond(val) => { + eq_array_primitive!(array, index, TimestampMillisecondArray, val) + } + ScalarValue::TimestampMicrosecond(val) => { + eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + } + ScalarValue::TimestampNanosecond(val) => { + eq_array_primitive!(array, index, TimestampNanosecondArray, val) + } + ScalarValue::IntervalYearMonth(val) => { + eq_array_primitive!(array, index, IntervalYearMonthArray, val) + } + ScalarValue::IntervalDayTime(val) => { + eq_array_primitive!(array, index, IntervalDayTimeArray, val) + } + } + } + + /// Compares a dictionary array with indexes of type `key_type` + /// with the array @ index for equality with self + fn eq_array_dictionary( + &self, array: &ArrayRef, index: usize, - ) -> Result { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // look up the index in the values dictionary - // (note validity was previously checked in `try_from_array`) - let keys_col = dict_array.keys(); - let values_index = keys_col.value(index).to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert index to usize in dictionary of type creating group by value {:?}", - keys_col.data_type() - )) - })?; - Self::try_from_array(dict_array.values(), values_index) + key_type: &DataType, + ) -> bool { + let (values, values_index) = match key_type { + DataType::Int8 => get_dict_value::(array, index).unwrap(), + DataType::Int16 => get_dict_value::(array, index).unwrap(), + DataType::Int32 => get_dict_value::(array, index).unwrap(), + DataType::Int64 => get_dict_value::(array, index).unwrap(), + DataType::UInt8 => get_dict_value::(array, index).unwrap(), + DataType::UInt16 => get_dict_value::(array, index).unwrap(), + DataType::UInt32 => get_dict_value::(array, index).unwrap(), + DataType::UInt64 => get_dict_value::(array, index).unwrap(), + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + }; + + match values_index { + Some(values_index) => self.eq_array(values, values_index), + None => self.is_null(), + } } } @@ -1654,6 +1784,159 @@ mod tests { assert_eq!(std::mem::size_of::(), 32); } + #[test] + fn scalar_eq_array() { + // Validate that eq_array has the same semantics as ScalarValue::eq + macro_rules! make_typed_vec { + ($INPUT:expr, $TYPE:ident) => {{ + $INPUT + .iter() + .map(|v| v.map(|v| v as $TYPE)) + .collect::>() + }}; + } + + let bool_vals = vec![Some(true), None, Some(false)]; + let f32_vals = vec![Some(-1.0), None, Some(1.0)]; + let f64_vals = make_typed_vec!(f32_vals, f64); + + let i8_vals = vec![Some(-1), None, Some(1)]; + let i16_vals = make_typed_vec!(i8_vals, i16); + let i32_vals = make_typed_vec!(i8_vals, i32); + let i64_vals = make_typed_vec!(i8_vals, i64); + + let u8_vals = vec![Some(0), None, Some(1)]; + let u16_vals = make_typed_vec!(u8_vals, u16); + let u32_vals = make_typed_vec!(u8_vals, u32); + let u64_vals = make_typed_vec!(u8_vals, u64); + + let str_vals = vec![Some("foo"), None, Some("bar")]; + + /// Test each value in `scalar` with the corresponding element + /// at `array`. Assumes each element is unique (aka not equal + /// with all other indexes) + struct TestCase { + array: ArrayRef, + scalars: Vec, + } + + /// Create a test case for casing the input to the specified array type + macro_rules! make_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + } + }}; + } + + macro_rules! make_str_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + + macro_rules! make_binary_test_case { + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| { + ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec())) + }) + .collect(), + } + }}; + } + + /// create a test case for DictionaryArray<$INDEX_TY> + macro_rules! make_str_dict_test_case { + ($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{ + TestCase { + array: Arc::new( + $INPUT + .iter() + .cloned() + .collect::>(), + ), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string()))) + .collect(), + } + }}; + } + + let cases = vec![ + make_test_case!(bool_vals, BooleanArray, Boolean), + make_test_case!(f32_vals, Float32Array, Float32), + make_test_case!(f64_vals, Float64Array, Float64), + make_test_case!(i8_vals, Int8Array, Int8), + make_test_case!(i16_vals, Int16Array, Int16), + make_test_case!(i32_vals, Int32Array, Int32), + make_test_case!(i64_vals, Int64Array, Int64), + make_test_case!(u8_vals, UInt8Array, UInt8), + make_test_case!(u16_vals, UInt16Array, UInt16), + make_test_case!(u32_vals, UInt32Array, UInt32), + make_test_case!(u64_vals, UInt64Array, UInt64), + make_str_test_case!(str_vals, StringArray, Utf8), + make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), + make_binary_test_case!(str_vals, BinaryArray, Binary), + make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), + make_test_case!(i32_vals, Date32Array, Date32), + make_test_case!(i64_vals, Date64Array, Date64), + make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond), + make_test_case!(i64_vals, TimestampMillisecondArray, TimestampMillisecond), + make_test_case!(i64_vals, TimestampMicrosecondArray, TimestampMicrosecond), + make_test_case!(i64_vals, TimestampNanosecondArray, TimestampNanosecond), + make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth), + make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime), + make_str_dict_test_case!(str_vals, Int8Type, Utf8), + make_str_dict_test_case!(str_vals, Int16Type, Utf8), + make_str_dict_test_case!(str_vals, Int32Type, Utf8), + make_str_dict_test_case!(str_vals, Int64Type, Utf8), + make_str_dict_test_case!(str_vals, UInt8Type, Utf8), + make_str_dict_test_case!(str_vals, UInt16Type, Utf8), + make_str_dict_test_case!(str_vals, UInt32Type, Utf8), + make_str_dict_test_case!(str_vals, UInt64Type, Utf8), + ]; + + for case in cases { + let TestCase { array, scalars } = case; + assert_eq!(array.len(), scalars.len()); + + for (index, scalar) in scalars.into_iter().enumerate() { + assert!( + scalar.eq_array(&array, index), + "Expected {:?} to be equal to {:?} at index {}", + scalar, + array, + index + ); + + // test that all other elements are *not* equal + for other_index in 0..array.len() { + if index != other_index { + assert!( + !scalar.eq_array(&array, other_index), + "Expected {:?} to be NOT equal to {:?} at index {}", + scalar, + array, + other_index + ); + } + } + } + } + } + #[test] fn scalar_partial_ordering() { use ScalarValue::*;