Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sum(decimal) Spark aggregate function #5372

Closed
16 changes: 12 additions & 4 deletions velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ For aggregaiton functions of default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool use_external_memory_ = true;

// Optional. Default is false.
static constexpr bool aligned_accumulator_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);

void addInput(HashStringAllocator* allocator, exec::arg_type<T1> value1, ...);
Expand All @@ -274,7 +277,9 @@ The author defines an optional flag `is_fixed_size_` indicating whether the
every accumulator takes fixed amount of memory. This flag is true by default.
Next, the author defines another optional flag `use_external_memory_`
indicating whether the accumulator uses memory that is not tracked by Velox.
This flag is false by default.
This flag is false by default. Then, the author can define an optional flag
`aligned_accumulator_` indicating whether the accumulator requires aligned
access. This flag is false by default.

The author defines a constructor that takes a single argument of
`HashStringAllocator*`. This constructor is called before aggregation starts to
Expand Down Expand Up @@ -345,6 +350,9 @@ For aggregaiton functions of non-default-null behavior, the author defines an
// Optional. Default is false.
static constexpr bool use_external_memory_ = true;

// Optional. Default is false.
static constexpr bool aligned_accumulator_ = true;

explicit AccumulatorType(HashStringAllocator* allocator);

bool addInput(HashStringAllocator* allocator, exec::optional_arg_type<T1> value1, ...);
Expand All @@ -361,9 +369,9 @@ For aggregaiton functions of non-default-null behavior, the author defines an
void destroy(HashStringAllocator* allocator);
};

The definition of `is_fixed_size_`, `use_external_memory_`, the constructor,
and the `destroy` method are exactly the same as those for default-null
behavior.
The definition of `is_fixed_size_`, `use_external_memory_`,
`aligned_accumulator_`, the constructor, and the `destroy` method are exactly
the same as those for default-null behavior.

On the other hand, the C++ function signatures of `addInput`, `combine`,
`writeIntermediateResult`, and `writeFinalResult` are different.
Expand Down
11 changes: 9 additions & 2 deletions velox/docs/functions/spark/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,20 @@ General Aggregate Functions

Returns the sum of `x`.

Supported types are TINYINT, SMALLINT, INTEGER, BIGINT, REAL and DOUBLE.
Supported types are TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE and DECIMAL.

When x is of type DOUBLE, the result type is DOUBLE.
When x is of type REAL, the result type is REAL.
When x is of type DECIMAL(p, s), the result type is DECIMAL(p + 10, s), where (p + 10) is capped at 38.

For all other input types, the result type is BIGINT.

Note: When the sum of BIGINT values exceeds its limit, it cycles to the overflowed value rather than raising an error.
Note:
When all input values is NULL, for all input types, the result is NULL.

For DECIMAL type, when an overflow occurs in the accumulation, it returns NULL. For REAL and DOUBLE type, it
returns Infinity. For all other input types, when the sum of input values exceeds its limit, it cycles to the
overflowed value rather than raising an error.

Example::

