Skip to content

Commit

Permalink
Add ScalarValue::eq_array
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Aug 9, 2021
1 parent 4ddd2f5 commit 832803c
Showing 1 changed file with 310 additions and 35 deletions.
345 changes: 310 additions & 35 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,31 @@ impl std::hash::Hash for ScalarValue {
}
}

// return the index into the dictionary values for array@index as well
// as a reference to the dictionary values array. Returns None for the
// index if the array is NULL at index
#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> Result<(&ArrayRef, Option<usize>)> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_array.keys();
if !keys_col.is_valid(index) {
return Ok((dict_array.values(), None));
}
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;

Ok((dict_array.values(), Some(values_index)))
}

macro_rules! typed_cast {
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
Expand Down Expand Up @@ -399,6 +424,17 @@ macro_rules! build_array_from_option {
}};
}

macro_rules! eq_array_primitive {
($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let is_valid = array.is_valid($index);
match $VALUE {
Some(val) => is_valid && &array.value($index) == val,
None => !is_valid,
}
}};
}

impl ScalarValue {
/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
Expand Down Expand Up @@ -942,28 +978,30 @@ impl ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond)
}
DataType::Dictionary(index_type, _) => match **index_type {
DataType::Int8 => Self::try_from_dict_array::<Int8Type>(array, index)?,
DataType::Int16 => Self::try_from_dict_array::<Int16Type>(array, index)?,
DataType::Int32 => Self::try_from_dict_array::<Int32Type>(array, index)?,
DataType::Int64 => Self::try_from_dict_array::<Int64Type>(array, index)?,
DataType::UInt8 => Self::try_from_dict_array::<UInt8Type>(array, index)?,
DataType::UInt16 => {
Self::try_from_dict_array::<UInt16Type>(array, index)?
}
DataType::UInt32 => {
Self::try_from_dict_array::<UInt32Type>(array, index)?
}
DataType::UInt64 => {
Self::try_from_dict_array::<UInt64Type>(array, index)?
}
_ => {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
)))
DataType::Dictionary(index_type, _) => {
let (values, values_index) = match **index_type {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
DataType::Int16 => get_dict_value::<Int16Type>(array, index)?,
DataType::Int32 => get_dict_value::<Int32Type>(array, index)?,
DataType::Int64 => get_dict_value::<Int64Type>(array, index)?,
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index)?,
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index)?,
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index)?,
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index)?,
_ => {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
)))
}
};

match values_index {
Some(values_index) => Self::try_from_array(values, values_index)?,
// was null
None => values.data_type().try_into()?,
}
},
}
other => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from array of type \"{:?}\"",
Expand All @@ -973,22 +1011,106 @@ impl ScalarValue {
})
}

fn try_from_dict_array<K: ArrowDictionaryKeyType>(
/// Compares array @ index for equality with self, in an optimized fashion
///
/// This method implements an optimized version of:
///
/// ``text
/// let arr_scalar = Self::try_from_array(array, index).unwrap();
/// arr_scalar.eq(self)
/// ```
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool {
if let DataType::Dictionary(key_type, _) = array.data_type() {
return self.eq_array_dictionary(array, index, key_type);
}

match self {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)
}
ScalarValue::Float64(val) => {
eq_array_primitive!(array, index, Float64Array, val)
}
ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val),
ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val),
ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val),
ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val),
ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val),
ScalarValue::UInt16(val) => {
eq_array_primitive!(array, index, UInt16Array, val)
}
ScalarValue::UInt32(val) => {
eq_array_primitive!(array, index, UInt32Array, val)
}
ScalarValue::UInt64(val) => {
eq_array_primitive!(array, index, UInt64Array, val)
}
ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val),
ScalarValue::LargeUtf8(val) => {
eq_array_primitive!(array, index, LargeStringArray, val)
}
ScalarValue::Binary(val) => {
eq_array_primitive!(array, index, BinaryArray, val)
}
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
}
ScalarValue::List(_, _) => unimplemented!(),
ScalarValue::Date32(val) => {
eq_array_primitive!(array, index, Date32Array, val)
}
ScalarValue::Date64(val) => {
eq_array_primitive!(array, index, Date64Array, val)
}
ScalarValue::TimestampSecond(val) => {
eq_array_primitive!(array, index, TimestampSecondArray, val)
}
ScalarValue::TimestampMillisecond(val) => {
eq_array_primitive!(array, index, TimestampMillisecondArray, val)
}
ScalarValue::TimestampMicrosecond(val) => {
eq_array_primitive!(array, index, TimestampMicrosecondArray, val)
}
ScalarValue::TimestampNanosecond(val) => {
eq_array_primitive!(array, index, TimestampNanosecondArray, val)
}
ScalarValue::IntervalYearMonth(val) => {
eq_array_primitive!(array, index, IntervalYearMonthArray, val)
}
ScalarValue::IntervalDayTime(val) => {
eq_array_primitive!(array, index, IntervalDayTimeArray, val)
}
}
}

/// Compares a dictionary array with indexes of type `key_type`
/// with the array @ index for equality with self
fn eq_array_dictionary(
&self,
array: &ArrayRef,
index: usize,
) -> Result<Self> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
// (note validity was previously checked in `try_from_array`)
let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;
Self::try_from_array(dict_array.values(), values_index)
key_type: &DataType,
) -> bool {
let (values, values_index) = match key_type {
DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(),
DataType::Int16 => get_dict_value::<Int16Type>(array, index).unwrap(),
DataType::Int32 => get_dict_value::<Int32Type>(array, index).unwrap(),
DataType::Int64 => get_dict_value::<Int64Type>(array, index).unwrap(),
DataType::UInt8 => get_dict_value::<UInt8Type>(array, index).unwrap(),
DataType::UInt16 => get_dict_value::<UInt16Type>(array, index).unwrap(),
DataType::UInt32 => get_dict_value::<UInt32Type>(array, index).unwrap(),
DataType::UInt64 => get_dict_value::<UInt64Type>(array, index).unwrap(),
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
};

match values_index {
Some(values_index) => self.eq_array(values, values_index),
None => self.is_null(),
}
}
}

