Skip to content

Commit

Permalink
[VL] Enable collect_set, min, max for complex types
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Jul 30, 2024
1 parent 4ae223c commit 818afd8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.gluten.extension.columnar.validator.FallbackInjects
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -1141,13 +1140,10 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
StructField("lastUpdated", LongType, true),
StructField("version", LongType, true))),
true)))
val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS)
df.select(collect_set(col("txn"))).collect

df.select(min(col("txn"))).collect

df.select(max(col("txn"))).collect

spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS).createOrReplaceTempView("t1")
runQueryAndCompare("select collect_set(txn), min(txn), max(txn) from t1") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
}

test("drop redundant partial sort which has pre-project when offload sortAgg") {
Expand Down
1 change: 0 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"xxhash64", "xxhash64_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"negative", "unaryminus"},
{"get_array_item", "get"}};

Expand Down
9 changes: 0 additions & 9 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1048,15 +1048,6 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
LOG_VALIDATION_MSG("Validation failed for function " + funcName + " resolve type in AggregateRel.");
return false;
}
static const std::unordered_set<std::string> notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"};
if (notSupportComplexTypeAggFuncs.find(baseFuncName) != notSupportComplexTypeAggFuncs.end() &&
exec::isRawInput(funcStep)) {
auto type = binder.tryResolveType(signature->argumentTypes()[0]);
if (type->isArray() || type->isMap() || type->isRow()) {
LOG_VALIDATION_MSG("Validation failed for function " + baseFuncName + " complex type is not supported.");
return false;
}
}

resolved = true;
break;
Expand Down

0 comments on commit 818afd8

Please sign in to comment.