diff --git a/velox/docs/develop/aggregate-functions.rst b/velox/docs/develop/aggregate-functions.rst index 572c7446e16a..a64ce58be8e4 100644 --- a/velox/docs/develop/aggregate-functions.rst +++ b/velox/docs/develop/aggregate-functions.rst @@ -254,6 +254,9 @@ For aggregaiton functions of default-null behavior, the author defines an // Optional. Default is false. static constexpr bool use_external_memory_ = true; + // Optional. Default is false. + static constexpr bool aligned_accumulator_ = true; + explicit AccumulatorType(HashStringAllocator* allocator); void addInput(HashStringAllocator* allocator, exec::arg_type value1, ...); @@ -274,7 +277,9 @@ The author defines an optional flag `is_fixed_size_` indicating whether the every accumulator takes fixed amount of memory. This flag is true by default. Next, the author defines another optional flag `use_external_memory_` indicating whether the accumulator uses memory that is not tracked by Velox. -This flag is false by default. +This flag is false by default. Then, the author can define an optional flag +`aligned_accumulator_` indicating whether the accumulator requires aligned +access. This flag is false by default. The author defines a constructor that takes a single argument of `HashStringAllocator*`. This constructor is called before aggregation starts to @@ -345,6 +350,9 @@ For aggregaiton functions of non-default-null behavior, the author defines an // Optional. Default is false. static constexpr bool use_external_memory_ = true; + // Optional. Default is false. + static constexpr bool aligned_accumulator_ = true; + explicit AccumulatorType(HashStringAllocator* allocator); bool addInput(HashStringAllocator* allocator, exec::optional_arg_type value1, ...); @@ -361,9 +369,9 @@ For aggregaiton functions of non-default-null behavior, the author defines an void destroy(HashStringAllocator* allocator); }; -The definition of `is_fixed_size_`, `use_external_memory_`, the constructor, -and the `destroy` method are exactly the same as those for default-null -behavior. +The definition of `is_fixed_size_`, `use_external_memory_`, +`aligned_accumulator_`, the constructor, and the `destroy` method are exactly +the same as those for default-null behavior. On the other hand, the C++ function signatures of `addInput`, `combine`, `writeIntermediateResult`, and `writeFinalResult` are different. diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index d566e9b15581..a43c95042aca 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -107,13 +107,20 @@ General Aggregate Functions Returns the sum of `x`. - Supported types are TINYINT, SMALLINT, INTEGER, BIGINT, REAL and DOUBLE. + Supported types are TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE and DECIMAL. When x is of type DOUBLE, the result type is DOUBLE. When x is of type REAL, the result type is REAL. + When x is of type DECIMAL(p, s), the result type is DECIMAL(p + 10, s), where (p + 10) is capped at 38. + For all other input types, the result type is BIGINT. - Note: When the sum of BIGINT values exceeds its limit, it cycles to the overflowed value rather than raising an error. + Note: + When all input values is NULL, for all input types, the result is NULL. + + For DECIMAL type, when an overflow occurs in the accumulation, it returns NULL. For REAL and DOUBLE type, it + returns Infinity. For all other input types, when the sum of input values exceeds its limit, it cycles to the + overflowed value rather than raising an error. Example:: diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index bfeaada8b114..4d642aefd934 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -145,6 +145,18 @@ class SimpleAggregateAdapter : public Aggregate { struct support_to_intermediate> : std::true_type {}; + // Whether the accumulator requires aligned access. If it is defined, + // SimpleAggregateAdapter::accumulatorAlignmentSize() returns + // alignof(typename FUNC::AccumulatorType). + // Otherwise, SimpleAggregateAdapter::accumulatorAlignmentSize() returns + // Aggregate::accumulatorAlignmentSize(), with a default value of 1. + template + struct aligned_accumulator : std::false_type {}; + + template + struct aligned_accumulator> + : std::integral_constant {}; + static constexpr bool aggregate_default_null_behavior_ = aggregate_default_null_behavior::value; @@ -160,6 +172,8 @@ class SimpleAggregateAdapter : public Aggregate { static constexpr bool support_to_intermediate_ = support_to_intermediate::value; + static constexpr bool aligned_accumulator_ = aligned_accumulator::value; + bool isFixedSize() const override { return accumulator_is_fixed_size_; } @@ -172,6 +186,13 @@ class SimpleAggregateAdapter : public Aggregate { return sizeof(typename FUNC::AccumulatorType); } + int32_t accumulatorAlignmentSize() const override { + if constexpr (aligned_accumulator_) { + return alignof(typename FUNC::AccumulatorType); + } + return Aggregate::accumulatorAlignmentSize(); + } + void initializeNewGroups( char** groups, folly::Range indices) override { diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h new file mode 100644 index 000000000000..638f78753367 --- /dev/null +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/type/DecimalUtil.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +/// @tparam TInputType The raw input data type. +/// @tparam TSumType The type of sum in the output of partial aggregation or the +/// final output type of final aggregation. +template +class DecimalSumAggregate { + public: + using InputType = Row; + + using IntermediateType = + Row; + + using OutputType = TSumType; + + /// Spark's decimal sum doesn't have the concept of a null group, each group + /// is initialized with an initial value, where sum = 0 and isEmpty = true. + /// The final agg may fallback to being executed in Spark, so the meaning of + /// the intermediate data should be consistent with Spark. Therefore, we need + /// to use the parameter nonNullGroup in writeIntermediateResult to output a + /// null group as sum = 0, isEmpty = true. nonNullGroup is only available when + /// default-null behavior is disabled. + static constexpr bool default_null_behavior_ = false; + + static constexpr bool aligned_accumulator_ = true; + + static bool toIntermediate( + exec::out_type>& out, + exec::optional_arg_type in) { + if (in.has_value()) { + out.copy_from(std::make_tuple(static_cast(in.value()), false)); + } else { + out.copy_from(std::make_tuple(static_cast(0), true)); + } + return true; + } + + /// This struct stores the sum of input values, overflow during accumulation, + /// and a bool value isEmpty used to indicate whether all inputs are null. The + /// initial value of sum is 0. We need to keep sum unchanged if the input is + /// null, as sum function ignores null input. If the isEmpty is true, then it + /// means there were no values to begin with or all the values were null, so + /// the result will be null. If the isEmpty is false, then if sum is nullopt + /// that means an overflow has happened, it returns null. + struct AccumulatorType { + std::optional sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; + + AccumulatorType() = delete; + + explicit AccumulatorType(HashStringAllocator* /*allocator*/) {} + + std::optional computeFinalResult() const { + if (!sum.has_value()) { + return std::nullopt; + } + auto const adjustedSum = + DecimalUtil::adjustSumForOverflow(sum.value(), overflow); + constexpr uint8_t maxPrecision = std::is_same_v + ? LongDecimalType::kMaxPrecision + : ShortDecimalType::kMaxPrecision; + if (adjustedSum.has_value() && + DecimalUtil::valueInPrecisionRange(adjustedSum, maxPrecision)) { + return adjustedSum; + } else { + // Found overflow during computing adjusted sum. + return std::nullopt; + } + } + + bool addInput( + HashStringAllocator* /*allocator*/, + exec::optional_arg_type data) { + if (!data.has_value()) { + return false; + } + if (!sum.has_value()) { + // sum is initialized to 0. When it is nullopt, it implies that the + // input data must not be empty. + VELOX_CHECK(!isEmpty) + return true; + } + int128_t result; + overflow += + DecimalUtil::addWithOverflow(result, data.value(), sum.value()); + sum = result; + isEmpty = false; + return true; + } + + bool combine( + HashStringAllocator* /*allocator*/, + exec::optional_arg_type> other) { + if (!other.has_value()) { + return false; + } + auto const otherSum = other.value().template at<0>(); + auto const otherIsEmpty = other.value().template at<1>(); + + // isEmpty is never null. + VELOX_CHECK(otherIsEmpty.has_value()); + if (isEmpty && otherIsEmpty.value()) { + // Both accumulators are empty, no need to do the combination. + return false; + } + + bool currentOverflow = !isEmpty && !sum.has_value(); + bool otherOverflow = !otherIsEmpty.value() && !otherSum.has_value(); + if (currentOverflow || otherOverflow) { + sum = std::nullopt; + isEmpty = false; + } else { + int128_t result; + overflow += + DecimalUtil::addWithOverflow(result, otherSum.value(), sum.value()); + sum = result; + isEmpty &= otherIsEmpty.value(); + } + return true; + } + + bool writeIntermediateResult( + bool nonNullGroup, + exec::out_type& out) { + if (!nonNullGroup) { + // If a group is null, all values in this group are null. In Spark, this + // group will be the initial value, where sum is 0 and isEmpty is true. + out = std::make_tuple(static_cast(0), true); + } else { + auto finalResult = computeFinalResult(); + if (finalResult.has_value()) { + out = std::make_tuple( + static_cast(finalResult.value()), isEmpty); + } else { + // Sum should be set to null on overflow, + // and isEmpty should be set to false. + out.template set_null_at<0>(); + out.template get_writer_at<1>() = false; + } + } + return true; + } + + bool writeFinalResult(bool nonNullGroup, exec::out_type& out) { + if (!nonNullGroup || isEmpty) { + // If isEmpty is true, we should set null. + return false; + } + auto finalResult = computeFinalResult(); + if (finalResult.has_value()) { + out = static_cast(finalResult.value()); + return true; + } else { + // Sum should be set to null on overflow. + return false; + } + } + }; +}; + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index b44d3c4397d1..bed5ede19ca3 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -34,6 +34,6 @@ void registerAggregateFunctions( registerBitwiseXorAggregate(prefix); registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); registerAverage(prefix + "avg", withCompanionFunctions); - registerSum(prefix + "sum"); + registerSum(prefix + "sum", withCompanionFunctions); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/SumAggregate.cpp b/velox/functions/sparksql/aggregates/SumAggregate.cpp index 486331631ec4..78b05a1935ee 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.cpp +++ b/velox/functions/sparksql/aggregates/SumAggregate.cpp @@ -16,6 +16,7 @@ #include "velox/functions/sparksql/aggregates/SumAggregate.h" #include "velox/functions/lib/aggregates/SumAggregateBase.h" +#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h" using namespace facebook::velox::functions::aggregate; @@ -24,9 +25,27 @@ namespace facebook::velox::functions::aggregate::sparksql { namespace { template using SumAggregate = SumAggregateBase; + +TypePtr getDecimalSumType(const TypePtr& resultType) { + if (resultType->isRow()) { + // If the resultType is ROW, then the type if sum is the type of the first + // child of the ROW. + return resultType->childAt(0); + } + return resultType; } -exec::AggregateRegistrationResult registerSum(const std::string& name) { +void checkAccumulatorRowType(const TypePtr& type) { + VELOX_CHECK_EQ(type->kind(), TypeKind::ROW); + VELOX_CHECK( + type->childAt(0)->isShortDecimal() || type->childAt(0)->isLongDecimal()); + VELOX_CHECK_EQ(type->childAt(1)->kind(), TypeKind::BOOLEAN); +} +} // namespace + +exec::AggregateRegistrationResult registerSum( + const std::string& name, + bool withCompanionFunctions) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() .returnType("real") @@ -38,6 +57,15 @@ exec::AggregateRegistrationResult registerSum(const std::string& name) { .intermediateType("double") .argumentType("double") .build(), + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("r_precision", "min(38, a_precision + 10)") + .integerVariable("r_scale", "min(38, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)") + .returnType("DECIMAL(r_precision, r_scale)") + .build(), }; for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) { @@ -71,13 +99,25 @@ exec::AggregateRegistrationResult registerSum(const std::string& name) { BIGINT()); case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { - VELOX_NYI(); + auto const sumType = getDecimalSumType(resultType); + if (sumType->isShortDecimal()) { + return std::make_unique>>(resultType); + } else if (sumType->isLongDecimal()) { + return std::make_unique>>(resultType); + } } return std::make_unique>( BIGINT()); } case TypeKind::HUGEINT: { - VELOX_NYI(); + if (inputType->isLongDecimal()) { + // If inputType is long decimal, + // its output type is always long decimal. + return std::make_unique>>(resultType); + } } case TypeKind::REAL: if (resultType->kind() == TypeKind::REAL) { @@ -93,6 +133,19 @@ exec::AggregateRegistrationResult registerSum(const std::string& name) { } return std::make_unique>( DOUBLE()); + case TypeKind::ROW: { + VELOX_DCHECK(!exec::isRawInput(step)); + checkAccumulatorRowType(inputType); + // For the intermediate aggregation step, input intermediate sum + // type is equal to final result sum type. + if (inputType->childAt(0)->isShortDecimal()) { + return std::make_unique>>(resultType); + } else if (inputType->childAt(0)->isLongDecimal()) { + return std::make_unique>>(resultType); + } + } default: VELOX_CHECK( false, @@ -100,7 +153,8 @@ exec::AggregateRegistrationResult registerSum(const std::string& name) { name, inputType->kindName()); } - }); + }, + withCompanionFunctions); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/SumAggregate.h b/velox/functions/sparksql/aggregates/SumAggregate.h index fbeb4198a9a3..c9059b0193d1 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.h +++ b/velox/functions/sparksql/aggregates/SumAggregate.h @@ -22,6 +22,8 @@ namespace facebook::velox::functions::aggregate::sparksql { -exec::AggregateRegistrationResult registerSum(const std::string& name); +exec::AggregateRegistrationResult registerSum( + const std::string& name, + bool withCompanionFunctions); } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 10a088c2db20..40be380c0fa6 100644 --- a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -14,9 +14,13 @@ * limitations under the License. */ +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/lib/aggregates/tests/SumTestBase.h" #include "velox/functions/sparksql/aggregates/Register.h" +using facebook::velox::exec::test::PlanBuilder; +using namespace facebook::velox::exec::test; using namespace facebook::velox::functions::aggregate::test; namespace facebook::velox::functions::aggregate::sparksql::test { @@ -26,7 +30,101 @@ class SumAggregationTest : public SumTestBase { protected: void SetUp() override { SumTestBase::SetUp(); - registerAggregateFunctions("spark_"); + registerAggregateFunctions("spark_", true); + } + + protected: + // Check global partial agg overflow, and final agg output null. + void decimalGlobalSumOverflow( + const std::vector>& input, + const std::vector>& output) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector({makeNullableFlatVector({input}, type)}); + auto expected = + makeRowVector({makeNullableFlatVector({output}, type)}); + testAggregations( + {in}, + {}, + {"spark_sum(c0)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {in}, + [](auto& /*builder*/) {}, + {}, + {"spark_sum(c0)"}, + {{type}}, + {}, + {expected}, + {}); + } + + // Check group by partial agg overflow, and final agg output null. + void decimalGroupBySumOverflow( + const std::vector>& input) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 10; }), + makeNullableFlatVector(input, type)}); + auto expected = makeRowVector( + {makeFlatVector(10, [](auto row) { return row; }), + makeNullableFlatVector( + std::vector>(10, std::nullopt), type)}); + testAggregations( + {in}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {in}, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_sum(c1)"}, + {{type}}, + {"c0", "a0"}, + {expected}, + {}); + } + + template + void decimalSumAllNulls( + const std::vector>& input, + const TypePtr& inputType, + const std::vector>& output, + const TypePtr& outputType) { + std::vector vectors; + VectorPtr inputDecimalVector = + makeNullableFlatVector(input, inputType); + for (int i = 0; i < 5; ++i) { + vectors.emplace_back(makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 4; }), + inputDecimalVector})); + } + + VectorPtr outputDecimalVector = + makeNullableFlatVector(output, outputType); + auto expected = makeRowVector( + {makeFlatVector(std::vector{0, 1, 2, 3}), + outputDecimalVector}); + testAggregations( + {vectors}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {vectors}, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_sum(c1)"}, + {{inputType}}, + {"c0", "a0"}, + {expected}, + {}); } }; @@ -38,5 +136,353 @@ TEST_F(SumAggregationTest, hookLimits) { testHookLimits(); } +TEST_F(SumAggregationTest, decimalSumCompanionPartial) { + std::vector shortDecimalRawVector; + int128_t sum = 0; + for (int i = 0; i < 100; ++i) { + shortDecimalRawVector.emplace_back(i * 1000); + sum += i * 1000; + } + + auto input = makeRowVector( + {makeFlatVector(shortDecimalRawVector, DECIMAL(10, 1))}); + auto plan = PlanBuilder() + .values({input}) + .singleAggregation({}, {"spark_sum_partial(c0)"}) + .planNode(); + std::vector sumVector = {sum}; + std::vector isEmptyVector = {false}; + auto expected = makeRowVector({makeRowVector( + {makeFlatVector(sumVector, DECIMAL(20, 1)), + makeFlatVector(isEmptyVector)})}); + AssertQueryBuilder(plan).assertResults(expected); +} + +TEST_F(SumAggregationTest, decimalSumCompanionMerge) { + auto intermediateInput = makeRowVector({makeRowVector( + {makeFlatVector( + std::vector{1000, 2000, 3000}, DECIMAL(20, 1)), + makeFlatVector(std::vector{false, false, false})})}); + + auto plan = PlanBuilder() + .values({intermediateInput}) + .singleAggregation({}, {"spark_sum_merge(c0)"}) + .planNode(); + auto expected = makeRowVector({makeRowVector( + {makeFlatVector(std::vector{6000}, DECIMAL(20, 1)), + makeFlatVector(std::vector{false})})}); + AssertQueryBuilder(plan).assertResults(expected); +} + +TEST_F(SumAggregationTest, decimalSum) { + std::vector> shortDecimalRawVector; + std::vector> longDecimalRawVector; + for (int i = 0; i < 1000; ++i) { + shortDecimalRawVector.emplace_back(i * 1000); + longDecimalRawVector.emplace_back(HugeInt::build(i * 10, i * 100)); + } + shortDecimalRawVector.emplace_back(std::nullopt); + longDecimalRawVector.emplace_back(std::nullopt); + auto input = makeRowVector( + {makeNullableFlatVector(shortDecimalRawVector, DECIMAL(10, 1)), + makeNullableFlatVector(longDecimalRawVector, DECIMAL(23, 4))}); + createDuckDbTable({input}); + testAggregations( + {input}, + {}, + {"spark_sum(c0)", "spark_sum(c1)"}, + "SELECT sum(c0), sum(c1) FROM tmp", + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_sum(c0)", "spark_sum(c1)"}, + {{DECIMAL(10, 1)}, {DECIMAL(23, 4)}}, + {}, + "SELECT sum(c0), sum(c1) FROM tmp", + {}); + + // Short decimal sum aggregation with multiple groups. + auto inputShortDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + std::vector{37220, 53450}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + std::vector{10410, 9250}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + std::vector{-12783, 0}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + std::vector{23178, 41093}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + std::vector{-10023, 5290}, DECIMAL(5, 2))}), + }; + + auto expectedShortDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{113848}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{50730}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{-7493}, DECIMAL(15, 2))})}; + + testAggregations( + inputShortDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedShortDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {inputShortDecimalRows}, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_sum(c1)"}, + {{DECIMAL(5, 2)}}, + {"c0", "a0"}, + expectedShortDecimalResult, + {}); + + // Long decimal sum aggregation with multiple groups. + auto inputLongDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + {HugeInt::build(13, 113848), HugeInt::build(12, 53450)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + {HugeInt::build(21, 10410), HugeInt::build(17, 9250)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + {HugeInt::build(25, 12783), HugeInt::build(19, 0)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + {HugeInt::build(31, 23178), HugeInt::build(82, 41093)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + {HugeInt::build(25, 10023), HugeInt::build(43, 5290)}, + DECIMAL(20, 2))}), + }; + + auto expectedLongDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{HugeInt::build(56, 190476)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{HugeInt::build(145, 70776)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{HugeInt::build(87, 18073)}, + DECIMAL(38, 2))})}; + + testAggregations( + inputLongDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedLongDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {inputShortDecimalRows}, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_sum(c1)"}, + {{DECIMAL(20, 2)}}, + {"c0", "a0"}, + expectedShortDecimalResult, + {}); +} + +TEST_F(SumAggregationTest, decimalGlobalSumOverflow) { + // Test Positive Overflow. + std::vector> longDecimalInput; + std::vector> longDecimalOutput; + // Create input with 2 kLongDecimalMax. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMin. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMax); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Test Negative Overflow. + longDecimalInput.clear(); + longDecimalOutput.clear(); + + // Create input with 2 kLongDecimalMin. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMax. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMin); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Check value in range. + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(-1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); +} + +TEST_F(SumAggregationTest, decimalGroupBySumOverflow) { + // Test Positive Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMax)); + + // Test Negative Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMin)); + + // Check value in range. + auto decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMax); + auto oneValueVector = std::vector>(10, 1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); + + decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMin); + oneValueVector = std::vector>(10, -1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); +} + +TEST_F(SumAggregationTest, decimalAllNullValues) { + std::vector> allNull(5, std::nullopt); + auto input = makeRowVector( + {makeNullableFlatVector(allNull, DECIMAL(20, 2))}); + std::vector> result = {std::nullopt}; + auto expected = + makeRowVector({makeNullableFlatVector(result, DECIMAL(30, 2))}); + testAggregations( + {input}, + {}, + {"spark_sum(c0)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_sum(c0)"}, + {{DECIMAL(20, 2)}}, + {}, + {expected}, + {}); +} + +// Test if all values in some groups are null, the final sum of this group +// should be null. +TEST_F(SumAggregationTest, decimalSomeGroupsAllnullValues) { + std::vector> shortDecimalNulls(20); + std::vector> longDecimalNulls(20); + for (int i = 0; i < 20; i++) { + if (i % 4 == 1 || i % 4 == 3) { + // not all groups are null + shortDecimalNulls[i] = 1; + longDecimalNulls[i] = 1; + } + } + + // Test short decimal inputs and the output sum is short decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(7, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(17, 2)); + + // Test short decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(17, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(27, 2)); + + // Test long decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + longDecimalNulls, + DECIMAL(25, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(35, 2)); +} + +TEST_F(SumAggregationTest, decimalRangeOverflow) { + // HugeInt::build(542101086242752217, 68739955140067328) = + // 10'000'000'000'000'000'000'000'000'000'000'000'000, + // one followed by 37 zeros. + int128_t largeNumber = HugeInt::build(542101086242752217, 68739955140067328); + std::vector firstLargeDecimals(11, largeNumber); + std::vector secondLargeDecimals(1, largeNumber); + auto firstInput = makeRowVector( + {makeFlatVector(firstLargeDecimals, DECIMAL(38, 18))}); + auto secondInput = makeRowVector( + {makeFlatVector(secondLargeDecimals, DECIMAL(38, 18))}); + std::vector> result = {std::nullopt}; + auto expected = makeRowVector( + {makeNullableFlatVector(result, DECIMAL(38, 18))}); + testAggregations( + {firstInput, secondInput}, + {}, + {"spark_sum(c0)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + testAggregationsWithCompanion( + {firstInput, secondInput}, + [](auto& /*builder*/) {}, + {}, + {"spark_sum(c0)"}, + {{DECIMAL(38, 18)}}, + {}, + {expected}, + {}); +} } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test