Skip to content

Commit

Permalink
Port ArrayResize to functions-array subcrate (#9570)
Browse files Browse the repository at this point in the history
* Issue-9569 - Port ArrayResize to function-arrays subcrate

* Issue-9569 - Address review comment

* Remove unused variants

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
erenavsarogullari and alamb authored Mar 14, 2024
1 parent 5911d18 commit 9d0c05b
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 135 deletions.
8 changes: 0 additions & 8 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ pub enum BuiltinScalarFunction {
ArrayUnion,
/// array_except
ArrayExcept,
/// array_resize
ArrayResize,

// string functions
/// ascii
Expand Down Expand Up @@ -311,7 +309,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::ArrayResize => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
BuiltinScalarFunction::Btrim => Volatility::Immutable,
Expand Down Expand Up @@ -393,7 +390,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReverse => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayResize => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) | (DataType::Null, _) => {
Expand Down Expand Up @@ -608,9 +604,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}

BuiltinScalarFunction::Concat
| BuiltinScalarFunction::ConcatWithSeparator => {
Expand Down Expand Up @@ -990,7 +983,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReverse => &["array_reverse", "list_reverse"],
BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"],
BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"],
BuiltinScalarFunction::ArrayIntersect => {
&["array_intersect", "list_intersect"]
}
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,6 @@ scalar_expr!(
);
scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates.");

scalar_expr!(
ArrayResize,
array_resize,
array size value,
"returns an array with the specified size filled with the given value."
);

scalar_expr!(
ArrayIntersect,
array_intersect,
Expand Down
100 changes: 98 additions & 2 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::datatypes::{
};
use arrow::row::{RowConverter, SortField};
use arrow_array::new_null_array;
use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_schema::FieldRef;
use arrow_schema::SortOptions;

Expand All @@ -39,9 +39,11 @@ use datafusion_common::cast::{
as_string_array,
};
use datafusion_common::{
exec_err, internal_err, not_impl_datafusion_err, DataFusionError, Result,
exec_err, internal_datafusion_err, internal_err, not_impl_datafusion_err,
DataFusionError, Result, ScalarValue,
};
use itertools::Itertools;

use std::any::type_name;
use std::sync::Arc;

Expand Down Expand Up @@ -893,6 +895,100 @@ pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// array_resize SQL function
pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() < 2 || arg.len() > 3 {
return exec_err!("array_resize needs two or three arguments");
}

let new_len = as_int64_array(&arg[1])?;
let new_element = if arg.len() == 3 {
Some(arg[2].clone())
} else {
None
};

match &arg[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&arg[0])?;
general_list_resize::<i32>(array, new_len, field, new_element)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&arg[0])?;
general_list_resize::<i64>(array, new_len, field, new_element)
}
array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
}
}

/// array_resize keep the original array and append the default element to the end
fn general_list_resize<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
count_array: &Int64Array,
field: &FieldRef,
default_element: Option<ArrayRef>,
) -> Result<ArrayRef>
where
O: TryInto<i64>,
{
let data_type = array.value_type();

let values = array.values();
let original_data = values.to_data();

// create default element array
let default_element = if let Some(default_element) = default_element {
default_element
} else {
let null_scalar = ScalarValue::try_from(&data_type)?;
null_scalar.to_array_of_size(original_data.len())?
};
let default_value_data = default_element.to_data();

// create a mutable array to store the original data
let capacity = Capacities::Array(original_data.len() + default_value_data.len());
let mut offsets = vec![O::usize_as(0)];
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &default_value_data],
false,
capacity,
);

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let count = O::usize_as(count);
let start = offset_window[0];
if start + count > offset_window[1] {
let extra_count =
(start + count - offset_window[1]).try_into().map_err(|_| {
internal_datafusion_err!(
"array_resize: failed to convert size to i64"
)
})?;
let end = offset_window[1];
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
// append default element
for _ in 0..extra_count {
mutable.extend(1, row_index, row_index + 1);
}
} else {
let end = start + count;
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
};
offsets.push(offsets[row_index] + count);
}

