Skip to content

Commit

Permalink
feat: mask
Browse files Browse the repository at this point in the history
Mask sets entries of an array to null. I like the analogy to light: the array is a sequence of
lights (each value might be a different wavelength). Null is represented by the absence of
light. Placing a mask (i.e. a piece of plastic with slits) over the array causes those values where
the mask is present (i.e. "on", "true") to be dark.

An example in pseudo-code:

```rust
a = [1, 2, 3, 4, 5]
a_mask = [t, f, f, t, f]
mask(a, a_mask) == [null, 2, 3, null, 5]
```

Specializations
---------------

I only fallback to Arrow for two of the core arrays:

- Sparse. I was skeptical that I could do better than decompressing and applying it.
- Constant. If the mask is sparse, SparseArray might be a good choice. I didn't investigate.

For the non-core arrays, I'm missing the following. I'm not clear that I can beat decompression for
run end. The others are easy enough but some amount of typing and testing.

- fastlanes
- fsst
- roaring
- runend
- runend-bool
- zigzag

Naming
------

Pandas also calls this operation
[`mask`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.mask.html) but accepts an
optional second argument which is an array of values to use instead of null (which makes Pandas'
mask more like an `if_else`).

Arrow-rs calls this [`nullif`](https://arrow.apache.org/rust/arrow/compute/fn.nullif.html).

Arrow-cpp has [`if_else(condition, consequent,
alternate)`](https://arrow.apache.org/docs/cpp/compute.html#cpp-compute-scalar-selections) and
[`replace_with_mask(array, mask,
replacements)`](https://arrow.apache.org/docs/cpp/compute.html#replace-functions) both of which can
implement our `mask` by passing a `NullArray` as the third argument.
  • Loading branch information
danking committed Jan 13, 2025
1 parent f97c0cd commit 6b3b101
Show file tree
Hide file tree
Showing 39 changed files with 1,583 additions and 93 deletions.
72 changes: 72 additions & 0 deletions encodings/alp/src/alp/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use vortex_array::compute::{mask, try_cast, FilterMask, MaskFn};
use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::{ALPArray, ALPEncoding};

impl MaskFn<ALPArray> for ALPEncoding {
fn mask(&self, array: &ALPArray, filter_mask: FilterMask) -> VortexResult<ArrayData> {
ALPArray::try_new(
mask(&array.encoded(), filter_mask)?,
array.exponents(),
array
.patches()
.map(|patches| {
patches.map_values(|values| try_cast(&values, &values.dtype().as_nullable()))
})
.transpose()?,
)
.map(IntoArrayData::into_array)
}
}

#[cfg(test)]
mod tests {
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::test_harness::test_mask;
use vortex_array::validity::Validity;
use vortex_array::IntoArrayData as _;
use vortex_buffer::buffer;

use crate::alp_encode;

#[test]
fn test_mask_no_patches_alp_array() {
test_mask(
alp_encode(&PrimitiveArray::new(
buffer![1.0f32, 2.0, 3.0, 4.0, 5.0],
Validity::AllValid,
))
.unwrap()
.into_array(),
);

test_mask(
alp_encode(&PrimitiveArray::new(
buffer![1.0f32, 2.0, 3.0, 4.0, 5.0],
Validity::NonNullable,
))
.unwrap()
.into_array(),
);
}

#[test]
fn test_mask_patched_alp_array() {
let alp_array = alp_encode(&PrimitiveArray::new(
buffer![1.0f32, 2.0, 3.0, 4.0, 1e10],
Validity::AllValid,
))
.unwrap();
assert!(alp_array.patches().is_some());
test_mask(alp_array.into_array());

let alp_array = alp_encode(&PrimitiveArray::new(
buffer![1.0f32, 2.0, 3.0, 4.0, 1e10],
Validity::NonNullable,
))
.unwrap();
assert!(alp_array.patches().is_some());
test_mask(alp_array.into_array());
}
}
10 changes: 8 additions & 2 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod mask;

use vortex_array::compute::{
filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn,
TakeFn,
filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, MaskFn, ScalarAtFn,
SliceFn, TakeFn,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
Expand All @@ -14,6 +16,10 @@ impl ComputeVTable for ALPEncoding {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<ArrayData>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand Down
57 changes: 57 additions & 0 deletions encodings/alp/src/alp_rd/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use vortex_array::compute::{mask, FilterMask, MaskFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::{ALPRDArray, ALPRDEncoding};

impl MaskFn<ALPRDArray> for ALPRDEncoding {
fn mask(&self, array: &ALPRDArray, filter_mask: FilterMask) -> VortexResult<ArrayData> {
Ok(ALPRDArray::try_new(
array.dtype().as_nullable(),
mask(&array.left_parts(), filter_mask)?,
array.left_parts_dict(),
array.right_parts(),
array.right_bit_width(),
array.left_parts_patches(),
)?
.into_array())
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::test_harness::test_mask;
use vortex_array::IntoArrayData as _;

use crate::{ALPRDFloat, RDEncoder};

#[rstest]
#[case(0.1f32, 0.2f32, 3e25f32)]
#[case(0.1f64, 0.2f64, 3e100f64)]
fn test_mask_simple<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
test_mask(
RDEncoder::new(&[a, b])
.encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
.into_array(),
);
}

#[rstest]
#[case(0.1f32, 3e25f32)]
#[case(0.5f64, 1e100f64)]
fn test_mask_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
test_mask(
RDEncoder::new(&[a])
.encode(&PrimitiveArray::from_option_iter([
Some(a),
None,
Some(outlier),
Some(a),
None,
]))
.into_array(),
);
}
}
7 changes: 6 additions & 1 deletion encodings/alp/src/alp_rd/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use vortex_array::compute::{ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::compute::{ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::ArrayData;

use crate::ALPRDEncoding;

mod filter;
mod mask;
mod scalar_at;
mod slice;
mod take;
Expand All @@ -13,6 +14,10 @@ impl ComputeVTable for ALPRDEncoding {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<ArrayData>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand Down
24 changes: 23 additions & 1 deletion encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use num_traits::AsPrimitive;
use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::compute::{
ComputeVTable, FillForwardFn, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn,
};
use vortex_array::validity::{ArrayValidity, Validity};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
Expand All @@ -14,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding {
None
}

fn mask_fn(&self) -> Option<&dyn MaskFn<ArrayData>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand All @@ -27,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding {
}
}

impl MaskFn<ByteBoolArray> for ByteBoolEncoding {
fn mask(&self, array: &ByteBoolArray, mask: FilterMask) -> VortexResult<ArrayData> {
ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?)
.map(IntoArrayData::into_array)
}
}

impl ScalarAtFn<ByteBoolArray> for ByteBoolEncoding {
fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
Ok(Scalar::bool(
Expand Down Expand Up @@ -136,6 +149,7 @@ impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {

#[cfg(test)]
mod tests {
use vortex_array::compute::test_harness::test_mask;
use vortex_array::compute::{compare, scalar_at, slice, Operator};

use super::*;
Expand Down Expand Up @@ -208,4 +222,12 @@ mod tests {
let s = scalar_at(&arr, 4).unwrap();
assert!(s.is_null());
}

#[test]
fn test_mask_byte_bool() {
test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array());
test_mask(
ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
);
}
}
25 changes: 25 additions & 0 deletions encodings/datetime-parts/src/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use vortex_array::compute::{try_cast, CastFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult};

use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl CastFn<DateTimePartsArray> for DateTimePartsEncoding {
fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult<ArrayData> {
if !array.dtype().eq_ignore_nullability(dtype) {
vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype);
};

Ok(DateTimePartsArray::try_new(
array.dtype().clone().as_nullable(),
try_cast(
array.days().as_ref(),
&array.days().dtype().with_nullability(dtype.nullability()),
)?,
array.seconds(),
array.subsecond(),
)?
.into_array())
}
}
58 changes: 58 additions & 0 deletions encodings/datetime-parts/src/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use vortex_array::compute::{mask, FilterMask, MaskFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl MaskFn<DateTimePartsArray> for DateTimePartsEncoding {
fn mask(&self, array: &DateTimePartsArray, filter_mask: FilterMask) -> VortexResult<ArrayData> {
Ok(DateTimePartsArray::try_new(
array.dtype().clone().as_nullable(),
mask(array.days().as_ref(), filter_mask)?,
array.seconds(),
array.subsecond(),
)?
.into_array())
}
}

#[cfg(test)]
mod tests {
use vortex_array::array::TemporalArray;
use vortex_array::compute::test_harness::test_mask;
use vortex_array::IntoArrayData as _;
use vortex_buffer::buffer;
use vortex_datetime_dtype::TimeUnit;
use vortex_dtype::DType;

use crate::{split_temporal, DateTimePartsArray, TemporalParts};

#[test]
fn test_mask_datetime_parts_array() {
let raw_millis = buffer![
86_400i64, // element with only day component
86_400i64 + 1000, // element with day + second components
86_400i64 + 1000 + 1, // element with day + second + sub-second components
86_400i64 + 1000 + 5, // element with day + second + sub-second components
86_400i64 + 1000 + 55, // element with day + second + sub-second components
]
.into_array();
let temporal_array =
TemporalArray::new_timestamp(raw_millis, TimeUnit::Ms, Some("UTC".to_string()));
let TemporalParts {
days,
seconds,
subseconds,
} = split_temporal(temporal_array.clone()).unwrap();
let date_times = DateTimePartsArray::try_new(
DType::Extension(temporal_array.ext_dtype()),
days,
seconds,
subseconds,
)
.unwrap()
.into_array();

test_mask(date_times.clone());
}
}
13 changes: 12 additions & 1 deletion encodings/datetime-parts/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
mod cast;
mod filter;
mod mask;
mod take;

use vortex_array::array::{PrimitiveArray, TemporalArray};
use vortex_array::compute::{
scalar_at, slice, try_cast, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn,
scalar_at, slice, try_cast, CastFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn,
TakeFn,
};
use vortex_array::validity::ArrayValidity;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
Expand All @@ -17,10 +20,18 @@ use vortex_scalar::{PrimitiveScalar, Scalar};
use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl ComputeVTable for DateTimePartsEncoding {
fn cast_fn(&self) -> Option<&dyn CastFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<ArrayData>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<ArrayData>> {
Some(self)
}
Expand Down
4 changes: 4 additions & 0 deletions encodings/dict/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ vortex-array = { workspace = true, features = ["test-harness"] }
[[bench]]
name = "dict_compress"
harness = false

[[bench]]
name = "dict_mask"
harness = false
Loading

0 comments on commit 6b3b101

Please sign in to comment.