Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 14, 2024
1 parent ee50ea7 commit 0b40e54
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
4 changes: 2 additions & 2 deletions velox/docs/functions/spark/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ General Aggregate Functions
For all other input types, the result type is BIGINT.

Note:
When all input values is NULL, for all input types, the result of sum(x) is NULL.
When all input values is NULL, for all input types, the result is NULL.

For DECIMAL type, when an overflow occurs in the accumulation, it returns null. For REAL and DOUBLE type, it
For DECIMAL type, when an overflow occurs in the accumulation, it returns NULL. For REAL and DOUBLE type, it
returns Infinity. For all other input types, when the sum of input values exceeds its limit, it cycles to the
overflowed value rather than raising an error.

Expand Down
71 changes: 36 additions & 35 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,23 @@ class DecimalSumAggregate {
using OutputType = TSumType;

// Spark's decimal sum doesn't have the concept of a null group, each group is
// initialized with an initial value, where sum = 0 and isEmpty = true.
// Therefore, to maintain consistency, we need to use the parameter
// nonNullGroup in writeIntermediateResult to output a null group as sum =
// 0, isEmpty = true. nonNullGroup is only available when default-null
// behavior is disabled.
// initialized with an initial value, where sum = 0 and isEmpty = true. The
// final agg may fallback to being executed in Spark, so the meaning of the
// intermediate data should be consistent with Spark. Therefore, we need to
// use the parameter nonNullGroup in writeIntermediateResult to output a null
// group as sum = 0, isEmpty = true. nonNullGroup is only available when
// default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static bool toIntermediate(
exec::out_type<Row<TSumType, bool>>& out,
exec::optional_arg_type<TInputType> in) {
if (in.has_value()) {
out.copy_from(std::make_tuple(static_cast<TSumType>(in.value()), false));
return true;
} else {
out.copy_from(std::make_tuple(static_cast<TSumType>(0), true));
}
return false;
return true;
}

// This struct stores the sum of input values, overflow during accumulation,
Expand All @@ -60,21 +62,21 @@ class DecimalSumAggregate {
// the result will be null. If the isEmpty is false, then if sum is nullopt
// that means an overflow has happened, it returns null.
struct AccumulatorType {
std::optional<int128_t> sum_{0};
int64_t overflow_{0};
bool isEmpty_{true};
std::optional<int128_t> sum{0};
int64_t overflow{0};
bool isEmpty{true};

AccumulatorType() = delete;

explicit AccumulatorType(HashStringAllocator* /*allocator*/) {}

std::optional<int128_t> computeFinalResult() const {
if (!sum_.has_value()) {
if (!sum.has_value()) {
return std::nullopt;
}
auto adjustedSum =
DecimalUtil::adjustSumForOverflow(sum_.value(), overflow_);
uint8_t maxPrecision = std::is_same_v<TSumType, int128_t>
DecimalUtil::adjustSumForOverflow(sum.value(), overflow);
constexpr uint8_t maxPrecision = std::is_same_v<TSumType, int128_t>
? LongDecimalType::kMaxPrecision
: ShortDecimalType::kMaxPrecision;
if (adjustedSum.has_value() &&
Expand All @@ -93,10 +95,10 @@ class DecimalSumAggregate {
return false;
}
int128_t result;
overflow_ +=
DecimalUtil::addWithOverflow(result, data.value(), sum_.value());
sum_ = result;
isEmpty_ = false;
overflow +=
DecimalUtil::addWithOverflow(result, data.value(), sum.value());
sum = result;
isEmpty = false;
return true;
}

Expand All @@ -109,40 +111,39 @@ class DecimalSumAggregate {
auto otherSum = other.value().template at<0>();
auto otherIsEmpty = other.value().template at<1>();

// IsEmpty should always has value.
// isEmpty is never null.
VELOX_CHECK(otherIsEmpty.has_value());

bool bufferOverflow = !isEmpty_ && !sum_.has_value();
bool inputOverflow = !otherIsEmpty.value() && !otherSum.has_value();
if (bufferOverflow || inputOverflow) {
sum_ = std::nullopt;
return false;
bool currentOverflow = !isEmpty && !sum.has_value();
bool otherOverflow = !otherIsEmpty.value() && !otherSum.has_value();
if (currentOverflow || otherOverflow) {
sum = std::nullopt;
isEmpty = false;
} else {
int128_t result;
overflow_ += DecimalUtil::addWithOverflow(
result, otherSum.value(), sum_.value());
sum_ = result;
isEmpty_ &= otherIsEmpty.value();
return true;
overflow +=
DecimalUtil::addWithOverflow(result, otherSum.value(), sum.value());
sum = result;
isEmpty &= otherIsEmpty.value();
}
return true;
}

bool writeIntermediateResult(
bool nonNullGroup,
exec::out_type<IntermediateType>& out) {
if (!nonNullGroup) {
// If a group is null, maybe all values in this group are null. In
// Spark, this group will be the initial value, where sum is 0 and
// isEmpty is true.
// If a group is null, all values in this group are null. In Spark, this
// group will be the initial value, where sum is 0 and isEmpty is true.
out = std::make_tuple(static_cast<TSumType>(0), true);
} else {
auto finalResult = computeFinalResult();
if (finalResult.has_value()) {
out = std::make_tuple(
static_cast<TSumType>(finalResult.value()), isEmpty_);
static_cast<TSumType>(finalResult.value()), isEmpty);
} else {
// Sum should be set to null on overflow, and
// isEmpty should be set to false.
// Sum should be set to null on overflow,
// and isEmptyshould be set to false.
out.template set_null_at<0>();
out.template get_writer_at<1>() = false;
}
Expand All @@ -151,7 +152,7 @@ class DecimalSumAggregate {
}

bool writeFinalResult(bool nonNullGroup, exec::out_type<OutputType>& out) {
if (!nonNullGroup || isEmpty_) {
if (!nonNullGroup || isEmpty) {
// If isEmpty is true, we should set null.
return false;
}
Expand Down

0 comments on commit 0b40e54

Please sign in to comment.