let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
field.clone(),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
None,
)?))
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub mod expr_fn {
pub use super::udf::array_length;
pub use super::udf::array_ndims;
pub use super::udf::array_repeat;
pub use super::udf::array_resize;
pub use super::udf::array_sort;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
Expand Down Expand Up @@ -89,6 +90,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::array_sort_udf(),
udf::array_distinct_udf(),
udf::array_repeat_udf(),
udf::array_resize_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
56 changes: 56 additions & 0 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,62 @@ impl ScalarUDFImpl for ArraySort {
}
}

make_udf_function!(
ArrayResize,
array_resize,
array size value,
"returns an array with the specified size filled with the given value.",
array_resize_udf
);

#[derive(Debug)]
pub(super) struct ArrayResize {
signature: Signature,
aliases: Vec<String>,
}

impl ArrayResize {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec!["array_resize".to_string(), "list_resize".to_string()],
}
}
}

impl ScalarUDFImpl for ArrayResize {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_resize"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
match &arg_types[0] {
List(field) | FixedSizeList(field, _) => Ok(List(field.clone())),
LargeList(field) => Ok(LargeList(field.clone())),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
crate::kernels::array_resize(&args).map(ColumnarValue::Array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

make_udf_function!(
Cardinality,
cardinality,
Expand Down
97 changes: 1 addition & 96 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::buffer::OffsetBuffer;
use arrow::compute::{self};
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_buffer::NullBuffer;

use arrow_schema::FieldRef;
use datafusion_common::cast::{
Expand All @@ -35,7 +35,6 @@ use datafusion_common::cast::{
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use itertools::Itertools;

Expand Down Expand Up @@ -1393,100 +1392,6 @@ pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
)?))
}

/// array_resize SQL function
pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() < 2 || arg.len() > 3 {
return exec_err!("array_resize needs two or three arguments");
}

let new_len = as_int64_array(&arg[1])?;
let new_element = if arg.len() == 3 {
Some(arg[2].clone())
} else {
None
};

match &arg[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&arg[0])?;
general_list_resize::<i32>(array, new_len, field, new_element)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&arg[0])?;
general_list_resize::<i64>(array, new_len, field, new_element)
}
array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
}
}

/// array_resize keep the original array and append the default element to the end
fn general_list_resize<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
count_array: &Int64Array,
field: &FieldRef,
default_element: Option<ArrayRef>,
) -> Result<ArrayRef>
where
O: TryInto<i64>,
{
let data_type = array.value_type();

let values = array.values();
let original_data = values.to_data();

// create default element array
let default_element = if let Some(default_element) = default_element {
default_element
} else {
let null_scalar = ScalarValue::try_from(&data_type)?;
null_scalar.to_array_of_size(original_data.len())?
};
let default_value_data = default_element.to_data();

// create a mutable array to store the original data
let capacity = Capacities::Array(original_data.len() + default_value_data.len());
let mut offsets = vec![O::usize_as(0)];
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &default_value_data],
false,
capacity,
);

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let count = O::usize_as(count);
let start = offset_window[0];
if start + count > offset_window[1] {
let extra_count =
(start + count - offset_window[1]).try_into().map_err(|_| {
internal_datafusion_err!(
"array_resize: failed to convert size to i64"
)
})?;
let end = offset_window[1];
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
// append default element
for _ in 0..extra_count {
mutable.extend(1, row_index, row_index + 1);
}
} else {
let end = start + count;
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
};
offsets.push(offsets[row_index] + count);
}

let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
field.clone(),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
None,
)?))
}

/// array_reverse SQL function
pub fn array_reverse(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() != 1 {
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_intersect)(args)
}),
BuiltinScalarFunction::ArrayResize => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_resize)(args)
}),
BuiltinScalarFunction::ArrayUnion => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_union)(args)
}),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ enum ScalarFunction {
FindInSet = 127;
/// 128 was ArraySort
/// 129 was ArrayDistinct
ArrayResize = 130;
/// 130 was ArrayResize
EndsWith = 131;
/// 132 was InStr
MakeDate = 133;
Expand Down
Loading

0 comments on commit 9d0c05b

Please sign in to comment.