Skip to content

Commit

Permalink
unary agg
Browse files Browse the repository at this point in the history
  • Loading branch information
scsmithr committed Dec 20, 2024
1 parent ca5fb3b commit ed2252c
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 37 deletions.
26 changes: 26 additions & 0 deletions crates/rayexec_bullet/src/exp/executors/aggregate/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
pub mod unary;

use std::fmt::Debug;

use rayexec_error::Result;

use super::OutputBuffer;
use crate::exp::buffer::addressable::MutableAddressableStorage;

/// State for a single group's aggregate.
///
/// An example state for SUM would be a struct that takes a running sum from
/// values provided in `update`.
pub trait AggregateState<Input: ?Sized, Output: ?Sized>: Debug {
/// Merge other state into this state.
fn merge(&mut self, other: &mut Self) -> Result<()>;

/// Update this state with some input.
fn update(&mut self, input: &Input) -> Result<()>;

/// Produce a single value from the state, along with a bool indicating if
/// the value is valid.
fn finalize<M>(&mut self, output: OutputBuffer<M>) -> Result<()>
where
M: MutableAddressableStorage<T = Output>;
}
192 changes: 192 additions & 0 deletions crates/rayexec_bullet/src/exp/executors/aggregate/unary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use rayexec_error::Result;

use super::AggregateState;
use crate::compute::util::IntoExactSizedIterator;
use crate::exp::array::Array;
use crate::exp::buffer::addressable::AddressableStorage;
use crate::exp::buffer::physical_type::PhysicalStorage;

#[derive(Debug, Clone, Copy)]
pub struct UnaryNonNullUpdater;

