Skip to content

Commit

Permalink
Add config for registration (7110)
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf authored and zhztheplayer committed Dec 29, 2023
1 parent b7baa7c commit 3d82641
Show file tree
Hide file tree
Showing 18 changed files with 182 additions and 71 deletions.
9 changes: 7 additions & 2 deletions velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ class BitwiseAggregateBase : public SimpleNumericAggregate<T, T, T> {
};

template <template <typename U> class T>
exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
exec::AggregateRegistrationResult registerBitwise(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
Expand Down Expand Up @@ -106,7 +109,9 @@ exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
name,
inputType->kindName());
}
});
},
withCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::functions::aggregate
11 changes: 8 additions & 3 deletions velox/functions/prestosql/aggregates/BitwiseAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ class BitwiseAndAggregate : public BitwiseAggregateBase<T> {

} // namespace

void registerBitwiseAggregates(const std::string& prefix) {
registerBitwise<BitwiseOrAggregate>(prefix + kBitwiseOr);
registerBitwise<BitwiseAndAggregate>(prefix + kBitwiseAnd);
void registerBitwiseAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerBitwise<BitwiseOrAggregate>(
prefix + kBitwiseOr, withCompanionFunctions, overwrite);
registerBitwise<BitwiseAndAggregate>(
prefix + kBitwiseAnd, withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
9 changes: 7 additions & 2 deletions velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ class CountAggregate : public SimpleNumericAggregate<bool, int64_t, int64_t> {

} // namespace

void registerCountAggregate(const std::string& prefix) {
void registerCountAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.returnType("bigint")
Expand All @@ -177,7 +180,9 @@ void registerCountAggregate(const std::string& prefix) {
VELOX_CHECK_LE(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
});
},
withCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
28 changes: 20 additions & 8 deletions velox/functions/prestosql/aggregates/CovarianceAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,10 @@ template <
typename TIntermediateInput,
typename TIntermediateResult,
typename TResultAccessor>
exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
exec::AggregateRegistrationResult registerCovariance(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures = {
// (double, double) -> double
exec::AggregateFunctionSignatureBuilder()
Expand Down Expand Up @@ -620,37 +623,46 @@ exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
"Unsupported raw input type: {}. Expected DOUBLE or REAL.",
rawInputType->toString())
}
});
},
withCompanionFunctions,
overwrite);
}

} // namespace

void registerCovarianceAggregates(const std::string& prefix) {
void registerCovarianceAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerCovariance<
CovarAccumulator,
CovarIntermediateInput,
CovarIntermediateResult,
CovarPopResultAccessor>(prefix + kCovarPop);
CovarPopResultAccessor>(
prefix + kCovarPop, withCompanionFunctions, overwrite);
registerCovariance<
CovarAccumulator,
CovarIntermediateInput,
CovarIntermediateResult,
CovarSampResultAccessor>(prefix + kCovarSamp);
CovarSampResultAccessor>(
prefix + kCovarSamp, withCompanionFunctions, overwrite);
registerCovariance<
CorrAccumulator,
CorrIntermediateInput,
CorrIntermediateResult,
CorrResultAccessor>(prefix + kCorr);
CorrResultAccessor>(prefix + kCorr, withCompanionFunctions, overwrite);
registerCovariance<
RegrAccumulator,
RegrIntermediateInput,
RegrIntermediateResult,
RegrInterceptResultAccessor>(prefix + kRegrIntercept);
RegrInterceptResultAccessor>(
prefix + kRegrIntercept, withCompanionFunctions, overwrite);
registerCovariance<
RegrAccumulator,
RegrIntermediateInput,
RegrIntermediateResult,
RegrSlopeResultAccessor>(prefix + kRegrSlop);
RegrSlopeResultAccessor>(
prefix + kRegrSlop, withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
18 changes: 13 additions & 5 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,10 @@ template <
typename TNonNumeric,
template <typename T>
class TNumericN>
exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
exec::AggregateRegistrationResult registerMinMax(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.orderableTypeVariable("T")
Expand Down Expand Up @@ -1008,16 +1011,21 @@ exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
inputType->kindName());
}
}
});
},
withCompanionFunctions,
overwrite);
}

} // namespace