Expand Down Expand Up @@ -1654,6 +1776,159 @@ mod tests {
assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
}

#[test]
fn scalar_eq_array() {
// Validate that eq_array has the same semantics as ScalarValue::eq
macro_rules! make_typed_vec {
($INPUT:expr, $TYPE:ident) => {{
$INPUT
.iter()
.map(|v| v.map(|v| v as $TYPE))
.collect::<Vec<_>>()
}};
}

let bool_vals = vec![Some(true), None, Some(false)];
let f32_vals = vec![Some(-1.0), None, Some(1.0)];
let f64_vals = make_typed_vec!(f32_vals, f64);

let i8_vals = vec![Some(-1), None, Some(1)];
let i16_vals = make_typed_vec!(i8_vals, i16);
let i32_vals = make_typed_vec!(i8_vals, i32);
let i64_vals = make_typed_vec!(i8_vals, i64);

let u8_vals = vec![Some(0), None, Some(1)];
let u16_vals = make_typed_vec!(u8_vals, u16);
let u32_vals = make_typed_vec!(u8_vals, u32);
let u64_vals = make_typed_vec!(u8_vals, u64);

let str_vals = vec![Some("foo"), None, Some("bar")];

/// Test each value in `scalar` with the corresponding element
/// at `array`. Assumes each element is unique (aka not equal
/// with all other indexes)
struct TestCase {
array: ArrayRef,
scalars: Vec<ScalarValue>,
}

/// Create a test case for casing the input to the specified array type
macro_rules! make_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()),
scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(),
}
}};
}

macro_rules! make_str_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

macro_rules! make_binary_test_case {
($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new($INPUT.iter().cloned().collect::<$ARRAY_TY>()),
scalars: $INPUT
.iter()
.map(|v| {
ScalarValue::$SCALAR_TY(v.map(|v| v.as_bytes().to_vec()))
})
.collect(),
}
}};
}

/// create a test case for DictionaryArray<$INDEX_TY>
macro_rules! make_str_dict_test_case {
($INPUT:expr, $INDEX_TY:ident, $SCALAR_TY:ident) => {{
TestCase {
array: Arc::new(
$INPUT
.iter()
.cloned()
.collect::<DictionaryArray<$INDEX_TY>>(),
),
scalars: $INPUT
.iter()
.map(|v| ScalarValue::$SCALAR_TY(v.map(|v| v.to_string())))
.collect(),
}
}};
}

let cases = vec![
make_test_case!(bool_vals, BooleanArray, Boolean),
make_test_case!(f32_vals, Float32Array, Float32),
make_test_case!(f64_vals, Float64Array, Float64),
make_test_case!(i8_vals, Int8Array, Int8),
make_test_case!(i16_vals, Int16Array, Int16),
make_test_case!(i32_vals, Int32Array, Int32),
make_test_case!(i64_vals, Int64Array, Int64),
make_test_case!(u8_vals, UInt8Array, UInt8),
make_test_case!(u16_vals, UInt16Array, UInt16),
make_test_case!(u32_vals, UInt32Array, UInt32),
make_test_case!(u64_vals, UInt64Array, UInt64),
make_str_test_case!(str_vals, StringArray, Utf8),
make_str_test_case!(str_vals, LargeStringArray, LargeUtf8),
make_binary_test_case!(str_vals, BinaryArray, Binary),
make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary),
make_test_case!(i32_vals, Date32Array, Date32),
make_test_case!(i64_vals, Date64Array, Date64),
make_test_case!(i64_vals, TimestampSecondArray, TimestampSecond),
make_test_case!(i64_vals, TimestampMillisecondArray, TimestampMillisecond),
make_test_case!(i64_vals, TimestampMicrosecondArray, TimestampMicrosecond),
make_test_case!(i64_vals, TimestampNanosecondArray, TimestampNanosecond),
make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth),
make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime),
make_str_dict_test_case!(str_vals, Int8Type, Utf8),
make_str_dict_test_case!(str_vals, Int16Type, Utf8),
make_str_dict_test_case!(str_vals, Int32Type, Utf8),
make_str_dict_test_case!(str_vals, Int64Type, Utf8),
make_str_dict_test_case!(str_vals, UInt8Type, Utf8),
make_str_dict_test_case!(str_vals, UInt16Type, Utf8),
make_str_dict_test_case!(str_vals, UInt32Type, Utf8),
make_str_dict_test_case!(str_vals, UInt64Type, Utf8),
];

for case in cases {
let TestCase { array, scalars } = case;
assert_eq!(array.len(), scalars.len());

for (index, scalar) in scalars.into_iter().enumerate() {
assert!(
scalar.eq_array(&array, index),
"Expected {:?} to be equal to {:?} at index {}",
scalar,
array,
index
);

// test that all other elements are *not* equal
for other_index in 0..array.len() {
if index != other_index {
assert!(
!scalar.eq_array(&array, other_index),
"Expected {:?} to be NOT equal to {:?} at index {}",
scalar,
array,
other_index
);
}
}
}
}
}

#[test]
fn scalar_partial_ordering() {
use ScalarValue::*;
Expand Down

0 comments on commit 832803c

Please sign in to comment.