Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 12, 2024
1 parent b081560 commit ee50ea7
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
7 changes: 5 additions & 2 deletions velox/docs/functions/spark/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ General Aggregate Functions
For all other input types, the result type is BIGINT.

Note:
For DECIMAL type, when an overflow occurs in the accumulation, it returns null. For all other input types,
when the sum of input values exceeds its limit, it cycles to the overflowed value rather than raising an error.
When all input values is NULL, for all input types, the result of sum(x) is NULL.

For DECIMAL type, when an overflow occurs in the accumulation, it returns null. For REAL and DOUBLE type, it
returns Infinity. For all other input types, when the sum of input values exceeds its limit, it cycles to the
overflowed value rather than raising an error.

Example::

Expand Down
42 changes: 22 additions & 20 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

namespace facebook::velox::functions::aggregate::sparksql {

/// TInputType refer to the raw input data type. TSumType refer to the type of
/// sum in the output of partial aggregation or the final output type of final
/// aggregation.
/// @tparam TInputType The raw input data type.
/// @tparam TSumType The type of sum in the output of partial aggregation or the
/// final output type of final aggregation.
template <typename TInputType, typename TSumType>
class DecimalSumAggregate {
public:
Expand All @@ -34,6 +34,12 @@ class DecimalSumAggregate {

using OutputType = TSumType;

// Spark's decimal sum doesn't have the concept of a null group, each group is
// initialized with an initial value, where sum = 0 and isEmpty = true.
// Therefore, to maintain consistency, we need to use the parameter
// nonNullGroup in writeIntermediateResult to output a null group as sum =
// 0, isEmpty = true. nonNullGroup is only available when default-null
// behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static bool toIntermediate(
Expand All @@ -54,17 +60,13 @@ class DecimalSumAggregate {
// the result will be null. If the isEmpty is false, then if sum is nullopt
// that means an overflow has happened, it returns null.
struct AccumulatorType {
std::optional<int128_t> sum_;
int64_t overflow_;
bool isEmpty_;
std::optional<int128_t> sum_{0};
int64_t overflow_{0};
bool isEmpty_{true};

AccumulatorType() = delete;

explicit AccumulatorType(HashStringAllocator* /*allocator*/) {
sum_ = 0;
overflow_ = 0;
isEmpty_ = true;
}
explicit AccumulatorType(HashStringAllocator* /*allocator*/) {}

std::optional<int128_t> computeFinalResult() const {
if (!sum_.has_value()) {
Expand All @@ -79,23 +81,23 @@ class DecimalSumAggregate {
DecimalUtil::valueInPrecisionRange(adjustedSum, maxPrecision)) {
return adjustedSum;
} else {
// Find overflow during computing adjusted sum.
// Found overflow during computing adjusted sum.
return std::nullopt;
}
}

bool addInput(
HashStringAllocator* /*allocator*/,
exec::optional_arg_type<TInputType> data) {
if (data.has_value()) {
int128_t result;
overflow_ +=
DecimalUtil::addWithOverflow(result, data.value(), sum_.value());
sum_ = result;
isEmpty_ = false;
return true;
if (!data.has_value()) {
return false;
}
return false;
int128_t result;
overflow_ +=
DecimalUtil::addWithOverflow(result, data.value(), sum_.value());
sum_ = result;
isEmpty_ = false;
return true;
}

bool combine(
Expand Down
1 change: 0 additions & 1 deletion velox/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ exec::AggregateRegistrationResult registerSum(
}
case TypeKind::HUGEINT: {
if (inputType->isLongDecimal()) {
auto const sumType = getDecimalSumType(resultType, step);
// If inputType is long decimal,
// its output type is always long decimal.
return std::make_unique<exec::SimpleAggregateAdapter<
Expand Down
12 changes: 4 additions & 8 deletions velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SumAggregationTest : public SumTestBase {
const std::vector<std::optional<TOut>>& output,
const TypePtr& outputType) {
std::vector<RowVectorPtr> vectors;
FlatVectorPtr<TIn> inputDecimalVector;
VectorPtr inputDecimalVector;
if constexpr (std::is_same_v<int64_t, TIn>) {
inputDecimalVector = makeNullableFlatVector<int64_t>(input, inputType);
} else {
Expand All @@ -90,7 +90,7 @@ class SumAggregationTest : public SumTestBase {
inputDecimalVector}));
}

FlatVectorPtr<TOut> outputDecimalVector;
VectorPtr outputDecimalVector;
if constexpr (std::is_same_v<int64_t, TOut>) {
outputDecimalVector = makeNullableFlatVector<int64_t>(output, outputType);
} else {
Expand Down Expand Up @@ -137,9 +137,7 @@ TEST_F(SumAggregationTest, decimalSumCompanionPartial) {
auto expected = makeRowVector({makeRowVector(
{makeFlatVector<int128_t>(sumVector, DECIMAL(20, 1)),
makeFlatVector<bool>(isEmptyVector)})});
AssertQueryBuilder assertQueryBuilder(plan);
auto result = assertQueryBuilder.copyResults(pool());
assertEqualResults({expected}, {result});
AssertQueryBuilder(plan).assertResults(expected);
}

TEST_F(SumAggregationTest, decimalSumCompanionMerge) {
Expand All @@ -155,9 +153,7 @@ TEST_F(SumAggregationTest, decimalSumCompanionMerge) {
auto expected = makeRowVector({makeRowVector(
{makeFlatVector<int128_t>(std::vector<int128_t>{6000}, DECIMAL(20, 1)),
makeFlatVector<bool>(std::vector<bool>{false})})});
AssertQueryBuilder assertQueryBuilder(plan);
auto result = assertQueryBuilder.copyResults(pool());
assertEqualResults({expected}, {result});
AssertQueryBuilder(plan).assertResults(expected);
}

TEST_F(SumAggregationTest, decimalSum) {
Expand Down

0 comments on commit ee50ea7

Please sign in to comment.