void registerMinMaxAggregates(const std::string& prefix) {
void registerMinMaxAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerMinMax<MinAggregate, NonNumericMinAggregate, MinNAggregate>(
prefix + kMin);
prefix + kMin, withCompanionFunctions, overwrite);
registerMinMax<MaxAggregate, NonNumericMaxAggregate, MaxNAggregate>(
prefix + kMax);
prefix + kMax, withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
38 changes: 27 additions & 11 deletions velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ extern void registerAverageAggregate(
bool withCompanionFunctions);
extern void registerBitwiseXorAggregate(const std::string& prefix);
extern void registerChecksumAggregate(const std::string& prefix);
extern void registerCountAggregate(const std::string& prefix);
extern void registerCountAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);
extern void registerCountIfAggregate(const std::string& prefix);
extern void registerEntropyAggregate(const std::string& prefix);
extern void registerGeometricMeanAggregate(const std::string& prefix);
Expand All @@ -49,32 +52,45 @@ extern void registerSetUnionAggregate(const std::string& prefix);
extern void registerApproxDistinctAggregates(
const std::string& prefix,
bool withCompanionFunctions);
extern void registerBitwiseAggregates(const std::string& prefix);
extern void registerBitwiseAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);
extern void registerBoolAggregates(const std::string& prefix);
extern void registerCentralMomentsAggregates(const std::string& prefix);
extern void registerCovarianceAggregates(const std::string& prefix);
extern void registerMinMaxAggregates(const std::string& prefix);
extern void registerCovarianceAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);
extern void registerMinMaxAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);
extern void registerMinMaxByAggregates(const std::string& prefix);
extern void registerSumAggregate(const std::string& prefix);
extern void registerVarianceAggregates(const std::string& prefix);
extern void registerVarianceAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);

void registerAllAggregateFunctions(
const std::string& prefix,
bool withCompanionFunctions) {
bool withCompanionFunctions,
bool overwrite) {
registerApproxDistinctAggregates(prefix, withCompanionFunctions);
registerApproxMostFrequentAggregate(prefix);
registerApproxPercentileAggregate(prefix, withCompanionFunctions);
registerArbitraryAggregate(prefix);
registerArrayAggAggregate(prefix, withCompanionFunctions);
registerAverageAggregate(prefix, withCompanionFunctions);
registerBitwiseAggregates(prefix);
registerBitwiseAggregates(prefix, withCompanionFunctions, overwrite);
registerBitwiseXorAggregate(prefix);
registerBoolAggregates(prefix);
registerCentralMomentsAggregates(prefix);
registerChecksumAggregate(prefix);
registerCountAggregate(prefix);
registerCountAggregate(prefix, withCompanionFunctions, overwrite);
registerCountIfAggregate(prefix);
registerCovarianceAggregates(prefix);
registerCovarianceAggregates(prefix, withCompanionFunctions, overwrite);
registerEntropyAggregate(prefix);
registerGeometricMeanAggregate(prefix);
registerHistogramAggregate(prefix);
Expand All @@ -84,13 +100,13 @@ void registerAllAggregateFunctions(
registerMaxDataSizeForStatsAggregate(prefix);
registerMultiMapAggAggregate(prefix);
registerSumDataSizeForStatsAggregate(prefix);
registerMinMaxAggregates(prefix);
registerMinMaxAggregates(prefix, withCompanionFunctions, overwrite);
registerMinMaxByAggregates(prefix);
registerReduceAgg(prefix);
registerSetAggAggregate(prefix);
registerSetUnionAggregate(prefix);
registerSumAggregate(prefix);
registerVarianceAggregates(prefix);
registerVarianceAggregates(prefix, withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace facebook::velox::aggregate::prestosql {

void registerAllAggregateFunctions(
const std::string& prefix = "",
bool withCompanionFunctions = true);
bool withCompanionFunctions = true,
bool overwrite = false);

} // namespace facebook::velox::aggregate::prestosql
32 changes: 23 additions & 9 deletions velox/functions/prestosql/aggregates/VarianceAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ void checkSumCountRowType(
}

template <template <typename TInput> class TClass>
exec::AggregateRegistrationResult registerVariance(const std::string& name) {
exec::AggregateRegistrationResult registerVariance(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
std::vector<std::string> inputTypes = {
"smallint", "integer", "bigint", "real", "double"};
Expand Down Expand Up @@ -514,18 +517,29 @@ exec::AggregateRegistrationResult registerVariance(const std::string& name) {
"(count:bigint, mean:double, m2:double) struct");
return std::make_unique<TClass<int64_t>>(resultType);
}
});
},
withCompanionFunctions,
overwrite);
}

} // namespace

