diff --git a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp index 625e8eb4c5f7..80a727778351 100644 --- a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp +++ b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp @@ -48,7 +48,9 @@ extern void registerMaxDataSizeForStatsAggregate(const std::string& prefix); extern void registerMultiMapAggAggregate(const std::string& prefix); extern void registerSumDataSizeForStatsAggregate(const std::string& prefix); extern void registerReduceAgg(const std::string& prefix); -extern void registerSetAggAggregate(const std::string& prefix); +extern void registerSetAggAggregate( + const std::string& prefix, + bool withCompanionFunctions); extern void registerSetUnionAggregate(const std::string& prefix); extern void registerApproxDistinctAggregates( @@ -108,7 +110,7 @@ void registerAllAggregateFunctions( registerMinMaxAggregates(prefix, withCompanionFunctions, overwrite); registerMinMaxByAggregates(prefix); registerReduceAgg(prefix); - registerSetAggAggregate(prefix); + registerSetAggAggregate(prefix, withCompanionFunctions); registerSetUnionAggregate(prefix); registerSumAggregate(prefix); registerVarianceAggregates(prefix, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/SetAggregates.cpp b/velox/functions/prestosql/aggregates/SetAggregates.cpp index c83846884fb4..c1f69f0c1d43 100644 --- a/velox/functions/prestosql/aggregates/SetAggregates.cpp +++ b/velox/functions/prestosql/aggregates/SetAggregates.cpp @@ -427,7 +427,9 @@ std::unique_ptr create( } // namespace -void registerSetAggAggregate(const std::string& prefix) { +void registerSetAggAggregate( + const std::string& prefix, + bool withCompanionFunctions) { std::vector> signatures = { exec::AggregateFunctionSignatureBuilder() .typeVariable("T") @@ -491,7 +493,8 @@ void registerSetAggAggregate(const std::string& prefix) { VELOX_UNREACHABLE( "Unexpected type {}", mapTypeKindToName(typeKind)); } - }); + }, + withCompanionFunctions); } void registerSetUnionAggregate(const std::string& prefix) {