-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for sum(decimal) Spark aggregate function #5372
Changes from all commits
0b07cf2
53c02e2
6ef750f
35324d7
0883e26
9df787a
7bd7f86
855c410
670e220
b6b0020
2b55c08
12e10e0
e1614c7
e6297bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include "velox/exec/SimpleAggregateAdapter.h" | ||
#include "velox/type/DecimalUtil.h" | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
|
||
/// @tparam TInputType The raw input data type. | ||
/// @tparam TSumType The type of sum in the output of partial aggregation or the | ||
/// final output type of final aggregation. | ||
template <typename TInputType, typename TSumType> | ||
class DecimalSumAggregate { | ||
public: | ||
using InputType = Row<TInputType>; | ||
|
||
using IntermediateType = | ||
Row</*sum*/ TSumType, | ||
/*isEmpty*/ bool>; | ||
|
||
using OutputType = TSumType; | ||
|
||
/// Spark's decimal sum doesn't have the concept of a null group, each group | ||
/// is initialized with an initial value, where sum = 0 and isEmpty = true. | ||
/// The final agg may fallback to being executed in Spark, so the meaning of | ||
/// the intermediate data should be consistent with Spark. Therefore, we need | ||
/// to use the parameter nonNullGroup in writeIntermediateResult to output a | ||
/// null group as sum = 0, isEmpty = true. nonNullGroup is only available when | ||
/// default-null behavior is disabled. | ||
static constexpr bool default_null_behavior_ = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this false? Would you add a comment to explain? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mbasmanova The main reason is that Spark's decimal sum doesn't have the concept of a null group; each group is initialized with an initial value, where sum = 0 and isEmpty = true. Therefore, to maintain consistency, I need to use the parameter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand. What is "null group"? Are you trying to match intermediate results to Spark's? If so, does this matter only when the query using companion functions? Why intermediate result for all null inputs cannot be sum = NULL ? Do you need isEmpty = true to allow Spark to distinguish between all-NULL inputs and decimal overflow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
"null group" means all input values for this group are null, we never call clearNull for this group.
Yes. Not only for companion functions. In Spark decimal sum agg, sum is initialized to 0, not null. sum=0, isEmpty=true means all input values are null. Spark use isEmpty=true to distinguish between all-NULL inputs and input values' sum just equal to 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to ensure that the data output by the operator is consistent with Spark, because the final agg of the decimal sum may fallback to being executed in Spark for some reason. We need to make sure that the meaning of the intermediate data is consistent with Spark. Spark uses isEmpty and sum to distinguish many different situations for all-null inputs and overflow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Got it. Makes sense. Perhaps, clarify this in the comments. |
||
|
||
static constexpr bool aligned_accumulator_ = true; | ||
|
||
static bool toIntermediate( | ||
exec::out_type<Row<TSumType, bool>>& out, | ||
exec::optional_arg_type<TInputType> in) { | ||
if (in.has_value()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this implementation correct? Shouldn't we produce a non-null struct for every row, where sum = x and isEmpty = isNull(x)? How do you test this code path to ensure Spark can process the results correctly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we need to set sum = x, isEmpty = false for non-null input, and sum = 0, isEmpty = true for null input. I will add a test case for decimal sum in Gluten that will set |
||
out.copy_from(std::make_tuple(static_cast<TSumType>(in.value()), false)); | ||
} else { | ||
out.copy_from(std::make_tuple(static_cast<TSumType>(0), true)); | ||
} | ||
return true; | ||
} | ||
|
||
/// This struct stores the sum of input values, overflow during accumulation, | ||
/// and a bool value isEmpty used to indicate whether all inputs are null. The | ||
/// initial value of sum is 0. We need to keep sum unchanged if the input is | ||
/// null, as sum function ignores null input. If the isEmpty is true, then it | ||
/// means there were no values to begin with or all the values were null, so | ||
/// the result will be null. If the isEmpty is false, then if sum is nullopt | ||
/// that means an overflow has happened, it returns null. | ||
struct AccumulatorType { | ||
std::optional<int128_t> sum{0}; | ||
int64_t overflow{0}; | ||
bool isEmpty{true}; | ||
|
||
AccumulatorType() = delete; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not needed, is it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this done to prevent the direct use of the default no-argument constructor to create a struct? I'm not sure about the main purpose here. I've seen other implementations using the simple agg interface include this line of code, so I added it as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not required since another constructor with HashStringAllocator* argument is defined. But I think deleting the default constructor explicitly is still helpful to clarify the purpose of the code. |
||
|
||
explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} | ||
|
||
std::optional<int128_t> computeFinalResult() const { | ||
if (!sum.has_value()) { | ||
return std::nullopt; | ||
} | ||
auto const adjustedSum = | ||
DecimalUtil::adjustSumForOverflow(sum.value(), overflow); | ||
constexpr uint8_t maxPrecision = std::is_same_v<TSumType, int128_t> | ||
? LongDecimalType::kMaxPrecision | ||
: ShortDecimalType::kMaxPrecision; | ||
if (adjustedSum.has_value() && | ||
DecimalUtil::valueInPrecisionRange(adjustedSum, maxPrecision)) { | ||
return adjustedSum; | ||
} else { | ||
// Found overflow during computing adjusted sum. | ||
return std::nullopt; | ||
} | ||
} | ||
|
||
bool addInput( | ||
HashStringAllocator* /*allocator*/, | ||
exec::optional_arg_type<TInputType> data) { | ||
if (!data.has_value()) { | ||
return false; | ||
} | ||
if (!sum.has_value()) { | ||
// sum is initialized to 0. When it is nullopt, it implies that the | ||
// input data must not be empty. | ||
VELOX_CHECK(!isEmpty) | ||
return true; | ||
} | ||
int128_t result; | ||
overflow += | ||
DecimalUtil::addWithOverflow(result, data.value(), sum.value()); | ||
sum = result; | ||
isEmpty = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we only need to set itEmpty if it's true. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will introduce an if branch, which may increase overhead.
Comment on lines
+108
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have seen some use cases where the function (1) adds raw inputs, (2) extracts intermediate results, (3) add intermediate results back to accumulator, and (4) continue adding raw inputs to the accumulator (see the test code in AggregationTestBase::testStreaming()). Suppose the intermediate result added at step (3) is {null, false}, the code here at step (4) would overwrite it. Should this method also check whether the current accumulator is in the "overflow" status? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was not aware of this scenario before. Thus, we need to verify |
||
return true; | ||
} | ||
|
||
bool combine( | ||
HashStringAllocator* /*allocator*/, | ||
exec::optional_arg_type<Row<TSumType, bool>> other) { | ||
if (!other.has_value()) { | ||
return false; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: You can also return false for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, if the current accumulator's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes, sorry, I meant |
||
auto const otherSum = other.value().template at<0>(); | ||
auto const otherIsEmpty = other.value().template at<1>(); | ||
|
||
// isEmpty is never null. | ||
VELOX_CHECK(otherIsEmpty.has_value()); | ||
if (isEmpty && otherIsEmpty.value()) { | ||
// Both accumulators are empty, no need to do the combination. | ||
return false; | ||
} | ||
|
||
bool currentOverflow = !isEmpty && !sum.has_value(); | ||
bool otherOverflow = !otherIsEmpty.value() && !otherSum.has_value(); | ||
if (currentOverflow || otherOverflow) { | ||
sum = std::nullopt; | ||
isEmpty = false; | ||
} else { | ||
int128_t result; | ||
overflow += | ||
DecimalUtil::addWithOverflow(result, otherSum.value(), sum.value()); | ||
sum = result; | ||
isEmpty &= otherIsEmpty.value(); | ||
} | ||
return true; | ||
} | ||
|
||
bool writeIntermediateResult( | ||
bool nonNullGroup, | ||
exec::out_type<IntermediateType>& out) { | ||
if (!nonNullGroup) { | ||
// If a group is null, all values in this group are null. In Spark, this | ||
// group will be the initial value, where sum is 0 and isEmpty is true. | ||
out = std::make_tuple(static_cast<TSumType>(0), true); | ||
} else { | ||
auto finalResult = computeFinalResult(); | ||
if (finalResult.has_value()) { | ||
out = std::make_tuple( | ||
static_cast<TSumType>(finalResult.value()), isEmpty); | ||
} else { | ||
// Sum should be set to null on overflow, | ||
// and isEmpty should be set to false. | ||
out.template set_null_at<0>(); | ||
out.template get_writer_at<1>() = false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool writeFinalResult(bool nonNullGroup, exec::out_type<OutputType>& out) { | ||
if (!nonNullGroup || isEmpty) { | ||
// If isEmpty is true, we should set null. | ||
return false; | ||
} | ||
auto finalResult = computeFinalResult(); | ||
if (finalResult.has_value()) { | ||
out = static_cast<TSumType>(finalResult.value()); | ||
return true; | ||
} else { | ||
// Sum should be set to null on overflow. | ||
return false; | ||
} | ||
} | ||
}; | ||
}; | ||
|
||
} // namespace facebook::velox::functions::aggregate::sparksql |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thank you for helping enhance the aggregation function interface! Could you also update the documentation about this new flag? (https://facebookincubator.github.io/velox/develop/aggregate-functions.html, source code is in velox/docs/develop/aggregate-functions.rst.)