diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index 06998e79e2c0..27b479d765b3 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -245,10 +245,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( const core::QueryConfig& config) -> std::unique_ptr { if (auto func = getAggregateFunctionEntry(name)) { + core::AggregationNode::Step usedStep{ + core::AggregationNode::Step::kPartial}; if (!exec::isRawInput(step)) { - step = core::AggregationNode::Step::kIntermediate; + usedStep = core::AggregationNode::Step::kIntermediate; } - auto fn = func->factory(step, argTypes, resultType, config); + auto fn = + func->factory(usedStep, argTypes, resultType, config); VELOX_CHECK_NOT_NULL(fn); return std::make_unique< AggregateCompanionAdapter::PartialFunction>( @@ -366,56 +369,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( const std::string& name, const std::vector& signatures, bool overwrite) { + bool registered = false; if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( signatures)) { - return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); + registered |= + registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); } auto mergeExtractSignatures = CompanionSignatures::mergeExtractFunctionSignatures(signatures); if (mergeExtractSignatures.empty()) { - return false; + return registered; } auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionName(name); - return exec::registerAggregateFunction( - mergeExtractFunctionName, - std::move(mergeExtractSignatures), - [name, mergeExtractFunctionName]( - core::AggregationNode::Step /*step*/, - const std::vector& argTypes, - const TypePtr& resultType, - const core::QueryConfig& config) - -> std::unique_ptr { - const auto& [originalResultType, _] = - resolveAggregateFunction(mergeExtractFunctionName, argTypes); - if (!originalResultType) { - // TODO: limitation -- result type must be resolveable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); - } - - if (auto func = getAggregateFunctionEntry(name)) { - auto fn = func->factory( - core::AggregationNode::Step::kFinal, - argTypes, - originalResultType, - config); - VELOX_CHECK_NOT_NULL(fn); - return std::make_unique< - AggregateCompanionAdapter::MergeExtractFunction>( - std::move(fn), resultType); - } - VELOX_FAIL( - "Original aggregation function {} not found: {}", - name, - mergeExtractFunctionName); - }, - /*registerCompanionFunctions*/ false, - overwrite) - .mainFunction; + registered |= + exec::registerAggregateFunction( + mergeExtractFunctionName, + std::move(mergeExtractSignatures), + [name, mergeExtractFunctionName]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& config) -> std::unique_ptr { + if (auto func = getAggregateFunctionEntry(name)) { + auto fn = func->factory( + core::AggregationNode::Step::kFinal, + argTypes, + resultType, + config); + VELOX_CHECK_NOT_NULL(fn); + return std::make_unique< + AggregateCompanionAdapter::MergeExtractFunction>( + std::move(fn), resultType); + } + VELOX_FAIL( + "Original aggregation function {} not found: {}", + name, + mergeExtractFunctionName); + }, + /*registerCompanionFunctions*/ false, + overwrite) + .mainFunction; + return registered; } bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( diff --git a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp index 38341d99fe8e..c0b24eae9003 100644 --- a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp +++ b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp @@ -288,7 +288,9 @@ class BloomFilterAggAggregate : public exec::Aggregate { } // namespace exec::AggregateRegistrationResult registerBloomFilterAggAggregate( - const std::string& name) { + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() .argumentType("bigint") @@ -318,6 +320,8 @@ exec::AggregateRegistrationResult registerBloomFilterAggAggregate( const TypePtr& resultType, const core::QueryConfig& config) -> std::unique_ptr { return std::make_unique(resultType, config); - }); + }, + withCompanionFunctions, + overwrite); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h index 48c69f305dbc..7cd54e1140e7 100644 --- a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h +++ b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h @@ -23,6 +23,8 @@ namespace facebook::velox::functions::aggregate::sparksql { exec::AggregateRegistrationResult registerBloomFilterAggAggregate( - const std::string& name); + const std::string& name, + bool withCompanionFunctions, + bool overwrite); } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index bcbf012ddac1..79e9c076aa1a 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -39,7 +39,8 @@ void registerAggregateFunctions( registerFirstLastAggregates(prefix, withCompanionFunctions, overwrite); registerMinMaxByAggregates(prefix, withCompanionFunctions, overwrite); registerBitwiseXorAggregate(prefix, withCompanionFunctions, overwrite); - registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); + registerBloomFilterAggAggregate( + prefix + "bloom_filter_agg", withCompanionFunctions, overwrite); registerAverage(prefix + "avg", withCompanionFunctions, overwrite); registerSum(prefix + "sum", withCompanionFunctions, overwrite); }