impl UnaryNonNullUpdater {
pub fn update<S, State, Output>(
array: &Array,
selection: impl IntoExactSizedIterator<Item = usize>,
mapping: impl IntoExactSizedIterator<Item = usize>,
states: &mut [State],
) -> Result<()>
where
S: PhysicalStorage,
Output: ?Sized,
State: AggregateState<S::StorageType, Output>,
{
// TODO: Length check.

let input = S::get_storage(array.buffer())?;
let validity = array.validity();

if validity.all_valid() {
for (input_idx, state_idx) in selection.into_iter().zip(mapping.into_iter()) {
let val = input.get(input_idx).unwrap();
let state = &mut states[state_idx];
state.update(val)?;
}
} else {
for (input_idx, state_idx) in selection.into_iter().zip(mapping.into_iter()) {
if !validity.is_valid(input_idx) {
continue;
}

let val = input.get(input_idx).unwrap();
let state = &mut states[state_idx];
state.update(val)?;
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::datatype::DataType;
use crate::exp::buffer::addressable::MutableAddressableStorage;
use crate::exp::buffer::physical_type::{PhysicalI32, PhysicalUtf8};
use crate::exp::buffer::{Int32Builder, StringViewBufferBuilder};
use crate::exp::executors::OutputBuffer;
use crate::exp::validity::Validity;

#[derive(Debug, Default)]
struct TestSumState {
val: i32,
}

impl AggregateState<i32, i32> for TestSumState {
fn merge(&mut self, other: &mut Self) -> Result<()> {
self.val += other.val;
Ok(())
}

fn update(&mut self, &input: &i32) -> Result<()> {
self.val += input;
Ok(())
}

fn finalize<M>(&mut self, output: OutputBuffer<M>) -> Result<()>
where
M: MutableAddressableStorage<T = i32>,
{
output.put(&self.val);
Ok(())
}
}

#[test]
fn unary_primitive_single_state() {
let mut states = [TestSumState::default()];
let array = Array::new(
DataType::Int32,
Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(),
);

UnaryNonNullUpdater::update::<PhysicalI32, _, _>(
&array,
[0, 1, 2, 4],
[0, 0, 0, 0],
&mut states,
)
.unwrap();

assert_eq!(11, states[0].val);
}

#[test]
fn unary_primitive_single_state_skip_null() {
let mut states = [TestSumState::default()];
let mut validity = Validity::new_all_valid(5);
validity.set_invalid(0);
let array = Array::new_with_validity(
DataType::Int32,
Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(),
validity,
)
.unwrap();

UnaryNonNullUpdater::update::<PhysicalI32, _, _>(
&array,
[0, 1, 2, 4],
[0, 0, 0, 0],
&mut states,
)
.unwrap();

assert_eq!(10, states[0].val);
}

#[test]
fn unary_primitive_multiple_states() {
let mut states = [TestSumState::default(), TestSumState::default()];
let array = Array::new(
DataType::Int32,
Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(),
);

UnaryNonNullUpdater::update::<PhysicalI32, _, _>(
&array,
[0, 1, 2, 4, 0, 3, 3],
[0, 0, 0, 0, 1, 1, 0],
&mut states,
)
.unwrap();

assert_eq!(15, states[0].val);
assert_eq!(5, states[1].val);
}

#[derive(Debug, Default)]
struct TestStringAgg {
val: String,
}

impl AggregateState<str, str> for TestStringAgg {
fn merge(&mut self, other: &mut Self) -> Result<()> {
self.val.push_str(&other.val);
Ok(())
}

fn update(&mut self, input: &str) -> Result<()> {
self.val.push_str(input);
Ok(())
}

fn finalize<M>(&mut self, output: OutputBuffer<M>) -> Result<()>
where
M: MutableAddressableStorage<T = str>,
{
output.put(&self.val);
Ok(())
}
}

#[test]
fn unary_string_single_state() {
// Test just checks to ensure working with varlen is sane.
let mut states = [TestStringAgg::default()];
let array = Array::new(
DataType::Utf8,
StringViewBufferBuilder::from_iter(["aa", "bbb", "cccc"]).unwrap(),
);

UnaryNonNullUpdater::update::<PhysicalUtf8, _, _>(
&array,
[0, 1, 2],
[0, 0, 0],
&mut states,
)
.unwrap();

assert_eq!("aabbbcccc", &states[0].val);
}
}
37 changes: 37 additions & 0 deletions crates/rayexec_bullet/src/exp/executors/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
pub mod aggregate;
pub mod scalar;

use super::buffer::addressable::MutableAddressableStorage;
use super::validity::Validity;

/// Helper for assigning a value to a location in a buffer.
#[derive(Debug)]
pub struct OutputBuffer<'a, M>
where
M: MutableAddressableStorage,
{
idx: usize,
buffer: &'a mut M,
validity: &'a mut Validity,
}

impl<'a, M> OutputBuffer<'a, M>
where
M: MutableAddressableStorage,
{
pub(crate) fn new(idx: usize, buffer: &'a mut M, validity: &'a mut Validity) -> Self {
debug_assert_eq!(buffer.len(), validity.len());
OutputBuffer {
idx,
buffer,
validity,
}
}

pub fn put(self, val: &M::T) {
self.buffer.put(self.idx, val)
}

pub fn put_null(self) {
self.validity.set_invalid(self.idx)
}
}
12 changes: 3 additions & 9 deletions crates/rayexec_bullet/src/exp/executors/scalar/binary.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use rayexec_error::Result;

use super::OutputBuffer;
use crate::compute::util::IntoExactSizedIterator;
use crate::exp::array::Array;
use crate::exp::buffer::addressable::AddressableStorage;
use crate::exp::buffer::physical_type::{MutablePhysicalStorage, PhysicalStorage};
use crate::exp::buffer::ArrayBuffer;
use crate::exp::executors::OutputBuffer;
use crate::exp::validity::Validity;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -47,10 +47,7 @@ impl BinaryExecutor {
op(
val1,
val2,
OutputBuffer {
idx: output_idx,
buffer: &mut output,
},
OutputBuffer::new(output_idx, &mut output, out_validity),
);
}
} else {
Expand All @@ -64,10 +61,7 @@ impl BinaryExecutor {
op(
val1,
val2,
OutputBuffer {
idx: output_idx,
buffer: &mut output,
},
OutputBuffer::new(output_idx, &mut output, out_validity),
);
} else {
out_validity.set_invalid(output_idx);
Expand Down
19 changes: 0 additions & 19 deletions crates/rayexec_bullet/src/exp/executors/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,3 @@ pub mod binary;
pub mod unary;

use crate::exp::buffer::addressable::MutableAddressableStorage;
use crate::exp::buffer::ArrayBuffer;

#[derive(Debug)]
pub struct OutputBuffer<'a, M>
where
M: MutableAddressableStorage,
{
idx: usize,
buffer: &'a mut M,
}

impl<'a, M> OutputBuffer<'a, M>
where
M: MutableAddressableStorage,
{
pub fn put(self, val: &M::T) {
self.buffer.put(self.idx, val)
}
}
12 changes: 3 additions & 9 deletions crates/rayexec_bullet/src/exp/executors/scalar/unary.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use rayexec_error::Result;

use super::OutputBuffer;
use crate::exp::array::Array;
use crate::exp::buffer::addressable::{AddressableStorage, MutableAddressableStorage};
use crate::exp::buffer::physical_type::{MutablePhysicalStorage, PhysicalStorage};
use crate::exp::buffer::ArrayBuffer;
use crate::exp::executors::OutputBuffer;
use crate::exp::validity::Validity;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -33,21 +33,15 @@ impl UnaryExecutor {
for (output_idx, input_idx) in selection.into_iter().enumerate() {
op(
input.get(input_idx).unwrap(),
OutputBuffer {
idx: output_idx,
buffer: &mut output,
},
OutputBuffer::new(output_idx, &mut output, out_validity),
);
}
} else {
for (output_idx, input_idx) in selection.into_iter().enumerate() {
if validity.is_valid(input_idx) {
op(
input.get(input_idx).unwrap(),
OutputBuffer {
idx: output_idx,
buffer: &mut output,
},
OutputBuffer::new(output_idx, &mut output, out_validity),
);
} else {
out_validity.set_invalid(output_idx);
Expand Down

0 comments on commit ed2252c

Please sign in to comment.