diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d82c286d4..f5c506688 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -699,17 +699,27 @@ impl PhysicalPlanner { let right = self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; let return_type = left.data_type(&input_schema)?; - let args = vec![left, right]; + let args = vec![Arc::clone(&left), right]; let datafusion_array_append = Arc::new(ScalarUDF::new_from_impl(ArrayAppend::new())); - let scalar_expr: Arc = Arc::new(ScalarFunctionExpr::new( + let array_append_expr: Arc = Arc::new(ScalarFunctionExpr::new( "array_append", datafusion_array_append, args, return_type, )); - Ok(scalar_expr) + let is_null_expr: Arc = Arc::new(IsNullExpr::new(left)); + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_append_expr), + ) + .unwrap(); + Ok(Arc::new(case_expr)) } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d64ee71b9..a5c19eb89 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2259,11 +2259,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } case ArrayAppend(left, right) => - val leftExpr = exprToProto(left, inputs, binding) - val rightExpr = exprToProto(right, inputs, binding) - - val optExpr = scalarExprToProto("array_append", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) + createBinaryExpr(left, right, inputs).map { builder => + ExprOuterClass.Expr + .newBuilder() + .setArrayAppend(builder) + .build() + } case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9466fcd92..9105023c0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2315,6 +2315,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("array_append") { + assume(isSpark34Plus) Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") @@ -2328,6 +2329,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(df.select(array_append(array(col("_6"), col("_7")), 6.5))) checkSparkAnswerAndOperator(df.select(array_append(array(col("_8")), "test"))) checkSparkAnswerAndOperator(df.select(array_append(array(col("_19")), col("_19")))) + checkSparkAnswerAndOperator( + df.select(array_append(expr("CASE WHEN _2=_3 THEN array(_4) END"), col("_4")))) } } }