Skip to content

Commit

Permalink
register decimal sum in SumAggregate and fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Oct 28, 2023
1 parent 78d847a commit af9c524
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 421 deletions.
106 changes: 22 additions & 84 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,8 @@ class DecimalSumAggregate : public exec::Aggregate {
if (decodedPartial_.isConstantMapping()) {
if (!decodedPartial_.isNullAt(0)) {
auto decodedIndex = decodedPartial_.index(0);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
// If isEmpty is false and sum is null, it means this intermediate
// result has an overflow. The final accumulator of this group will
// be null.
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
rows.applyToSelected([&](vector_size_t i) { setNull(groups[i]); });
} else {
auto sum = sumVector->valueAt(decodedIndex);
Expand All @@ -259,8 +256,8 @@ class DecimalSumAggregate : public exec::Aggregate {
return;
}
auto decodedIndex = decodedPartial_.index(i);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
setNull(groups[i]);
} else {
auto sum = sumVector->valueAt(decodedIndex);
Expand All @@ -272,8 +269,8 @@ class DecimalSumAggregate : public exec::Aggregate {
rows.applyToSelected([&](vector_size_t i) {
clearNull(groups[i]);
auto decodedIndex = decodedPartial_.index(i);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
setNull(groups[i]);
} else {
auto sum = sumVector->valueAt(decodedIndex);
Expand All @@ -298,8 +295,8 @@ class DecimalSumAggregate : public exec::Aggregate {
if (decodedPartial_.isConstantMapping()) {
if (!decodedPartial_.isNullAt(0)) {
auto decodedIndex = decodedPartial_.index(0);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
setNull(group);
} else {
auto sum = sumVector->valueAt(decodedIndex);
Expand All @@ -318,8 +315,8 @@ class DecimalSumAggregate : public exec::Aggregate {
return;
}
auto decodedIndex = decodedPartial_.index(i);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
setNull(group);
return;
} else {
Expand All @@ -335,8 +332,8 @@ class DecimalSumAggregate : public exec::Aggregate {
}
rows.applyToSelected([&](vector_size_t i) {
auto decodedIndex = decodedPartial_.index(i);
if (!isEmptyVector->valueAt(decodedIndex) &&
sumVector->isNullAt(decodedIndex)) {
if (isIntermediateResultOverflow(
isEmptyVector, sumVector, decodedIndex)) {
setNull(group);
return;
} else {
Expand Down Expand Up @@ -373,78 +370,19 @@ class DecimalSumAggregate : public exec::Aggregate {
return exec::Aggregate::value<DecimalSum>(group);
}

inline bool isIntermediateResultOverflow(
const SimpleVector<bool>* isEmptyVector,
const SimpleVector<TResultType>* sumVector,
vector_size_t index) {
// If isEmpty is false and sum is null, it means this intermediate
// result has an overflow. The final accumulator of this group will
// be null.
return !isEmptyVector->valueAt(index) && sumVector->isNullAt(index);
}

DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
TypePtr sumType_;
};

exec::AggregateRegistrationResult registerDecimalSumAggregate(
const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("r_precision", "min(38, a_precision + 10)")
.integerVariable("r_scale", "min(38, a_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)")
.returnType("DECIMAL(r_precision, r_scale)")
.build()};

return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes only one argument", name);
auto& inputType = argTypes[0];
auto sumType =
exec::isPartialOutput(step) ? resultType->childAt(0) : resultType;
switch (inputType->kind()) {
case TypeKind::BIGINT: {
DCHECK(exec::isRawInput(step));
if (inputType->isShortDecimal()) {
if (sumType->isShortDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
resultType, sumType);
} else if (sumType->isLongDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int128_t>>(
resultType, sumType);
}
}
}
case TypeKind::HUGEINT:
if (inputType->isLongDecimal()) {
// If inputType is long decimal,
// its output type always be long decimal.
return std::make_unique<DecimalSumAggregate<int128_t, int128_t>>(
resultType, sumType);
}
case TypeKind::ROW: {
DCHECK(!exec::isRawInput(step));
// For intermediate input agg, input intermediate sum type
// is equal to final result sum type.
if (inputType->childAt(0)->isShortDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
resultType, sumType);
} else if (inputType->childAt(0)->isLongDecimal()) {
return std::make_unique<DecimalSumAggregate<int128_t, int128_t>>(
resultType, sumType);
}
}
default:
VELOX_CHECK(
false,
"Unknown input type for {} aggregation {}",
name,
inputType->kindName());
}
},
true);
}

} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 0 additions & 2 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "velox/functions/sparksql/aggregates/AverageAggregate.h"
#include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h"
#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"
#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h"
#include "velox/functions/sparksql/aggregates/SumAggregate.h"

namespace facebook::velox::functions::aggregate::sparksql {
Expand All @@ -34,6 +33,5 @@ void registerAggregateFunctions(const std::string& prefix) {
registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
registerAverage(prefix + "avg");
registerSum(prefix + "sum");
registerDecimalSumAggregate(prefix + "sum");
}
} // namespace facebook::velox::functions::aggregate::sparksql
46 changes: 44 additions & 2 deletions velox/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "velox/functions/lib/aggregates/SumAggregateBase.h"
#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h"

using namespace facebook::velox::functions::aggregate;

Expand All @@ -22,7 +23,13 @@ namespace facebook::velox::functions::aggregate::sparksql {
namespace {
template <typename TInput, typename TAccumulator, typename ResultType>
using SumAggregate = SumAggregateBase<TInput, TAccumulator, ResultType, true>;

TypePtr getDecimalSumType(
const TypePtr& resultType,
core::AggregationNode::Step step) {
return exec::isPartialOutput(step) ? resultType->childAt(0) : resultType;
}
} // namespace

void registerSum(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
Expand All @@ -36,6 +43,15 @@ void registerSum(const std::string& name) {
.intermediateType("double")
.argumentType("double")
.build(),
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.integerVariable("r_precision", "min(38, a_precision + 10)")
.integerVariable("r_scale", "min(38, a_scale)")
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)")
.returnType("DECIMAL(r_precision, r_scale)")
.build(),
};

for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
Expand Down Expand Up @@ -69,13 +85,26 @@ void registerSum(const std::string& name) {
BIGINT());
case TypeKind::BIGINT: {
if (inputType->isShortDecimal()) {
VELOX_NYI();
auto sumType = getDecimalSumType(resultType, step);
if (sumType->isShortDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
resultType, sumType);
} else if (sumType->isLongDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int128_t>>(
resultType, sumType);
}
}
return std::make_unique<SumAggregate<int64_t, int64_t, int64_t>>(
BIGINT());
}
case TypeKind::HUGEINT: {
VELOX_NYI();
if (inputType->isLongDecimal()) {
auto sumType = getDecimalSumType(resultType, step);
// If inputType is long decimal,
// its output type always be long decimal.
return std::make_unique<DecimalSumAggregate<int128_t, int128_t>>(
resultType, sumType);
}
}
case TypeKind::REAL:
if (resultType->kind() == TypeKind::REAL) {
Expand All @@ -91,6 +120,19 @@ void registerSum(const std::string& name) {
}
return std::make_unique<SumAggregate<double, double, double>>(
DOUBLE());
case TypeKind::ROW: {
DCHECK(!exec::isRawInput(step));
auto sumType = getDecimalSumType(resultType, step);
// For intermediate input agg, input intermediate sum type
// is equal to final result sum type.
if (inputType->childAt(0)->isShortDecimal()) {
return std::make_unique<DecimalSumAggregate<int64_t, int64_t>>(
resultType, sumType);
} else if (inputType->childAt(0)->isLongDecimal()) {
return std::make_unique<DecimalSumAggregate<int128_t, int128_t>>(
resultType, sumType);
}
}
default:
VELOX_CHECK(
false,
Expand Down
1 change: 0 additions & 1 deletion velox/functions/sparksql/aggregates/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ add_executable(
velox_functions_spark_aggregates_test
BitwiseXorAggregationTest.cpp
BloomFilterAggAggregateTest.cpp
DecimalSumAggregateTest.cpp
FirstAggregateTest.cpp
LastAggregateTest.cpp
AverageAggregationTest.cpp
Expand Down
Loading

0 comments on commit af9c524

Please sign in to comment.