From f583e5c3887974436753de83a0d08e95f4e053f6 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 9 Nov 2024 17:10:58 +0100 Subject: [PATCH 01/18] Part of the implementation of array_insert --- native/proto/src/proto/expr.proto | 10 ++ native/spark-expr/src/list.rs | 156 +++++++++++++++++- .../apache/comet/serde/QueryPlanSerde.scala | 23 +++ 3 files changed, 182 insertions(+), 7 deletions(-) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 3a8193f4a..2b993104f 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -82,6 +82,7 @@ message Expr { ToJson to_json = 55; ListExtract list_extract = 56; GetArrayStructFields get_array_struct_fields = 57; + ArrayInsert array_insert = 58; } } @@ -402,6 +403,15 @@ enum NullOrdering { NullsLast = 1; } +// Array functions +message ArrayInsert { + Expr src_array_expr = 1; + Expr pos_expr = 2; + Expr item_expr = 3; + bool legacy_negative_index = 4; + DataType element_type = 5; +} + message DataType { enum DataTypeId { BOOL = 0; diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index a376198db..e2a263c33 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; -use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; -use arrow_schema::{DataType, FieldRef, Schema}; +use arrow::{array::{Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, datatypes::ArrowNativeType, record_batch::RecordBatch}; +use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; +use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ @@ -26,10 +26,7 @@ use datafusion_common::{ }; use datafusion_physical_expr::PhysicalExpr; use std::{ - any::Any, - fmt::{Display, Formatter}, - hash::{Hash, Hasher}, - sync::Arc, + any::Any, fmt::{Display, Formatter}, hash::{Hash, Hasher}, sync::Arc }; #[derive(Debug, Hash)] pub struct ListExtract { @@ -413,6 +410,151 @@ impl PartialEq for GetArrayStructFields { } } +#[derive(Debug, Hash)] +pub struct ArrayInsert { + src_array_expr: Arc, + pos_expr: Arc, + item_expr: Arc, + legacy_negative_index: bool, +} + +impl ArrayInsert { + pub fn new( + src_array_expr: Arc, + pos_expr: Arc, + item_expr: Arc, + legacy_negative_index: bool, + ) -> Self { + Self { + src_array_expr, + pos_expr, + item_expr, + legacy_negative_index, + } + } +} + +impl PhysicalExpr for ArrayInsert { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + match self.src_array_expr.data_type(input_schema)? { + DataType::List(field) => Ok(DataType::List(field)), + DataType::LargeList(field) => Ok(DataType::LargeList(field)) + data_type => Err(DataFusionError::Internal(format!("Unexpected data type in ArrayInsert: {:?}", data_type))) + } + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + todo!() + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let pos_value = self.pos_expr.evaluate(batch)?.into_array(batch.num_rows())?; + // Check that index value is integer-like + match pos_value.data_type() { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {} + data_type => { + return Err(DataFusionError::Internal(format!("Unexpected index data type in ArrayInsert: {:?}", data_type))) + } + } + + // Check that src array is actually an array and get it's value type + let src_value = self.src_array_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let src_element_type = match src_value.data_type() { + DataType::List(field) => field.data_type(), + DataType::LargeList(field) => field.data_type(), + data_type => { + return Err(DataFusionError::Internal(format!("Unexpected src array type in ArrayInsert: {:?}", data_type))) + } + }; + + // Check that inserted value has the same type as an array + let item_value = self.item_expr.evaluate(batch)?.into_array(batch.num_rows())?; + if item_value.data_type() != src_element_type { + return Err(DataFusionError::Internal(format!("Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}", src_element_type, item_value.data_type()))) + } + todo!() + } + + fn children(&self) -> Vec<&Arc> { + todo!() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + todo!() + } + + fn dyn_hash(&self, _state: &mut dyn Hasher) { + todo!() + } +} + +fn array_insert( + list_array: &GenericListArray, + items_array: &ArrayRef, + pos_array: &ArrayRef, +) -> DataFusionResult { + // TODO: support spark's legacy mode! + + // Heavily inspired by + // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 + + let values = list_array.values(); + let offsets = list_array.offsets(); + let values_data = values.to_data(); + let item_data = items_array.to_data(); + let new_capacity = Capacities::Array(values_data.len() + item_data.len()); + + let mut mutable_values = MutableArrayData::with_capacities(vec![&values_data, &item_data], false, new_capacity); + + let mut new_offsets = vec![O::usize_as(0)]; + let mut new_nulls = Vec::::with_capacity(list_array.len()); + + let pos_data = pos_array.to_data(); + + for (i, offset_window) in offsets.windows(2).enumerate() { + let start = offset_window[0].as_usize(); + let end = offset_window[1].as_usize(); + let pos = pos_data.buffers()[0][i].as_usize(); + let is_item_null = items_array.is_null(i); + + mutable_values.extend(0, start, pos); + mutable_values.extend(1, i, i + 1); + mutable_values.extend(0, pos, end); + if is_item_null { + if start == end { + new_nulls.push(false) + } else { + if values.is_null(i) { + new_nulls.push(false) + } else { + new_nulls.push(true) + } + } + } else { + new_nulls.push(true) + } + new_offsets.push(offsets[i] + O::usize_as(end - start + 1)); + } + + let data = mutable_values.freeze(); + let new_array = GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.data_type().to_owned(), true)), + OffsetBuffer::new(new_offsets.into()), + make_array(data), + Some(NullBuffer::new(new_nulls.into())) + )?; + + Ok(ColumnarValue::Array(Arc::new(new_array))) +} + + #[cfg(test)] mod test { use crate::list::{list_extract, zero_based_index}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8bdc886de..7ef3e3c5b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2235,6 +2235,29 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case ArrayInsert(srcArrayExpr, posExpr, itemExpr, legacyNegativeIndex) => + val srcExprProto = exprToProto(srcArrayExpr, inputs, binding) + val posExprProto = exprToProto(posExpr, inputs, binding) + val itemExprProto = exprToProto(itemExpr, inputs, binding) + if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) { + val arrayInsertBuilder = ExprOuterClass.ArrayInsert + .newBuilder() + .setSrcArrayExpr(srcExprProto.get) + .setPosExpr(posExprProto.get) + .setItemExpr(itemExprProto.get) + .setLegacyNegativeIndex(legacyNegativeIndex) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayInsert(arrayInsertBuilder) + .build() + ) + } else { + withInfo(expr, "unsupported arguments for ArrayInsert", srcArrayExpr, posExpr, itemExpr) + None + } + case ElementAt(child, ordinal, defaultValue, failOnError) if child.dataType.isInstanceOf[ArrayType] => val childExpr = exprToProto(child, inputs, binding) From e870c21c49acbc392ae695856824abc61f7129b4 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 11 Nov 2024 09:11:17 +0100 Subject: [PATCH 02/18] Missing methods --- native/spark-expr/src/list.rs | 59 +++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index e2a263c33..651f044f6 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::{array::{Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, datatypes::ArrowNativeType, record_batch::RecordBatch}; +use arrow::{array::{as_large_list_array, as_list_array, Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, datatypes::ArrowNativeType, record_batch::RecordBatch}; use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; @@ -476,7 +476,18 @@ impl PhysicalExpr for ArrayInsert { if item_value.data_type() != src_element_type { return Err(DataFusionError::Internal(format!("Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}", src_element_type, item_value.data_type()))) } - todo!() + + match src_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&src_value); + array_insert(list_array, &pos_value, &item_value) + }, + DataType::LargeList(_) => { + let list_array = as_large_list_array(&src_value); + array_insert(list_array, &pos_value, &item_value) + }, + _ => unreachable!() // This case is checked already + } } fn children(&self) -> Vec<&Arc> { @@ -554,15 +565,34 @@ fn array_insert( Ok(ColumnarValue::Array(Arc::new(new_array))) } +impl Display for ArrayInsert { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]", + self.src_array_expr, self.pos_expr, self.item_expr + ) + } +} + +impl PartialEq for ArrayInsert { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.src_array_expr.eq(&x.src_array_expr) && self.pos_expr.eq(&x.pos_expr) && self.item_expr.eq(&x.item_expr) && self.legacy_negative_index.eq(&x.legacy_negative_index)) + .unwrap_or(false) + } +} #[cfg(test)] mod test { - use crate::list::{list_extract, zero_based_index}; + use crate::list::{array_insert, list_extract, zero_based_index}; use arrow::datatypes::Int32Type; use arrow_array::{Array, Int32Array, ListArray}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; + use std::{ops::Deref, sync::Arc}; #[test] fn test_list_extract_default_value() -> Result<()> { @@ -600,4 +630,27 @@ mod test { ); Ok(()) } + + #[test] + fn test_array_insert() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ]); + let positions = Int32Array::from(vec![1, 0, 0]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); + + let ColumnarValue::Array(result) = array_insert(&list, &Arc::new(items), &Arc::new(positions))? else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(10), Some(2), Some(3)]), + Some(vec![Some(20), Some(4), Some(5)]), + None, + ]); + + assert_eq!(result.to_data(), expected.to_data()); + Ok(()) } } From ac7a2b3b867141b58eb51026ac5bc7d416203311 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 11 Nov 2024 13:50:22 +0100 Subject: [PATCH 03/18] Working version --- .../core/src/execution/datafusion/planner.rs | 20 +- native/proto/src/proto/expr.proto | 1 - native/spark-expr/src/lib.rs | 2 +- native/spark-expr/src/list.rs | 195 ++++++++++++------ .../apache/comet/serde/QueryPlanSerde.scala | 10 +- .../apache/comet/CometExpressionSuite.scala | 12 ++ 6 files changed, 172 insertions(+), 68 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 6f41bf0ad..5bea9ba56 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -97,8 +97,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr, - ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + ArrayInsert, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, + HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -691,6 +691,22 @@ impl PhysicalPlanner { expr.ordinal as usize, ))) } + ExprStruct::ArrayInsert(expr) => { + let src_array_expr = self.create_expr( + expr.src_array_expr.as_ref().unwrap(), + Arc::clone(&input_schema), + )?; + let pos_expr = + self.create_expr(expr.pos_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + let item_expr = + self.create_expr(expr.item_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + Ok(Arc::new(ArrayInsert::new( + src_array_expr, + pos_expr, + item_expr, + expr.legacy_negative_index, + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 2b993104f..00e1ead55 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -409,7 +409,6 @@ message ArrayInsert { Expr pos_expr = 2; Expr item_expr = 3; bool legacy_negative_index = 4; - DataType element_type = 5; } message DataType { diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 614b48f2b..3ec2e886b 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -37,7 +37,7 @@ pub mod utils; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; -pub use list::{GetArrayStructFields, ListExtract}; +pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 651f044f6..304c1f623 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -15,8 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::{array::{as_large_list_array, as_list_array, Capacities, MutableArrayData}, buffer::{NullBuffer, OffsetBuffer}, datatypes::ArrowNativeType, record_batch::RecordBatch}; -use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; +use arrow::{ + array::{as_primitive_array, Capacities, MutableArrayData}, + buffer::{NullBuffer, OffsetBuffer}, + datatypes::ArrowNativeType, + record_batch::RecordBatch, +}; +use arrow_array::{ + make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, + StructArray, +}; use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; @@ -26,7 +34,10 @@ use datafusion_common::{ }; use datafusion_physical_expr::PhysicalExpr; use std::{ - any::Any, fmt::{Display, Formatter}, hash::{Hash, Hasher}, sync::Arc + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, }; #[derive(Debug, Hash)] pub struct ListExtract { @@ -442,67 +453,109 @@ impl PhysicalExpr for ArrayInsert { fn data_type(&self, input_schema: &Schema) -> DataFusionResult { match self.src_array_expr.data_type(input_schema)? { DataType::List(field) => Ok(DataType::List(field)), - DataType::LargeList(field) => Ok(DataType::LargeList(field)) - data_type => Err(DataFusionError::Internal(format!("Unexpected data type in ArrayInsert: {:?}", data_type))) + DataType::LargeList(field) => Ok(DataType::LargeList(field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ArrayInsert: {:?}", + data_type + ))), } } fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - todo!() + self.src_array_expr.nullable(input_schema) } fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let pos_value = self.pos_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let pos_value = self + .pos_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; // Check that index value is integer-like match pos_value.data_type() { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {} + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {} data_type => { - return Err(DataFusionError::Internal(format!("Unexpected index data type in ArrayInsert: {:?}", data_type))) - } + return Err(DataFusionError::Internal(format!( + "Unexpected index data type in ArrayInsert: {:?}", + data_type + ))) + } } // Check that src array is actually an array and get it's value type - let src_value = self.src_array_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let src_value = self + .src_array_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; let src_element_type = match src_value.data_type() { DataType::List(field) => field.data_type(), DataType::LargeList(field) => field.data_type(), data_type => { - return Err(DataFusionError::Internal(format!("Unexpected src array type in ArrayInsert: {:?}", data_type))) + return Err(DataFusionError::Internal(format!( + "Unexpected src array type in ArrayInsert: {:?}", + data_type + ))) } }; // Check that inserted value has the same type as an array - let item_value = self.item_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let item_value = self + .item_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; if item_value.data_type() != src_element_type { - return Err(DataFusionError::Internal(format!("Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}", src_element_type, item_value.data_type()))) + return Err(DataFusionError::Internal(format!( + "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}", + src_element_type, + item_value.data_type() + ))); } match src_value.data_type() { DataType::List(_) => { - let list_array = as_list_array(&src_value); + let list_array = as_list_array(&src_value)?; array_insert(list_array, &pos_value, &item_value) - }, + } DataType::LargeList(_) => { - let list_array = as_large_list_array(&src_value); + let list_array = as_large_list_array(&src_value)?; array_insert(list_array, &pos_value, &item_value) - }, - _ => unreachable!() // This case is checked already + } + _ => unreachable!(), // This case is checked already } } fn children(&self) -> Vec<&Arc> { - todo!() + vec![&self.src_array_expr, &self.pos_expr, &self.item_expr] } fn with_new_children( self: Arc, children: Vec>, ) -> DataFusionResult> { - todo!() + match children.len() { + 3 => Ok(Arc::new(ArrayInsert::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + Arc::clone(&children[2]), + self.legacy_negative_index, + ))), + _ => internal_err!("ArrayInsert should have exactly three childrens"), + } } fn dyn_hash(&self, _state: &mut dyn Hasher) { - todo!() + let mut s = _state; + self.src_array_expr.hash(&mut s); + self.pos_expr.hash(&mut s); + self.item_expr.hash(&mut s); + self.legacy_negative_index.hash(&mut s); + self.hash(&mut s); } } @@ -512,37 +565,38 @@ fn array_insert( pos_array: &ArrayRef, ) -> DataFusionResult { // TODO: support spark's legacy mode! - - // Heavily inspired by + + // Heavily inspired by // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 - + let values = list_array.values(); let offsets = list_array.offsets(); let values_data = values.to_data(); let item_data = items_array.to_data(); let new_capacity = Capacities::Array(values_data.len() + item_data.len()); - let mut mutable_values = MutableArrayData::with_capacities(vec![&values_data, &item_data], false, new_capacity); + let mut mutable_values = + MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity); let mut new_offsets = vec![O::usize_as(0)]; let mut new_nulls = Vec::::with_capacity(list_array.len()); - let pos_data = pos_array.to_data(); + let pos_data: &Int32Array = as_primitive_array(&pos_array); // TODO: How to make it works in generic version? - for (i, offset_window) in offsets.windows(2).enumerate() { + for (row_index, offset_window) in offsets.windows(2).enumerate() { let start = offset_window[0].as_usize(); let end = offset_window[1].as_usize(); - let pos = pos_data.buffers()[0][i].as_usize(); - let is_item_null = items_array.is_null(i); + let pos = (pos_data.values()[row_index] - 1).as_usize(); // Spark uses indexes started from one + let is_item_null = items_array.is_null(row_index); - mutable_values.extend(0, start, pos); - mutable_values.extend(1, i, i + 1); - mutable_values.extend(0, pos, end); + mutable_values.extend(0, start, start + pos); + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend(0, start + pos, end); if is_item_null { if start == end { new_nulls.push(false) } else { - if values.is_null(i) { + if values.is_null(row_index) { new_nulls.push(false) } else { new_nulls.push(true) @@ -551,15 +605,20 @@ fn array_insert( } else { new_nulls.push(true) } - new_offsets.push(offsets[i] + O::usize_as(end - start + 1)); + new_offsets.push(new_offsets[row_index] + O::usize_as(end - start + 1)); } - let data = mutable_values.freeze(); + let data = make_array(mutable_values.freeze()); + let data_type = match list_array.data_type() { + DataType::List(field) => field.data_type(), + DataType::LargeList(field) => field.data_type(), + _ => unreachable!() + }; let new_array = GenericListArray::::try_new( - Arc::new(Field::new("item", list_array.data_type().to_owned(), true)), - OffsetBuffer::new(new_offsets.into()), - make_array(data), - Some(NullBuffer::new(new_nulls.into())) + Arc::new(Field::new("item", data_type.clone(), true)), + OffsetBuffer::new(new_offsets.into()), + data, + Some(NullBuffer::new(new_nulls.into())), )?; Ok(ColumnarValue::Array(Arc::new(new_array))) @@ -579,7 +638,12 @@ impl PartialEq for ArrayInsert { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.src_array_expr.eq(&x.src_array_expr) && self.pos_expr.eq(&x.pos_expr) && self.item_expr.eq(&x.item_expr) && self.legacy_negative_index.eq(&x.legacy_negative_index)) + .map(|x| { + self.src_array_expr.eq(&x.src_array_expr) + && self.pos_expr.eq(&x.pos_expr) + && self.item_expr.eq(&x.item_expr) + && self.legacy_negative_index.eq(&x.legacy_negative_index) + }) .unwrap_or(false) } } @@ -589,10 +653,10 @@ mod test { use crate::list::{array_insert, list_extract, zero_based_index}; use arrow::datatypes::Int32Type; - use arrow_array::{Array, Int32Array, ListArray}; + use arrow_array::{Array, ArrayRef, Int32Array, ListArray}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; - use std::{ops::Deref, sync::Arc}; + use std::sync::Arc; #[test] fn test_list_extract_default_value() -> Result<()> { @@ -633,24 +697,33 @@ mod test { #[test] fn test_array_insert() -> Result<()> { - let list = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - None, - ]); - let positions = Int32Array::from(vec![1, 0, 0]); - let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); - - let ColumnarValue::Array(result) = array_insert(&list, &Arc::new(items), &Arc::new(positions))? else { - unreachable!() - }; + // Test inserting an item into a list array + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let positions = Int32Array::from(vec![2, 1, 1]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); - let expected = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(10), Some(2), Some(3)]), - Some(vec![Some(20), Some(4), Some(5)]), - None, - ]); + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(10), Some(2), Some(3)]), + Some(vec![Some(20), Some(4), Some(5)]), + Some(vec![Some(30)]), + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); - assert_eq!(result.to_data(), expected.to_data()); - Ok(()) } + Ok(()) + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7ef3e3c5b..66e953ef9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2251,10 +2251,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim ExprOuterClass.Expr .newBuilder() .setArrayInsert(arrayInsertBuilder) - .build() - ) + .build()) } else { - withInfo(expr, "unsupported arguments for ArrayInsert", srcArrayExpr, posExpr, itemExpr) + withInfo( + expr, + "unsupported arguments for ArrayInsert", + srcArrayExpr, + posExpr, + itemExpr) None } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0d00867d1..ffdb16608 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2313,4 +2313,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("ArrayInsert") { + Seq(true, false).foreach(dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + val df = spark.read + .parquet(path.toString) + .select(array_insert(array(col("_4"), lit(null)), lit(1), lit(1)).alias("arr")) + checkSparkAnswerAndOperator(df.select("arr")) + }) + } } From 9d9518e2658489b762760e564a1d306f50f765db Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 11 Nov 2024 13:53:46 +0100 Subject: [PATCH 04/18] Reformat code --- native/spark-expr/src/list.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 304c1f623..604446fd6 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -22,8 +22,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_array::{ - make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, - StructArray, + make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray, }; use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; @@ -612,7 +611,7 @@ fn array_insert( let data_type = match list_array.data_type() { DataType::List(field) => field.data_type(), DataType::LargeList(field) => field.data_type(), - _ => unreachable!() + _ => unreachable!(), }; let new_array = GenericListArray::::try_new( Arc::new(Field::new("item", data_type.clone(), true)), From 6e0d5f45a641b9d37be6f90d96a757d4636f0b9b Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 11 Nov 2024 14:02:43 +0100 Subject: [PATCH 05/18] Fix code-style --- native/spark-expr/src/list.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 604446fd6..7b1a82c8f 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -592,14 +592,10 @@ fn array_insert( mutable_values.extend(1, row_index, row_index + 1); mutable_values.extend(0, start + pos, end); if is_item_null { - if start == end { + if (start == end) || (values.is_null(row_index)) { new_nulls.push(false) } else { - if values.is_null(row_index) { - new_nulls.push(false) - } else { - new_nulls.push(true) - } + new_nulls.push(true) } } else { new_nulls.push(true) From e4b5e4c6f0033340af03e4c5f76c48527d2b1264 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 11 Nov 2024 14:13:55 +0100 Subject: [PATCH 06/18] Add comments about spark's implementation. --- native/spark-expr/src/list.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 7b1a82c8f..0c67df341 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -469,19 +469,14 @@ impl PhysicalExpr for ArrayInsert { .pos_expr .evaluate(batch)? .into_array(batch.num_rows())?; - // Check that index value is integer-like + + // Spark supports only IntegerType (Int32): + // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4737 match pos_value.data_type() { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => {} + DataType::Int32 => {} data_type => { return Err(DataFusionError::Internal(format!( - "Unexpected index data type in ArrayInsert: {:?}", + "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32", data_type ))) } @@ -563,7 +558,7 @@ fn array_insert( items_array: &ArrayRef, pos_array: &ArrayRef, ) -> DataFusionResult { - // TODO: support spark's legacy mode! + // TODO: support spark's legacy mode and negative indices! // Heavily inspired by // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 @@ -580,7 +575,7 @@ fn array_insert( let mut new_offsets = vec![O::usize_as(0)]; let mut new_nulls = Vec::::with_capacity(list_array.len()); - let pos_data: &Int32Array = as_primitive_array(&pos_array); // TODO: How to make it works in generic version? + let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions for (row_index, offset_window) in offsets.windows(2).enumerate() { let start = offset_window[0].as_usize(); From 19230bfe5d717d51fbdd2adf60a840c4be5d63f1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 13 Nov 2024 12:34:38 +0100 Subject: [PATCH 07/18] Implement negative indices + fix tests for spark < 3.4 --- native/spark-expr/src/list.rs | 78 +++++++++++++++++-- .../apache/comet/serde/QueryPlanSerde.scala | 17 ++-- 2 files changed, 80 insertions(+), 15 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 0c67df341..81ecad4bd 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -38,6 +38,12 @@ use std::{ hash::{Hash, Hasher}, sync::Arc, }; + +// 2147483632 == java.lang.Integer.MAX_VALUE - 15 +// It is a value of ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH +// https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java +const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632; + #[derive(Debug, Hash)] pub struct ListExtract { child: Arc, @@ -514,11 +520,21 @@ impl PhysicalExpr for ArrayInsert { match src_value.data_type() { DataType::List(_) => { let list_array = as_list_array(&src_value)?; - array_insert(list_array, &pos_value, &item_value) + array_insert( + list_array, + &pos_value, + &item_value, + self.legacy_negative_index, + ) } DataType::LargeList(_) => { let list_array = as_large_list_array(&src_value)?; - array_insert(list_array, &pos_value, &item_value) + array_insert( + list_array, + &pos_value, + &item_value, + self.legacy_negative_index, + ) } _ => unreachable!(), // This case is checked already } @@ -557,10 +573,11 @@ fn array_insert( list_array: &GenericListArray, items_array: &ArrayRef, pos_array: &ArrayRef, + legacy_mode: bool, ) -> DataFusionResult { // TODO: support spark's legacy mode and negative indices! - // Heavily inspired by + // The code is based on the implementation of array_append from DataFusion // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 let values = list_array.values(); @@ -578,14 +595,59 @@ fn array_insert( let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions for (row_index, offset_window) in offsets.windows(2).enumerate() { + let pos = pos_data.values()[row_index]; let start = offset_window[0].as_usize(); let end = offset_window[1].as_usize(); - let pos = (pos_data.values()[row_index] - 1).as_usize(); // Spark uses indexes started from one let is_item_null = items_array.is_null(row_index); - mutable_values.extend(0, start, start + pos); - mutable_values.extend(1, row_index, row_index + 1); - mutable_values.extend(0, start + pos, end); + if pos == 0 { + return Err(DataFusionError::Internal(format!( + "Position for array_insert should be greter or less than zero" + ))); + } + + if (pos > 0) || ((-pos).as_usize() < (start - end + 1)) { + let new_array_len = std::cmp::max(end - start + 1, pos.as_usize()); + if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { + return Err(DataFusionError::Internal(format!( + "Max array length in Spark is {:?}, but got {:?}", + MAX_ROUNDED_ARRAY_LENGTH, new_array_len + ))); + } + + let corrected_pos = if pos > 0 { + (pos - 1).as_usize() + } else { + (pos + if legacy_mode { 0 } else { 1 }).as_usize() + (end - start + 1) + }; + if corrected_pos < end { + mutable_values.extend(0, start, start + corrected_pos); + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend(0, start + corrected_pos, end); + } else { + mutable_values.extend(0, start, end); + mutable_values.extend_nulls(new_array_len - (end - start + 1)); + mutable_values.extend(1, row_index, row_index + 1); + } + new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + } else { + // This comment is takes from the Apache Spark source code as is: + // special case- if the new position is negative but larger than the current array size + // place the new item at start of array, place the current array contents at the end + // and fill the newly created array elements inbetween with a null + let base_offset = if legacy_mode { 1 } else { 0 }; + let new_array_len = (-pos + base_offset).as_usize(); + if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { + return Err(DataFusionError::Internal(format!( + "Max array length in Spark is {:?}, but got {:?}", + MAX_ROUNDED_ARRAY_LENGTH, new_array_len + ))); + } + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend_nulls(new_array_len - 1 - (start - end + 1)); + mutable_values.extend(0, new_array_len - (start - end + 1), new_array_len); + new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + } if is_item_null { if (start == end) || (values.is_null(row_index)) { new_nulls.push(false) @@ -595,7 +657,6 @@ fn array_insert( } else { new_nulls.push(true) } - new_offsets.push(new_offsets[row_index] + O::usize_as(end - start + 1)); } let data = make_array(mutable_values.freeze()); @@ -701,6 +762,7 @@ mod test { &list, &(Arc::new(items) as ArrayRef), &(Arc::new(positions) as ArrayRef), + false, )? else { unreachable!() diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 66e953ef9..36d8c1891 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2235,10 +2235,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case ArrayInsert(srcArrayExpr, posExpr, itemExpr, legacyNegativeIndex) => - val srcExprProto = exprToProto(srcArrayExpr, inputs, binding) - val posExprProto = exprToProto(posExpr, inputs, binding) - val itemExprProto = exprToProto(itemExpr, inputs, binding) + case expr if expr.prettyName == "array_insert" => { + val srcExprProto = exprToProto(expr.children(0), inputs, binding) + val posExprProto = exprToProto(expr.children(1), inputs, binding) + val itemExprProto = exprToProto(expr.children(2), inputs, binding) + val legacyNegativeIndex = + SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) { val arrayInsertBuilder = ExprOuterClass.ArrayInsert .newBuilder() @@ -2256,11 +2258,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo( expr, "unsupported arguments for ArrayInsert", - srcArrayExpr, - posExpr, - itemExpr) + expr.children(0), + expr.children(1), + expr.children(2)) None } + } case ElementAt(child, ordinal, defaultValue, failOnError) if child.dataType.isInstanceOf[ArrayType] => From 58ecb82bd0d3500e1191aceac7c9ec218cab8253 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 13 Nov 2024 12:40:25 +0100 Subject: [PATCH 08/18] Fix code-style --- native/spark-expr/src/list.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 81ecad4bd..7780b3837 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -601,9 +601,9 @@ fn array_insert( let is_item_null = items_array.is_null(row_index); if pos == 0 { - return Err(DataFusionError::Internal(format!( - "Position for array_insert should be greter or less than zero" - ))); + return Err(DataFusionError::Internal( + "Position for array_insert should be greter or less than zero".to_string(), + )); } if (pos > 0) || ((-pos).as_usize() < (start - end + 1)) { From a248567162394fae55e3594c01f092648149c486 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 13 Nov 2024 12:53:45 +0100 Subject: [PATCH 09/18] Fix scalastyle --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 36d8c1891..e468be966 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2235,7 +2235,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case expr if expr.prettyName == "array_insert" => { + case expr if expr.prettyName == "array_insert" => val srcExprProto = exprToProto(expr.children(0), inputs, binding) val posExprProto = exprToProto(expr.children(1), inputs, binding) val itemExprProto = exprToProto(expr.children(2), inputs, binding) @@ -2263,7 +2263,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr.children(2)) None } - } case ElementAt(child, ordinal, defaultValue, failOnError) if child.dataType.isInstanceOf[ArrayType] => From e4349f5c4d594c9b58bf0d13276667fa17543277 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 13 Nov 2024 13:13:56 +0100 Subject: [PATCH 10/18] Fix tests for spark < 3.4 --- .../scala/org/apache/comet/CometExpressionSuite.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ffdb16608..dbe6e6e13 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2319,10 +2319,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - val df = spark.read - .parquet(path.toString) - .select(array_insert(array(col("_4"), lit(null)), lit(1), lit(1)).alias("arr")) - checkSparkAnswerAndOperator(df.select("arr")) + if (spark.version >= "3.4.0") { + val df = spark.read + .parquet(path.toString) + .withColumn("arr", array(col("_4"), lit(null))) + .select(expr("array_insert(arr, 1, 1)").alias("arrInsertResult")) + checkSparkAnswerAndOperator(df.select("arrInsertResult")) + } }) } } From 0d38ef01cdef9859f225cf250e5a34d195baef6a Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 13 Nov 2024 17:28:49 +0100 Subject: [PATCH 11/18] Fixes & tests - added test for the negative index - added test for the legacy spark mode --- native/spark-expr/src/list.rs | 88 ++++++++++++++++--- .../apache/comet/CometExpressionSuite.scala | 6 +- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 7780b3837..ead101f04 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -522,8 +522,8 @@ impl PhysicalExpr for ArrayInsert { let list_array = as_list_array(&src_value)?; array_insert( list_array, - &pos_value, &item_value, + &pos_value, self.legacy_negative_index, ) } @@ -531,8 +531,8 @@ impl PhysicalExpr for ArrayInsert { let list_array = as_large_list_array(&src_value)?; array_insert( list_array, - &pos_value, &item_value, + &pos_value, self.legacy_negative_index, ) } @@ -575,8 +575,6 @@ fn array_insert( pos_array: &ArrayRef, legacy_mode: bool, ) -> DataFusionResult { - // TODO: support spark's legacy mode and negative indices! - // The code is based on the implementation of array_append from DataFusion // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 @@ -606,8 +604,13 @@ fn array_insert( )); } - if (pos > 0) || ((-pos).as_usize() < (start - end + 1)) { - let new_array_len = std::cmp::max(end - start + 1, pos.as_usize()); + if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) { + let corrected_pos = if pos > 0 { + (pos - 1).as_usize() + } else { + end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 } + }; + let new_array_len = std::cmp::max(end - start + 1, corrected_pos); if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { return Err(DataFusionError::Internal(format!( "Max array length in Spark is {:?}, but got {:?}", @@ -615,11 +618,6 @@ fn array_insert( ))); } - let corrected_pos = if pos > 0 { - (pos - 1).as_usize() - } else { - (pos + if legacy_mode { 0 } else { 1 }).as_usize() + (end - start + 1) - }; if corrected_pos < end { mutable_values.extend(0, start, start + corrected_pos); mutable_values.extend(1, row_index, row_index + 1); @@ -644,8 +642,8 @@ fn array_insert( ))); } mutable_values.extend(1, row_index, row_index + 1); - mutable_values.extend_nulls(new_array_len - 1 - (start - end + 1)); - mutable_values.extend(0, new_array_len - (start - end + 1), new_array_len); + mutable_values.extend_nulls(new_array_len - (end - start + 1)); + mutable_values.extend(0, new_array_len - (end - start + 1), new_array_len); new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); } if is_item_null { @@ -778,4 +776,68 @@ mod test { Ok(()) } + + #[test] + fn test_array_insert_negative_index() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let positions = Int32Array::from(vec![-2, -1, -1]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + false, + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(10), Some(3)]), + Some(vec![Some(4), Some(5), Some(20)]), + Some(vec![Some(30)]), + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); + + Ok(()) + } + + #[test] + fn test_array_insert_legacy_mode() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let positions = Int32Array::from(vec![-1, -1, -1]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + true, + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(10), Some(3)]), + Some(vec![Some(4), Some(20), Some(5)]), + Some(vec![Some(30), None]), + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); + + Ok(()) + } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index dbe6e6e13..8600ce652 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2322,9 +2322,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { if (spark.version >= "3.4.0") { val df = spark.read .parquet(path.toString) - .withColumn("arr", array(col("_4"), lit(null))) - .select(expr("array_insert(arr, 1, 1)").alias("arrInsertResult")) + .withColumn("arr", array(col("_4"), lit(null), col("_4"))) + .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) + .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) checkSparkAnswerAndOperator(df.select("arrInsertResult")) + checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) } }) } From c7f26f9bc9b68d78e729be6c468739826e12b3bf Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 14 Nov 2024 06:55:49 +0100 Subject: [PATCH 12/18] Use assume(isSpark34Plus) in tests --- .../org/apache/comet/CometExpressionSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8600ce652..78430c394 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2315,19 +2315,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ArrayInsert") { + assume(isSpark34Plus) Seq(true, false).foreach(dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - if (spark.version >= "3.4.0") { - val df = spark.read - .parquet(path.toString) - .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) - .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) - checkSparkAnswerAndOperator(df.select("arrInsertResult")) - checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) - } + val df = spark.read + .parquet(path.toString) + .withColumn("arr", array(col("_4"), lit(null), col("_4"))) + .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) + .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) + checkSparkAnswerAndOperator(df.select("arrInsertResult")) + checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) }) } } From f832cf09c4818c2dd1f63f9deb4ed9871e3d0027 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 14 Nov 2024 09:17:37 +0100 Subject: [PATCH 13/18] Test else-branch & improve coverage --- .../org/apache/comet/CometExpressionSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0754b5c71..9b62ec476 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2353,4 +2353,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) }) } + + test("ArrayInsertUnsupportedArgs") { + // This test checks that the else branch in ArrayInsert + // mapping to the comet is valid and fallback to spark is working fine. + assume(isSpark34Plus) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000) + spark.read + .parquet(path.toString) + .withColumn("arr", array(col("_4"), lit(null), col("_4"))) + .withColumn("idx", udf((x: Int) => x).apply(col("_4"))) + .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) + } + } } From 6e41858cc68b030201e8798de270ab52d7f6ff4b Mon Sep 17 00:00:00 2001 From: Sem Date: Tue, 19 Nov 2024 07:39:20 +0100 Subject: [PATCH 14/18] Update native/spark-expr/src/list.rs Co-authored-by: Andy Grove --- native/spark-expr/src/list.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index ead101f04..009d6aeba 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -478,14 +478,11 @@ impl PhysicalExpr for ArrayInsert { // Spark supports only IntegerType (Int32): // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4737 - match pos_value.data_type() { - DataType::Int32 => {} - data_type => { - return Err(DataFusionError::Internal(format!( - "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32", - data_type - ))) - } + if !matches!(pos_value.data_type(), DataType::Int32) { + return Err(DataFusionError::Internal(format!( + "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32", + pos_value.data_type() + ))) } // Check that src array is actually an array and get it's value type From 4770fcea0b656a53271e1ef494ca2f1ee80b900d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Tue, 19 Nov 2024 19:33:34 +0100 Subject: [PATCH 15/18] Fix fallback test In one case there is a zero in index and test fails due to spark error --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 744d81c28..35f374bf0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2385,7 +2385,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = spark.read .parquet(path.toString) .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((x: Int) => x).apply(col("_4"))) + .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswer(df.select("arrUnsupportedArgs")) } From e9ef94135a459a0c8fc7b437bda54df30f6d7a0f Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Wed, 20 Nov 2024 17:32:36 +0100 Subject: [PATCH 16/18] Adjust the behaviour for the NULL case to Spark --- native/spark-expr/src/list.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index a4ac77cc2..1cd099467 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -598,6 +598,14 @@ fn array_insert( let end = offset_window[1].as_usize(); let is_item_null = items_array.is_null(row_index); + if list_array.is_null(row_index) { + // In Spark if value of the array is NULL than nothing happens + mutable_values.extend_nulls(1); + new_offsets.push(new_offsets[row_index] + O::one()); + new_nulls.push(false); + continue; + } + if pos == 0 { return Err(DataFusionError::Internal( "Position for array_insert should be greter or less than zero".to_string(), @@ -786,7 +794,7 @@ mod test { Some(vec![Some(30), None]), Some(vec![Some(1), Some(2), Some(3), None, Some(100)]), Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]), - Some(vec![Some(40)]), + None, ]); assert_eq!(&result.to_data(), &expected.to_data()); @@ -822,7 +830,7 @@ mod test { Some(vec![Some(1), Some(2), Some(10), Some(3)]), Some(vec![Some(4), Some(5), Some(20)]), Some(vec![Some(100), None, Some(1)]), - Some(vec![Some(30)]), + None, ]); assert_eq!(&result.to_data(), &expected.to_data()); @@ -855,7 +863,7 @@ mod test { let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(10), Some(3)]), Some(vec![Some(4), Some(20), Some(5)]), - Some(vec![Some(30), None]), + None, ]); assert_eq!(&result.to_data(), &expected.to_data()); From 6431ad9d53818fac5c862b7dabbe160d41c8de4b Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 21 Nov 2024 10:25:41 +0100 Subject: [PATCH 17/18] Move the logic of type checking to the method --- native/spark-expr/src/list.rs | 36 ++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 1cd099467..765412f56 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -448,6 +448,19 @@ impl ArrayInsert { legacy_negative_index, } } + + pub fn array_type(&self, data_type: &DataType) -> DataFusionResult { + match data_type { + DataType::List(field) => Ok(DataType::List(Arc::clone(field))), + DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))), + data_type => { + return Err(DataFusionError::Internal(format!( + "Unexpected src array type in ArrayInsert: {:?}", + data_type + ))) + } + } + } } impl PhysicalExpr for ArrayInsert { @@ -456,14 +469,7 @@ impl PhysicalExpr for ArrayInsert { } fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - match self.src_array_expr.data_type(input_schema)? { - DataType::List(field) => Ok(DataType::List(field)), - DataType::LargeList(field) => Ok(DataType::LargeList(field)), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ArrayInsert: {:?}", - data_type - ))), - } + self.array_type(&self.src_array_expr.data_type(input_schema)?) } fn nullable(&self, input_schema: &Schema) -> DataFusionResult { @@ -490,15 +496,11 @@ impl PhysicalExpr for ArrayInsert { .src_array_expr .evaluate(batch)? .into_array(batch.num_rows())?; - let src_element_type = match src_value.data_type() { - DataType::List(field) => field.data_type(), - DataType::LargeList(field) => field.data_type(), - data_type => { - return Err(DataFusionError::Internal(format!( - "Unexpected src array type in ArrayInsert: {:?}", - data_type - ))) - } + + let src_element_type = match self.array_type(src_value.data_type())? { + DataType::List(field) => &field.data_type().clone(), + DataType::LargeList(field) => &field.data_type().clone(), + _ => unreachable!(), }; // Check that inserted value has the same type as an array From e02d20f17ee4f6c330152c80d13492151a7b0aba Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Thu, 21 Nov 2024 10:29:24 +0100 Subject: [PATCH 18/18] Fix code-style --- native/spark-expr/src/list.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 765412f56..7dc17b568 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -453,12 +453,10 @@ impl ArrayInsert { match data_type { DataType::List(field) => Ok(DataType::List(Arc::clone(field))), DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))), - data_type => { - return Err(DataFusionError::Internal(format!( - "Unexpected src array type in ArrayInsert: {:?}", - data_type - ))) - } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected src array type in ArrayInsert: {:?}", + data_type + ))), } } }