Skip to content

Commit

Permalink
[VL] Fall back collect_set, min and max when input is complex type (#…
Browse files Browse the repository at this point in the history
…5934)

[VL] Fall back collect_set, min and max when input is complex type.
  • Loading branch information
zhli1142015 authored May 31, 2024
1 parent f48b9fa commit e870de8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import org.apache.gluten.extension.columnar.validator.FallbackInjects
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {

Expand Down Expand Up @@ -1112,6 +1114,27 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}
}

test("complex type with null") {
val jsonStr = """{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}"""
val jsonSchema = StructType(
Seq(
StructField(
"txn",
StructType(
Seq(
StructField("appId", StringType, true),
StructField("lastUpdated", LongType, true),
StructField("version", LongType, true))),
true)))
val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS)
df.select(collect_set(col("txn"))).collect

df.select(min(col("txn"))).collect

df.select(max(col("txn"))).collect

}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
Expand Down
10 changes: 10 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,16 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
LOG_VALIDATION_MSG("Validation failed for function " + funcName + " resolve type in AggregateRel.");
return false;
}
static const std::unordered_set<std::string> notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"};
if (notSupportComplexTypeAggFuncs.find(baseFuncName) != notSupportComplexTypeAggFuncs.end() &&
exec::isRawInput(funcStep)) {
auto type = binder.tryResolveType(signature->argumentTypes()[0]);
if (type->isArray() || type->isMap() || type->isRow()) {
LOG_VALIDATION_MSG("Validation failed for function " + baseFuncName + " complex type is not supported.");
return false;
}
}

resolved = true;
break;
}
Expand Down

0 comments on commit e870de8

Please sign in to comment.