Skip to content

Commit

Permalink
rewrite concat_internal (#7748)
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Oct 6, 2023
1 parent 8504600 commit 2e7fd62
Showing 1 changed file with 45 additions and 37 deletions.
82 changes: 45 additions & 37 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
aligned_args
}

// Concatenate arrays on the same row.
fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
let args = align_array_dimensions(args.to_vec())?;

Expand All @@ -818,49 +819,56 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {

// Assume number of rows is the same for all arrays
let row_count = list_arrays[0].len();
let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
let array_data: Vec<_> = list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data: Vec<&ArrayData> = array_data.iter().collect();

let mut mutable = MutableArrayData::with_capacities(array_data, true, capacity);

let mut array_lens = vec![0; row_count];
let mut null_bit_map: Vec<bool> = vec![true; row_count];
let mut array_lengths = vec![];
let mut arrays = vec![];
let mut valid = BooleanBufferBuilder::new(row_count);
for i in 0..row_count {
let nulls = list_arrays
.iter()
.map(|arr| arr.is_null(i))
.collect::<Vec<_>>();

// If all the arrays are null, the concatenated array is null
let is_null = nulls.iter().all(|&x| x);
if is_null {
array_lengths.push(0);
valid.append(false);
} else {
// Get all the arrays on i-th row
let values = list_arrays
.iter()
.map(|arr| arr.value(i))
.collect::<Vec<_>>();

for (i, array_len) in array_lens.iter_mut().enumerate().take(row_count) {
let null_count = mutable.null_count();
for (j, a) in list_arrays.iter().enumerate() {
mutable.extend(j, i, i + 1);
*array_len += a.value_length(i);
}
let elements = values
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

// This means all arrays are null
if mutable.null_count() == null_count + list_arrays.len() {
null_bit_map[i] = false;
// Concatenated array on i-th row
let concated_array = arrow::compute::concat(elements.as_slice())?;
array_lengths.push(concated_array.len());
arrays.push(concated_array);
valid.append(true);
}
}
// Assume all arrays have the same data type
let data_type = list_arrays[0].value_type().clone();
let buffer = valid.finish();

let mut buffer = BooleanBufferBuilder::new(row_count);
buffer.append_slice(null_bit_map.as_slice());
let nulls = Some(NullBuffer::from(buffer.finish()));

let offsets: Vec<i32> = std::iter::once(0)
.chain(array_lens.iter().scan(0, |state, &x| {
*state += x;
Some(*state)
}))
.collect();

let builder = mutable.into_builder();

let list = builder
.len(row_count)
.buffers(vec![Buffer::from_vec(offsets)])
.nulls(nulls)
.build()?;

let list = arrow::array::make_array(list);
Ok(Arc::new(list))
let elements = arrays
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(arrow::compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);
Ok(Arc::new(list_arr))
}

/// Array_concat/Array_cat SQL function
Expand Down

0 comments on commit 2e7fd62

Please sign in to comment.