void registerVarianceAggregates(const std::string& prefix) {
registerVariance<StdDevSampAggregate>(prefix + kStdDev);
registerVariance<StdDevPopAggregate>(prefix + kStdDevPop);
registerVariance<StdDevSampAggregate>(prefix + kStdDevSamp);
registerVariance<VarSampAggregate>(prefix + kVariance);
registerVariance<VarPopAggregate>(prefix + kVarPop);
registerVariance<VarSampAggregate>(prefix + kVarSamp);
void registerVarianceAggregates(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerVariance<StdDevSampAggregate>(
prefix + kStdDev, withCompanionFunctions, overwrite);
registerVariance<StdDevPopAggregate>(
prefix + kStdDevPop, withCompanionFunctions, overwrite);
registerVariance<StdDevSampAggregate>(
prefix + kStdDevSamp, withCompanionFunctions, overwrite);
registerVariance<VarSampAggregate>(
prefix + kVariance, withCompanionFunctions, overwrite);
registerVariance<VarPopAggregate>(
prefix + kVarPop, withCompanionFunctions, overwrite);
registerVariance<VarSampAggregate>(
prefix + kVarSamp, withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::aggregate::prestosql
6 changes: 4 additions & 2 deletions velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
/// DECIMAL | DECIMAL | DECIMAL
exec::AggregateRegistrationResult registerAverage(
const std::string& name,
bool withCompanionFunctions) {
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;

for (const auto& inputType :
Expand Down Expand Up @@ -494,7 +495,8 @@ exec::AggregateRegistrationResult registerAverage(
}
}
},
withCompanionFunctions);
withCompanionFunctions,
overwrite);
}

} // namespace facebook::velox::functions::aggregate::sparksql
3 changes: 2 additions & 1 deletion velox/functions/sparksql/aggregates/AverageAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerAverage(
const std::string& name,
bool withCompanionFunctions);
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
6 changes: 4 additions & 2 deletions velox/functions/sparksql/aggregates/BitwiseXorAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ class BitwiseXorAggregate : public BitwiseAggregateBase<T> {
} // namespace

exec::AggregateRegistrationResult registerBitwiseXorAggregate(
const std::string& prefix) {
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
return functions::aggregate::registerBitwise<BitwiseXorAggregate>(
prefix + "bit_xor");
prefix + "bit_xor", withCompanionFunctions, overwrite);
}

} // namespace facebook::velox::functions::aggregate::sparksql
4 changes: 3 additions & 1 deletion velox/functions/sparksql/aggregates/BitwiseXorAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
namespace facebook::velox::functions::aggregate::sparksql {

exec::AggregateRegistrationResult registerBitwiseXorAggregate(
const std::string& name);
const std::string& name,
bool registerCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
Loading

0 comments on commit 3d82641

Please sign in to comment.