Expand Down
21 changes: 21 additions & 0 deletions velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class SimpleAggregateAdapter : public Aggregate {
struct support_to_intermediate<T, std::void_t<decltype(&T::toIntermediate)>>
: std::true_type {};

// Whether the accumulator requires aligned access. If it is defined,
// SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// alignof(typename FUNC::AccumulatorType).
// Otherwise, SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// Aggregate::accumulatorAlignmentSize(), with a default value of 1.
template <typename T, typename = void>
struct aligned_accumulator : std::false_type {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thank you for helping enhance the aggregation function interface! Could you also update the documentation about this new flag? (https://facebookincubator.github.io/velox/develop/aggregate-functions.html, source code is in velox/docs/develop/aggregate-functions.rst.)


template <typename T>
struct aligned_accumulator<T, std::void_t<decltype(T::aligned_accumulator_)>>
: std::integral_constant<bool, T::aligned_accumulator_> {};

static constexpr bool aggregate_default_null_behavior_ =
aggregate_default_null_behavior<FUNC>::value;

Expand All @@ -160,6 +172,8 @@ class SimpleAggregateAdapter : public Aggregate {
static constexpr bool support_to_intermediate_ =
support_to_intermediate<FUNC>::value;

static constexpr bool aligned_accumulator_ = aligned_accumulator<FUNC>::value;

bool isFixedSize() const override {
return accumulator_is_fixed_size_;
}
Expand All @@ -172,6 +186,13 @@ class SimpleAggregateAdapter : public Aggregate {
return sizeof(typename FUNC::AccumulatorType);
}

int32_t accumulatorAlignmentSize() const override {
if constexpr (aligned_accumulator_) {
return alignof(typename FUNC::AccumulatorType);
}
return Aggregate::accumulatorAlignmentSize();
}

void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override {
Expand Down
183 changes: 183 additions & 0 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/exec/SimpleAggregateAdapter.h"
#include "velox/type/DecimalUtil.h"

namespace facebook::velox::functions::aggregate::sparksql {

/// @tparam TInputType The raw input data type.
/// @tparam TSumType The type of sum in the output of partial aggregation or the
/// final output type of final aggregation.
template <typename TInputType, typename TSumType>
class DecimalSumAggregate {
public:
using InputType = Row<TInputType>;

using IntermediateType =
Row</*sum*/ TSumType,
/*isEmpty*/ bool>;

using OutputType = TSumType;

/// Spark's decimal sum doesn't have the concept of a null group, each group
/// is initialized with an initial value, where sum = 0 and isEmpty = true.
/// The final agg may fallback to being executed in Spark, so the meaning of
/// the intermediate data should be consistent with Spark. Therefore, we need
/// to use the parameter nonNullGroup in writeIntermediateResult to output a
/// null group as sum = 0, isEmpty = true. nonNullGroup is only available when
/// default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this false? Would you add a comment to explain?

Copy link
Contributor Author

@liujiayi771 liujiayi771 Feb 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbasmanova The main reason is that Spark's decimal sum doesn't have the concept of a null group; each group is initialized with an initial value, where sum = 0 and isEmpty = true. Therefore, to maintain consistency, I need to use the parameter nonNullGroup in writeIntermediateResult (this parameter is only available when non-default-null behavior is enabled), to output a null group as sum = 0, isEmpty = true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand. What is "null group"? Are you trying to match intermediate results to Spark's? If so, does this matter only when the query using companion functions? Why intermediate result for all null inputs cannot be sum = NULL ? Do you need isEmpty = true to allow Spark to distinguish between all-NULL inputs and decimal overflow?

Copy link
Contributor Author

@liujiayi771 liujiayi771 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "null group"?

"null group" means all input values for this group are null, we never call clearNull for this group.

Are you trying to match intermediate results to Spark's?

Yes. Not only for companion functions. In Spark decimal sum agg, sum is initialized to 0, not null. sum=0, isEmpty=true means all input values are null. Spark use isEmpty=true to distinguish between all-NULL inputs and input values' sum just equal to 0.

Copy link
Contributor Author

@liujiayi771 liujiayi771 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to ensure that the data output by the operator is consistent with Spark, because the final agg of the decimal sum may fallback to being executed in Spark for some reason. We need to make sure that the meaning of the intermediate data is consistent with Spark. Spark uses isEmpty and sum to distinguish many different situations for all-null inputs and overflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the final agg of the decimal sum may fallback to being executed in Spark

Got it. Makes sense. Perhaps, clarify this in the comments.


static constexpr bool aligned_accumulator_ = true;

static bool toIntermediate(
exec::out_type<Row<TSumType, bool>>& out,
exec::optional_arg_type<TInputType> in) {
if (in.has_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation correct? Shouldn't we produce a non-null struct for every row, where sum = x and isEmpty = isNull(x)? How do you test this code path to ensure Spark can process the results correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to set sum = x, isEmpty = false for non-null input, and sum = 0, isEmpty = true for null input. I will add a test case for decimal sum in Gluten that will set kAbandonPartialAggregationMinPct and kAbandonPartialAggregationMinRows to a very small value to trigger partial agg abandon.

out.copy_from(std::make_tuple(static_cast<TSumType>(in.value()), false));
} else {
out.copy_from(std::make_tuple(static_cast<TSumType>(0), true));
}
return true;
}

/// This struct stores the sum of input values, overflow during accumulation,
/// and a bool value isEmpty used to indicate whether all inputs are null. The
/// initial value of sum is 0. We need to keep sum unchanged if the input is
/// null, as sum function ignores null input. If the isEmpty is true, then it
/// means there were no values to begin with or all the values were null, so
/// the result will be null. If the isEmpty is false, then if sum is nullopt
/// that means an overflow has happened, it returns null.
struct AccumulatorType {
std::optional<int128_t> sum{0};
int64_t overflow{0};
bool isEmpty{true};

AccumulatorType() = delete;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed, is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done to prevent the direct use of the default no-argument constructor to create a struct? I'm not sure about the main purpose here. I've seen other implementations using the simple agg interface include this line of code, so I added it as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not required since another constructor with HashStringAllocator* argument is defined. But I think deleting the default constructor explicitly is still helpful to clarify the purpose of the code.


explicit AccumulatorType(HashStringAllocator* /*allocator*/) {}

std::optional<int128_t> computeFinalResult() const {
if (!sum.has_value()) {
return std::nullopt;
}
auto const adjustedSum =
DecimalUtil::adjustSumForOverflow(sum.value(), overflow);
constexpr uint8_t maxPrecision = std::is_same_v<TSumType, int128_t>
? LongDecimalType::kMaxPrecision
: ShortDecimalType::kMaxPrecision;
if (adjustedSum.has_value() &&
DecimalUtil::valueInPrecisionRange(adjustedSum, maxPrecision)) {
return adjustedSum;
} else {
// Found overflow during computing adjusted sum.
return std::nullopt;
}
}

bool addInput(
HashStringAllocator* /*allocator*/,
exec::optional_arg_type<TInputType> data) {
if (!data.has_value()) {
return false;
}
if (!sum.has_value()) {
// sum is initialized to 0. When it is nullopt, it implies that the
// input data must not be empty.
VELOX_CHECK(!isEmpty)
return true;
}
int128_t result;
overflow +=
DecimalUtil::addWithOverflow(result, data.value(), sum.value());
sum = result;
isEmpty = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we only need to set itEmpty if it's true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will introduce an if branch, which may increase overhead.

Comment on lines +108 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have seen some use cases where the function (1) adds raw inputs, (2) extracts intermediate results, (3) add intermediate results back to accumulator, and (4) continue adding raw inputs to the accumulator (see the test code in AggregationTestBase::testStreaming()).

Suppose the intermediate result added at step (3) is {null, false}, the code here at step (4) would overwrite it. Should this method also check whether the current accumulator is in the "overflow" status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was not aware of this scenario before. Thus, we need to verify sum.has_value() here, and if it is false, we should ignore the input data and return true directly.

return true;
}

bool combine(
HashStringAllocator* /*allocator*/,
exec::optional_arg_type<Row<TSumType, bool>> other) {
if (!other.has_value()) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can also return false for if (isEmpty || otherIsEmpty) upfront here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, if the current accumulator's isEmpty is true, and otherIsEmpty is false, the combined accumulator's isEmpty will be false, and the sum will equal to the otherSum. But we can return false for if (isEmpty && otherIsEmpty).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, sorry, I meant if (isEmpty && otherIsEmpty). My typo.

auto const otherSum = other.value().template at<0>();
auto const otherIsEmpty = other.value().template at<1>();

// isEmpty is never null.
VELOX_CHECK(otherIsEmpty.has_value());
if (isEmpty && otherIsEmpty.value()) {
// Both accumulators are empty, no need to do the combination.
return false;
}

bool currentOverflow = !isEmpty && !sum.has_value();
bool otherOverflow = !otherIsEmpty.value() && !otherSum.has_value();
if (currentOverflow || otherOverflow) {
sum = std::nullopt;
isEmpty = false;
} else {
int128_t result;
overflow +=
DecimalUtil::addWithOverflow(result, otherSum.value(), sum.value());
sum = result;
isEmpty &= otherIsEmpty.value();
}
return true;
}

bool writeIntermediateResult(
bool nonNullGroup,
exec::out_type<IntermediateType>& out) {
if (!nonNullGroup) {
// If a group is null, all values in this group are null. In Spark, this
// group will be the initial value, where sum is 0 and isEmpty is true.
out = std::make_tuple(static_cast<TSumType>(0), true);
} else {
auto finalResult = computeFinalResult();
if (finalResult.has_value()) {
out = std::make_tuple(
static_cast<TSumType>(finalResult.value()), isEmpty);
} else {
// Sum should be set to null on overflow,
// and isEmpty should be set to false.
out.template set_null_at<0>();
out.template get_writer_at<1>() = false;
}
}
return true;
}

bool writeFinalResult(bool nonNullGroup, exec::out_type<OutputType>& out) {
if (!nonNullGroup || isEmpty) {
// If isEmpty is true, we should set null.
return false;
}
auto finalResult = computeFinalResult();
if (finalResult.has_value()) {
out = static_cast<TSumType>(finalResult.value());
return true;
} else {
// Sum should be set to null on overflow.
return false;
}
}
};
};

} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 1 addition & 1 deletion velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ void registerAggregateFunctions(
registerBitwiseXorAggregate(prefix);
registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
registerAverage(prefix + "avg", withCompanionFunctions);
registerSum(prefix + "sum");
registerSum(prefix + "sum", withCompanionFunctions);
}
} // namespace facebook::velox::functions::aggregate::sparksql
Loading
Loading