diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala index 4f6f4eb224d0..ae6306cc0d4a 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala @@ -22,7 +22,9 @@ import org.apache.gluten.extension.columnar.validator.FallbackInjects import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { @@ -1112,6 +1114,27 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } } + + test("complex type with null") { + val jsonStr = """{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}""" + val jsonSchema = StructType( + Seq( + StructField( + "txn", + StructType( + Seq( + StructField("appId", StringType, true), + StructField("lastUpdated", LongType, true), + StructField("version", LongType, true))), + true))) + val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS) + df.select(collect_set(col("txn"))).collect + + df.select(min(col("txn"))).collect + + df.select(max(col("txn"))).collect + + } } class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite { diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index abb2bbc560f4..a3b46d7d08e1 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -1045,6 +1045,16 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait LOG_VALIDATION_MSG("Validation failed for function " + funcName + " resolve type in AggregateRel."); return false; } + static const std::unordered_set notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"}; + if (notSupportComplexTypeAggFuncs.find(baseFuncName) != notSupportComplexTypeAggFuncs.end() && + exec::isRawInput(funcStep)) { + auto type = binder.tryResolveType(signature->argumentTypes()[0]); + if (type->isArray() || type->isMap() || type->isRow()) { + LOG_VALIDATION_MSG("Validation failed for function " + baseFuncName + " complex type is not supported."); + return false; + } + } + resolved = true; break; }