From ed2252c35ead71dcf620f34b9276cf57235ac2d2 Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Fri, 20 Dec 2024 09:28:19 -0600 Subject: [PATCH] unary agg --- .../src/exp/executors/aggregate/mod.rs | 26 +++ .../src/exp/executors/aggregate/unary.rs | 192 ++++++++++++++++++ .../rayexec_bullet/src/exp/executors/mod.rs | 37 ++++ .../src/exp/executors/scalar/binary.rs | 12 +- .../src/exp/executors/scalar/mod.rs | 19 -- .../src/exp/executors/scalar/unary.rs | 12 +- 6 files changed, 261 insertions(+), 37 deletions(-) create mode 100644 crates/rayexec_bullet/src/exp/executors/aggregate/mod.rs create mode 100644 crates/rayexec_bullet/src/exp/executors/aggregate/unary.rs diff --git a/crates/rayexec_bullet/src/exp/executors/aggregate/mod.rs b/crates/rayexec_bullet/src/exp/executors/aggregate/mod.rs new file mode 100644 index 000000000..0058c65a0 --- /dev/null +++ b/crates/rayexec_bullet/src/exp/executors/aggregate/mod.rs @@ -0,0 +1,26 @@ +pub mod unary; + +use std::fmt::Debug; + +use rayexec_error::Result; + +use super::OutputBuffer; +use crate::exp::buffer::addressable::MutableAddressableStorage; + +/// State for a single group's aggregate. +/// +/// An example state for SUM would be a struct that takes a running sum from +/// values provided in `update`. +pub trait AggregateState: Debug { + /// Merge other state into this state. + fn merge(&mut self, other: &mut Self) -> Result<()>; + + /// Update this state with some input. + fn update(&mut self, input: &Input) -> Result<()>; + + /// Produce a single value from the state, along with a bool indicating if + /// the value is valid. + fn finalize(&mut self, output: OutputBuffer) -> Result<()> + where + M: MutableAddressableStorage; +} diff --git a/crates/rayexec_bullet/src/exp/executors/aggregate/unary.rs b/crates/rayexec_bullet/src/exp/executors/aggregate/unary.rs new file mode 100644 index 000000000..23606e2b0 --- /dev/null +++ b/crates/rayexec_bullet/src/exp/executors/aggregate/unary.rs @@ -0,0 +1,192 @@ +use rayexec_error::Result; + +use super::AggregateState; +use crate::compute::util::IntoExactSizedIterator; +use crate::exp::array::Array; +use crate::exp::buffer::addressable::AddressableStorage; +use crate::exp::buffer::physical_type::PhysicalStorage; + +#[derive(Debug, Clone, Copy)] +pub struct UnaryNonNullUpdater; + +impl UnaryNonNullUpdater { + pub fn update( + array: &Array, + selection: impl IntoExactSizedIterator, + mapping: impl IntoExactSizedIterator, + states: &mut [State], + ) -> Result<()> + where + S: PhysicalStorage, + Output: ?Sized, + State: AggregateState, + { + // TODO: Length check. + + let input = S::get_storage(array.buffer())?; + let validity = array.validity(); + + if validity.all_valid() { + for (input_idx, state_idx) in selection.into_iter().zip(mapping.into_iter()) { + let val = input.get(input_idx).unwrap(); + let state = &mut states[state_idx]; + state.update(val)?; + } + } else { + for (input_idx, state_idx) in selection.into_iter().zip(mapping.into_iter()) { + if !validity.is_valid(input_idx) { + continue; + } + + let val = input.get(input_idx).unwrap(); + let state = &mut states[state_idx]; + state.update(val)?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datatype::DataType; + use crate::exp::buffer::addressable::MutableAddressableStorage; + use crate::exp::buffer::physical_type::{PhysicalI32, PhysicalUtf8}; + use crate::exp::buffer::{Int32Builder, StringViewBufferBuilder}; + use crate::exp::executors::OutputBuffer; + use crate::exp::validity::Validity; + + #[derive(Debug, Default)] + struct TestSumState { + val: i32, + } + + impl AggregateState for TestSumState { + fn merge(&mut self, other: &mut Self) -> Result<()> { + self.val += other.val; + Ok(()) + } + + fn update(&mut self, &input: &i32) -> Result<()> { + self.val += input; + Ok(()) + } + + fn finalize(&mut self, output: OutputBuffer) -> Result<()> + where + M: MutableAddressableStorage, + { + output.put(&self.val); + Ok(()) + } + } + + #[test] + fn unary_primitive_single_state() { + let mut states = [TestSumState::default()]; + let array = Array::new( + DataType::Int32, + Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(), + ); + + UnaryNonNullUpdater::update::( + &array, + [0, 1, 2, 4], + [0, 0, 0, 0], + &mut states, + ) + .unwrap(); + + assert_eq!(11, states[0].val); + } + + #[test] + fn unary_primitive_single_state_skip_null() { + let mut states = [TestSumState::default()]; + let mut validity = Validity::new_all_valid(5); + validity.set_invalid(0); + let array = Array::new_with_validity( + DataType::Int32, + Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(), + validity, + ) + .unwrap(); + + UnaryNonNullUpdater::update::( + &array, + [0, 1, 2, 4], + [0, 0, 0, 0], + &mut states, + ) + .unwrap(); + + assert_eq!(10, states[0].val); + } + + #[test] + fn unary_primitive_multiple_states() { + let mut states = [TestSumState::default(), TestSumState::default()]; + let array = Array::new( + DataType::Int32, + Int32Builder::from_iter([1, 2, 3, 4, 5]).unwrap(), + ); + + UnaryNonNullUpdater::update::( + &array, + [0, 1, 2, 4, 0, 3, 3], + [0, 0, 0, 0, 1, 1, 0], + &mut states, + ) + .unwrap(); + + assert_eq!(15, states[0].val); + assert_eq!(5, states[1].val); + } + + #[derive(Debug, Default)] + struct TestStringAgg { + val: String, + } + + impl AggregateState for TestStringAgg { + fn merge(&mut self, other: &mut Self) -> Result<()> { + self.val.push_str(&other.val); + Ok(()) + } + + fn update(&mut self, input: &str) -> Result<()> { + self.val.push_str(input); + Ok(()) + } + + fn finalize(&mut self, output: OutputBuffer) -> Result<()> + where + M: MutableAddressableStorage, + { + output.put(&self.val); + Ok(()) + } + } + + #[test] + fn unary_string_single_state() { + // Test just checks to ensure working with varlen is sane. + let mut states = [TestStringAgg::default()]; + let array = Array::new( + DataType::Utf8, + StringViewBufferBuilder::from_iter(["aa", "bbb", "cccc"]).unwrap(), + ); + + UnaryNonNullUpdater::update::( + &array, + [0, 1, 2], + [0, 0, 0], + &mut states, + ) + .unwrap(); + + assert_eq!("aabbbcccc", &states[0].val); + } +} diff --git a/crates/rayexec_bullet/src/exp/executors/mod.rs b/crates/rayexec_bullet/src/exp/executors/mod.rs index cd930e279..85327840d 100644 --- a/crates/rayexec_bullet/src/exp/executors/mod.rs +++ b/crates/rayexec_bullet/src/exp/executors/mod.rs @@ -1 +1,38 @@ +pub mod aggregate; pub mod scalar; + +use super::buffer::addressable::MutableAddressableStorage; +use super::validity::Validity; + +/// Helper for assigning a value to a location in a buffer. +#[derive(Debug)] +pub struct OutputBuffer<'a, M> +where + M: MutableAddressableStorage, +{ + idx: usize, + buffer: &'a mut M, + validity: &'a mut Validity, +} + +impl<'a, M> OutputBuffer<'a, M> +where + M: MutableAddressableStorage, +{ + pub(crate) fn new(idx: usize, buffer: &'a mut M, validity: &'a mut Validity) -> Self { + debug_assert_eq!(buffer.len(), validity.len()); + OutputBuffer { + idx, + buffer, + validity, + } + } + + pub fn put(self, val: &M::T) { + self.buffer.put(self.idx, val) + } + + pub fn put_null(self) { + self.validity.set_invalid(self.idx) + } +} diff --git a/crates/rayexec_bullet/src/exp/executors/scalar/binary.rs b/crates/rayexec_bullet/src/exp/executors/scalar/binary.rs index b76a87304..43381a764 100644 --- a/crates/rayexec_bullet/src/exp/executors/scalar/binary.rs +++ b/crates/rayexec_bullet/src/exp/executors/scalar/binary.rs @@ -1,11 +1,11 @@ use rayexec_error::Result; -use super::OutputBuffer; use crate::compute::util::IntoExactSizedIterator; use crate::exp::array::Array; use crate::exp::buffer::addressable::AddressableStorage; use crate::exp::buffer::physical_type::{MutablePhysicalStorage, PhysicalStorage}; use crate::exp::buffer::ArrayBuffer; +use crate::exp::executors::OutputBuffer; use crate::exp::validity::Validity; #[derive(Debug, Clone)] @@ -47,10 +47,7 @@ impl BinaryExecutor { op( val1, val2, - OutputBuffer { - idx: output_idx, - buffer: &mut output, - }, + OutputBuffer::new(output_idx, &mut output, out_validity), ); } } else { @@ -64,10 +61,7 @@ impl BinaryExecutor { op( val1, val2, - OutputBuffer { - idx: output_idx, - buffer: &mut output, - }, + OutputBuffer::new(output_idx, &mut output, out_validity), ); } else { out_validity.set_invalid(output_idx); diff --git a/crates/rayexec_bullet/src/exp/executors/scalar/mod.rs b/crates/rayexec_bullet/src/exp/executors/scalar/mod.rs index 796cccba4..3a2e91f5a 100644 --- a/crates/rayexec_bullet/src/exp/executors/scalar/mod.rs +++ b/crates/rayexec_bullet/src/exp/executors/scalar/mod.rs @@ -2,22 +2,3 @@ pub mod binary; pub mod unary; use crate::exp::buffer::addressable::MutableAddressableStorage; -use crate::exp::buffer::ArrayBuffer; - -#[derive(Debug)] -pub struct OutputBuffer<'a, M> -where - M: MutableAddressableStorage, -{ - idx: usize, - buffer: &'a mut M, -} - -impl<'a, M> OutputBuffer<'a, M> -where - M: MutableAddressableStorage, -{ - pub fn put(self, val: &M::T) { - self.buffer.put(self.idx, val) - } -} diff --git a/crates/rayexec_bullet/src/exp/executors/scalar/unary.rs b/crates/rayexec_bullet/src/exp/executors/scalar/unary.rs index ebe4d0cd8..2e3cc2824 100644 --- a/crates/rayexec_bullet/src/exp/executors/scalar/unary.rs +++ b/crates/rayexec_bullet/src/exp/executors/scalar/unary.rs @@ -1,10 +1,10 @@ use rayexec_error::Result; -use super::OutputBuffer; use crate::exp::array::Array; use crate::exp::buffer::addressable::{AddressableStorage, MutableAddressableStorage}; use crate::exp::buffer::physical_type::{MutablePhysicalStorage, PhysicalStorage}; use crate::exp::buffer::ArrayBuffer; +use crate::exp::executors::OutputBuffer; use crate::exp::validity::Validity; #[derive(Debug, Clone)] @@ -33,10 +33,7 @@ impl UnaryExecutor { for (output_idx, input_idx) in selection.into_iter().enumerate() { op( input.get(input_idx).unwrap(), - OutputBuffer { - idx: output_idx, - buffer: &mut output, - }, + OutputBuffer::new(output_idx, &mut output, out_validity), ); } } else { @@ -44,10 +41,7 @@ impl UnaryExecutor { if validity.is_valid(input_idx) { op( input.get(input_idx).unwrap(), - OutputBuffer { - idx: output_idx, - buffer: &mut output, - }, + OutputBuffer::new(output_idx, &mut output, out_validity), ); } else { out_validity.set_invalid(output_idx);