diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 0a2f7fef6..13263a595 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -34,7 +34,7 @@ use arrow::{ }; use arrow_array::builder::StringBuilder; use arrow_array::{DictionaryArray, StringArray, StructArray}; -use arrow_schema::{DataType, Schema}; +use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{ cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, }; @@ -714,6 +714,14 @@ fn cast_array( (DataType::Struct(_), DataType::Utf8) => { Ok(casts_struct_to_string(array.as_struct(), &timezone)?) } + (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( + array.as_struct(), + from_type, + to_type, + eval_mode, + timezone, + allow_incompat, + )?), _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) @@ -811,6 +819,35 @@ fn is_datafusion_spark_compatible( } } +/// Cast between struct types based on logic in +/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`. +fn cast_struct_to_struct( + array: &StructArray, + from_type: &DataType, + to_type: &DataType, + eval_mode: EvalMode, + timezone: String, + allow_incompat: bool, +) -> DataFusionResult { + match (from_type, to_type) { + (DataType::Struct(_), DataType::Struct(to_fields)) => { + let mut cast_fields: Vec<(Arc, ArrayRef)> = Vec::with_capacity(to_fields.len()); + for i in 0..to_fields.len() { + let cast_field = cast_array( + Arc::clone(array.column(i)), + to_fields[i].data_type(), + eval_mode, + timezone.clone(), + allow_incompat, + )?; + cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); + } + Ok(Arc::new(StructArray::from(cast_fields))) + } + _ => unreachable!(), + } +} + fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult { // cast each field to a string let string_arrays: Vec = array @@ -1929,7 +1966,7 @@ fn trim_end(s: &str) -> &str { mod tests { use arrow::datatypes::TimestampMicrosecondType; use arrow_array::StringArray; - use arrow_schema::{Field, TimeUnit}; + use arrow_schema::{Field, Fields, TimeUnit}; use std::str::FromStr; use super::*; @@ -2336,4 +2373,75 @@ mod tests { assert_eq!(r#"{4, d}"#, string_array.value(3)); assert_eq!(r#"{5, e}"#, string_array.value(4)); } + + #[test] + fn test_cast_struct_to_struct() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + // change type of "a" from Int32 to Utf8 + let fields = Fields::from(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ]); + let cast_array = spark_cast( + ColumnarValue::Array(c), + &DataType::Struct(fields), + EvalMode::Legacy, + "UTC", + false, + ) + .unwrap(); + if let ColumnarValue::Array(cast_array) = cast_array { + assert_eq!(5, cast_array.len()); + let a = cast_array.as_struct().column(0).as_string::(); + assert_eq!("1", a.value(0)); + } else { + unreachable!() + } + } + + #[test] + fn test_cast_struct_to_struct_drop_column() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + // change type of "a" from Int32 to Utf8 and drop "b" + let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]); + let cast_array = spark_cast( + ColumnarValue::Array(c), + &DataType::Struct(fields), + EvalMode::Legacy, + "UTC", + false, + ) + .unwrap(); + if let ColumnarValue::Array(cast_array) = cast_array { + assert_eq!(5, cast_array.len()); + let struct_array = cast_array.as_struct(); + assert_eq!(1, struct_array.columns().len()); + let a = struct_array.column(0).as_string::(); + assert_eq!("1", a.value(0)); + } else { + unreachable!() + } + } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index eb9800b8d..11d6d049f 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -95,6 +95,16 @@ object CometCast { canCastFromFloat(toType) case (DataTypes.DoubleType, _) => canCastFromDouble(toType) + case (from_struct: StructType, to_struct: StructType) => + from_struct.fields.zip(to_struct.fields).foreach { case (a, b) => + isSupported(a.dataType, b.dataType, timeZoneId, evalMode) match { + case Compatible(_) => + // all good + case other => + return other + } + } + Compatible() case _ => Unsupported } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b5517f40f..1170c55a3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -933,26 +933,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim handleCast(child, inputs, dt, timeZoneId, evalMode(c)) case add @ Add(left, right, _) if supportedDataType(left.dataType) => - createMathExpression(left, right, inputs, add.dataType, getFailOnError(add)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setAdd(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + add.dataType, + getFailOnError(add), + (builder, mathExpr) => builder.setAdd(mathExpr)) case add @ Add(left, _, _) if !supportedDataType(left.dataType) => withInfo(add, s"Unsupported datatype ${left.dataType}") None case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - createMathExpression(left, right, inputs, sub.dataType, getFailOnError(sub)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setSubtract(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + sub.dataType, + getFailOnError(sub), + (builder, mathExpr) => builder.setSubtract(mathExpr)) case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => withInfo(sub, s"Unsupported datatype ${left.dataType}") @@ -960,13 +960,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - createMathExpression(left, right, inputs, mul.dataType, getFailOnError(mul)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setMultiply(expr) - .build() - } + createMathExpression( + left, + right, + inputs, + mul.dataType, + getFailOnError(mul), + (builder, mathExpr) => builder.setMultiply(mathExpr)) case mul @ Multiply(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -984,13 +984,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // For now, use NullIf to swap zeros with nulls. val rightExpr = nullIfWhenPrimitive(right) - createMathExpression(left, rightExpr, inputs, div.dataType, getFailOnError(div)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setDivide(expr) - .build() - } + createMathExpression( + left, + rightExpr, + inputs, + div.dataType, + getFailOnError(div), + (builder, mathExpr) => builder.setDivide(mathExpr)) case div @ Divide(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -1005,13 +1005,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => val rightExpr = nullIfWhenPrimitive(right) - createMathExpression(left, rightExpr, inputs, rem.dataType, getFailOnError(rem)).map { - expr => - ExprOuterClass.Expr - .newBuilder() - .setRemainder(expr) - .build() - } + createMathExpression( + left, + rightExpr, + inputs, + rem.dataType, + getFailOnError(rem), + (builder, mathExpr) => builder.setRemainder(mathExpr)) case rem @ Remainder(left, _, _) => if (!supportedDataType(left.dataType)) { @@ -1023,68 +1023,60 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case EqualTo(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEq(binaryExpr)) case Not(EqualTo(left, right)) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setNeq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setNeq(binaryExpr)) case EqualNullSafe(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEqNullSafe(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) case Not(EqualNullSafe(left, right)) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setNeqNullSafe(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) case GreaterThan(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setGt(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setGt(binaryExpr)) case GreaterThanOrEqual(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setGtEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setGtEq(binaryExpr)) case LessThan(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLt(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLt(binaryExpr)) case LessThanOrEqual(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLtEq(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLtEq(binaryExpr)) case Literal(value, dataType) if supportedDataType(dataType, allowStruct = value == null) => @@ -1221,12 +1213,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Like(left, right, escapeChar) => if (escapeChar == '\\') { - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setLike(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setLike(binaryExpr)) } else { // TODO custom escape char withInfo(expr, s"custom escape character $escapeChar not supported in LIKE") @@ -1251,36 +1242,32 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setRlike(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setRlike(binaryExpr)) case StartsWith(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setStartsWith(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setStartsWith(binaryExpr)) case EndsWith(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setEndsWith(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setEndsWith(binaryExpr)) case Contains(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setContains(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setContains(binaryExpr)) case StringSpace(child) => createUnaryExpr( @@ -1461,20 +1448,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case And(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setAnd(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setAnd(binaryExpr)) case Or(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setOr(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setOr(binaryExpr)) case UnaryExpression(child) if expr.prettyName == "promote_precision" => // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 @@ -1911,31 +1896,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case BitwiseAnd(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseAnd(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) case BitwiseNot(child) => createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr)) case BitwiseOr(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseOr(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) case BitwiseXor(left, right) => - createBinaryExpr(left, right, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseXor(builder) - .build() - } + createBinaryExpr( + left, + right, + inputs, + (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires @@ -1946,12 +1928,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right } - createBinaryExpr(left, rightExpression, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftRight(builder) - .build() - } + createBinaryExpr( + left, + rightExpression, + inputs, + (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) case ShiftLeft(left, right) => // DataFusion bitwise shift right expression requires @@ -1962,13 +1943,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right } - createBinaryExpr(left, rightExpression, inputs).map { builder => - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftLeft(builder) - .build() - } - + createBinaryExpr( + left, + rightExpression, + inputs, + (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) case In(value, list) => in(expr, value, list, inputs, false) @@ -2308,16 +2287,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def createBinaryExpr( left: Expression, right: Expression, - inputs: Seq[Attribute]): Option[ExprOuterClass.BinaryExpr] = { + inputs: Seq[Attribute], + f: ( + ExprOuterClass.Expr.Builder, + ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic BinaryExpr message + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get) + .build() + // call the user-supplied function to wrap BinaryExpr in a top-level Expr + // such as Expr.And or Expr.Or Some( - ExprOuterClass.BinaryExpr - .newBuilder() - .setLeft(leftExpr.get) - .setRight(rightExpr.get) - .build()) + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) } else { withInfo(expr, left, right) None @@ -2329,11 +2319,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right: Expression, inputs: Seq[Attribute], dataType: DataType, - failOnError: Boolean): Option[ExprOuterClass.MathExpr] = { + failOnError: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic MathExpr message val builder = ExprOuterClass.MathExpr.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) @@ -2341,7 +2334,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim serializeDataType(dataType).foreach { t => builder.setReturnType(t) } - Some(builder.build()) + val inner = builder.build() + // call the user-supplied function to wrap MathExpr in a top-level Expr + // such as Expr.Add or Expr.Divide + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) } else { withInfo(expr, left, right) None diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 817545c5d..db9a870dc 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -881,6 +881,20 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast StructType to StructType") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + checkSparkAnswerAndOperator( + "SELECT CAST(CASE WHEN _1 THEN struct(_1, _2, _3, _4) ELSE null END as " + + "struct<_1:string, _2:string, _3:string, _4:string>) FROM tbl") + } + } + } + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") }