diff --git a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h index 4048e5c57742f..b9645f8d849fc 100644 --- a/velox/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/velox/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -238,11 +238,8 @@ class DecimalSumAggregate : public exec::Aggregate { if (decodedPartial_.isConstantMapping()) { if (!decodedPartial_.isNullAt(0)) { auto decodedIndex = decodedPartial_.index(0); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { - // If isEmpty is false and sum is null, it means this intermediate - // result has an overflow. The final accumulator of this group will - // be null. + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { rows.applyToSelected([&](vector_size_t i) { setNull(groups[i]); }); } else { auto sum = sumVector->valueAt(decodedIndex); @@ -259,8 +256,8 @@ class DecimalSumAggregate : public exec::Aggregate { return; } auto decodedIndex = decodedPartial_.index(i); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { setNull(groups[i]); } else { auto sum = sumVector->valueAt(decodedIndex); @@ -272,8 +269,8 @@ class DecimalSumAggregate : public exec::Aggregate { rows.applyToSelected([&](vector_size_t i) { clearNull(groups[i]); auto decodedIndex = decodedPartial_.index(i); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { setNull(groups[i]); } else { auto sum = sumVector->valueAt(decodedIndex); @@ -298,8 +295,8 @@ class DecimalSumAggregate : public exec::Aggregate { if (decodedPartial_.isConstantMapping()) { if (!decodedPartial_.isNullAt(0)) { auto decodedIndex = decodedPartial_.index(0); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { setNull(group); } else { auto sum = sumVector->valueAt(decodedIndex); @@ -318,8 +315,8 @@ class DecimalSumAggregate : public exec::Aggregate { return; } auto decodedIndex = decodedPartial_.index(i); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { setNull(group); return; } else { @@ -335,8 +332,8 @@ class DecimalSumAggregate : public exec::Aggregate { } rows.applyToSelected([&](vector_size_t i) { auto decodedIndex = decodedPartial_.index(i); - if (!isEmptyVector->valueAt(decodedIndex) && - sumVector->isNullAt(decodedIndex)) { + if (isIntermediateResultOverflow( + isEmptyVector, sumVector, decodedIndex)) { setNull(group); return; } else { @@ -373,78 +370,19 @@ class DecimalSumAggregate : public exec::Aggregate { return exec::Aggregate::value(group); } + inline bool isIntermediateResultOverflow( + const SimpleVector* isEmptyVector, + const SimpleVector* sumVector, + vector_size_t index) { + // If isEmpty is false and sum is null, it means this intermediate + // result has an overflow. The final accumulator of this group will + // be null. + return !isEmptyVector->valueAt(index) && sumVector->isNullAt(index); + } + DecodedVector decodedRaw_; DecodedVector decodedPartial_; TypePtr sumType_; }; -exec::AggregateRegistrationResult registerDecimalSumAggregate( - const std::string& name) { - std::vector> signatures{ - exec::AggregateFunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .integerVariable("r_precision", "min(38, a_precision + 10)") - .integerVariable("r_scale", "min(38, a_scale)") - .argumentType("DECIMAL(a_precision, a_scale)") - .intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)") - .returnType("DECIMAL(r_precision, r_scale)") - .build()}; - - return exec::registerAggregateFunction( - name, - std::move(signatures), - [name]( - core::AggregationNode::Step step, - const std::vector& argTypes, - const TypePtr& resultType, - const core::QueryConfig& /*config*/) - -> std::unique_ptr { - VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes only one argument", name); - auto& inputType = argTypes[0]; - auto sumType = - exec::isPartialOutput(step) ? resultType->childAt(0) : resultType; - switch (inputType->kind()) { - case TypeKind::BIGINT: { - DCHECK(exec::isRawInput(step)); - if (inputType->isShortDecimal()) { - if (sumType->isShortDecimal()) { - return std::make_unique>( - resultType, sumType); - } else if (sumType->isLongDecimal()) { - return std::make_unique>( - resultType, sumType); - } - } - } - case TypeKind::HUGEINT: - if (inputType->isLongDecimal()) { - // If inputType is long decimal, - // its output type always be long decimal. - return std::make_unique>( - resultType, sumType); - } - case TypeKind::ROW: { - DCHECK(!exec::isRawInput(step)); - // For intermediate input agg, input intermediate sum type - // is equal to final result sum type. - if (inputType->childAt(0)->isShortDecimal()) { - return std::make_unique>( - resultType, sumType); - } else if (inputType->childAt(0)->isLongDecimal()) { - return std::make_unique>( - resultType, sumType); - } - } - default: - VELOX_CHECK( - false, - "Unknown input type for {} aggregation {}", - name, - inputType->kindName()); - } - }, - true); -} - } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index 3f8d08c46efd6..5db71d74d539b 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -19,7 +19,6 @@ #include "velox/functions/sparksql/aggregates/AverageAggregate.h" #include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h" #include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" -#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h" #include "velox/functions/sparksql/aggregates/SumAggregate.h" namespace facebook::velox::functions::aggregate::sparksql { @@ -34,6 +33,5 @@ void registerAggregateFunctions(const std::string& prefix) { registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); registerAverage(prefix + "avg"); registerSum(prefix + "sum"); - registerDecimalSumAggregate(prefix + "sum"); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/SumAggregate.cpp b/velox/functions/sparksql/aggregates/SumAggregate.cpp index d0f3df57c4bef..c70b9f08d3a71 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.cpp +++ b/velox/functions/sparksql/aggregates/SumAggregate.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/functions/lib/aggregates/SumAggregateBase.h" +#include "velox/functions/sparksql/aggregates/DecimalSumAggregate.h" using namespace facebook::velox::functions::aggregate; @@ -22,7 +23,13 @@ namespace facebook::velox::functions::aggregate::sparksql { namespace { template using SumAggregate = SumAggregateBase; + +TypePtr getDecimalSumType( + const TypePtr& resultType, + core::AggregationNode::Step step) { + return exec::isPartialOutput(step) ? resultType->childAt(0) : resultType; } +} // namespace void registerSum(const std::string& name) { std::vector> signatures{ @@ -36,6 +43,15 @@ void registerSum(const std::string& name) { .intermediateType("double") .argumentType("double") .build(), + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("r_precision", "min(38, a_precision + 10)") + .integerVariable("r_scale", "min(38, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(r_precision, r_scale), boolean)") + .returnType("DECIMAL(r_precision, r_scale)") + .build(), }; for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) { @@ -69,13 +85,26 @@ void registerSum(const std::string& name) { BIGINT()); case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { - VELOX_NYI(); + auto sumType = getDecimalSumType(resultType, step); + if (sumType->isShortDecimal()) { + return std::make_unique>( + resultType, sumType); + } else if (sumType->isLongDecimal()) { + return std::make_unique>( + resultType, sumType); + } } return std::make_unique>( BIGINT()); } case TypeKind::HUGEINT: { - VELOX_NYI(); + if (inputType->isLongDecimal()) { + auto sumType = getDecimalSumType(resultType, step); + // If inputType is long decimal, + // its output type always be long decimal. + return std::make_unique>( + resultType, sumType); + } } case TypeKind::REAL: if (resultType->kind() == TypeKind::REAL) { @@ -91,6 +120,19 @@ void registerSum(const std::string& name) { } return std::make_unique>( DOUBLE()); + case TypeKind::ROW: { + DCHECK(!exec::isRawInput(step)); + auto sumType = getDecimalSumType(resultType, step); + // For intermediate input agg, input intermediate sum type + // is equal to final result sum type. + if (inputType->childAt(0)->isShortDecimal()) { + return std::make_unique>( + resultType, sumType); + } else if (inputType->childAt(0)->isLongDecimal()) { + return std::make_unique>( + resultType, sumType); + } + } default: VELOX_CHECK( false, diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index 9d06aa73236e6..22730f9d7e578 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -16,7 +16,6 @@ add_executable( velox_functions_spark_aggregates_test BitwiseXorAggregationTest.cpp BloomFilterAggAggregateTest.cpp - DecimalSumAggregateTest.cpp FirstAggregateTest.cpp LastAggregateTest.cpp AverageAggregationTest.cpp diff --git a/velox/functions/sparksql/aggregates/tests/DecimalSumAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/DecimalSumAggregateTest.cpp deleted file mode 100644 index e2c2887248c87..0000000000000 --- a/velox/functions/sparksql/aggregates/tests/DecimalSumAggregateTest.cpp +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/common/base/tests/GTestUtils.h" -#include "velox/exec/tests/utils/AssertQueryBuilder.h" -#include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/functions/lib/aggregates/tests/AggregationTestBase.h" -#include "velox/functions/sparksql/aggregates/Register.h" - -using facebook::velox::exec::test::PlanBuilder; -using namespace facebook::velox::exec::test; -using namespace facebook::velox::functions::aggregate::test; - -namespace facebook::velox::functions::aggregate::sparksql::test { -namespace { -class DecimalSumAggregateTest : public AggregationTestBase { - protected: - void SetUp() override { - AggregationTestBase::SetUp(); - registerAggregateFunctions("spark_"); - allowInputShuffle(); - } - - protected: - // check global partial agg overflow, and final agg output null - void decimalGlobalSumOverflow( - const std::vector>& input, - const std::vector>& output) { - const TypePtr type = DECIMAL(38, 0); - auto in = makeRowVector({makeNullableFlatVector({input}, type)}); - auto expected = - makeRowVector({makeNullableFlatVector({output}, type)}); - PlanBuilder builder(pool()); - builder.values({in}); - builder.partialAggregation({}, {"spark_sum(c0)"}).finalAggregation(); - AssertQueryBuilder queryBuilder( - builder.planNode(), this->duckDbQueryRunner_); - queryBuilder.assertResults({expected}); - } - - // check group by partial agg overflow, and final agg output null - void decimalGroupBySumOverflow( - const std::vector>& input) { - const TypePtr type = DECIMAL(38, 0); - auto in = makeRowVector( - {makeFlatVector(20, [](auto row) { return row % 10; }), - makeNullableFlatVector(input, type)}); - auto expected = makeRowVector( - {makeFlatVector(10, [](auto row) { return row; }), - makeNullableFlatVector( - std::vector>(10, std::nullopt), type)}); - PlanBuilder builder(pool()); - builder.values({in}); - builder.partialAggregation({"c0"}, {"spark_sum(c1)"}).finalAggregation(); - AssertQueryBuilder queryBuilder( - builder.planNode(), this->duckDbQueryRunner_); - queryBuilder.assertResults({expected}); - } - - template - void decimalSumAllNulls( - const std::vector>& input, - const TypePtr& inputType, - const std::vector>& output, - const TypePtr& outputType) { - std::vector vectors; - FlatVectorPtr inputDecimalVector; - if constexpr (std::is_same_v) { - inputDecimalVector = makeNullableFlatVector(input, inputType); - } else { - inputDecimalVector = makeNullableFlatVector(input, inputType); - } - for (int i = 0; i < 5; ++i) { - vectors.emplace_back(makeRowVector( - {makeFlatVector(20, [](auto row) { return row % 4; }), - inputDecimalVector})); - } - - FlatVectorPtr outputDecimalVector; - if constexpr (std::is_same_v) { - outputDecimalVector = makeNullableFlatVector(output, outputType); - } else { - outputDecimalVector = - makeNullableFlatVector(output, outputType); - } - auto expected = makeRowVector( - {makeFlatVector(std::vector{0, 1, 2, 3}), - outputDecimalVector}); - PlanBuilder builder(pool()); - builder.values({vectors}); - builder.singleAggregation({"c0"}, {"spark_sum(c1)"}); - AssertQueryBuilder queryBuilder( - builder.planNode(), this->duckDbQueryRunner_); - queryBuilder.assertResults({expected}); - } -}; - -TEST_F(DecimalSumAggregateTest, sumDecimal) { - std::vector> shortDecimalRawVector; - std::vector> longDecimalRawVector; - for (int i = 0; i < 1000; ++i) { - shortDecimalRawVector.emplace_back(i * 1000); - longDecimalRawVector.emplace_back(HugeInt::build(i * 10, i * 100)); - } - shortDecimalRawVector.emplace_back(std::nullopt); - longDecimalRawVector.emplace_back(std::nullopt); - auto input = makeRowVector( - {makeNullableFlatVector(shortDecimalRawVector, DECIMAL(10, 1)), - makeNullableFlatVector(longDecimalRawVector, DECIMAL(23, 4))}); - createDuckDbTable({input}); - testAggregations( - {input}, - {}, - {"spark_sum(c0)", "spark_sum(c1)"}, - "SELECT sum(c0), sum(c1) FROM tmp"); - - // Short decimal sum aggregation with multiple groups. - auto inputShortDecimalRows = { - makeRowVector( - {makeNullableFlatVector({1, 1}), - makeFlatVector( - std::vector{37220, 53450}, DECIMAL(5, 2))}), - makeRowVector( - {makeNullableFlatVector({2, 2}), - makeFlatVector( - std::vector{10410, 9250}, DECIMAL(5, 2))}), - makeRowVector( - {makeNullableFlatVector({3, 3}), - makeFlatVector( - std::vector{-12783, 0}, DECIMAL(5, 2))}), - makeRowVector( - {makeNullableFlatVector({1, 2}), - makeFlatVector( - std::vector{23178, 41093}, DECIMAL(5, 2))}), - makeRowVector( - {makeNullableFlatVector({2, 3}), - makeFlatVector( - std::vector{-10023, 5290}, DECIMAL(5, 2))}), - }; - - auto expectedShortDecimalResult = { - makeRowVector( - {makeNullableFlatVector({1}), - makeFlatVector( - std::vector{113848}, DECIMAL(15, 2))}), - makeRowVector( - {makeNullableFlatVector({2}), - makeFlatVector( - std::vector{50730}, DECIMAL(15, 2))}), - makeRowVector( - {makeNullableFlatVector({3}), - makeFlatVector( - std::vector{-7493}, DECIMAL(15, 2))})}; - - testAggregations( - inputShortDecimalRows, - {"c0"}, - {"spark_sum(c1)"}, - expectedShortDecimalResult); - - // Long decimal sum aggregation with multiple groups. - auto inputLongDecimalRows = { - makeRowVector( - {makeNullableFlatVector({1, 1}), - makeFlatVector( - {HugeInt::build(13, 113848), HugeInt::build(12, 53450)}, - DECIMAL(20, 2))}), - makeRowVector( - {makeNullableFlatVector({2, 2}), - makeFlatVector( - {HugeInt::build(21, 10410), HugeInt::build(17, 9250)}, - DECIMAL(20, 2))}), - makeRowVector( - {makeNullableFlatVector({3, 3}), - makeFlatVector( - {HugeInt::build(25, 12783), HugeInt::build(19, 0)}, - DECIMAL(20, 2))}), - makeRowVector( - {makeNullableFlatVector({1, 2}), - makeFlatVector( - {HugeInt::build(31, 23178), HugeInt::build(82, 41093)}, - DECIMAL(20, 2))}), - makeRowVector( - {makeNullableFlatVector({2, 3}), - makeFlatVector( - {HugeInt::build(25, 10023), HugeInt::build(43, 5290)}, - DECIMAL(20, 2))}), - }; - - auto expectedLongDecimalResult = { - makeRowVector( - {makeNullableFlatVector({1}), - makeFlatVector( - std::vector{HugeInt::build(56, 190476)}, - DECIMAL(38, 2))}), - makeRowVector( - {makeNullableFlatVector({2}), - makeFlatVector( - std::vector{HugeInt::build(145, 70776)}, - DECIMAL(38, 2))}), - makeRowVector( - {makeNullableFlatVector({3}), - makeFlatVector( - std::vector{HugeInt::build(87, 18073)}, - DECIMAL(38, 2))})}; - - testAggregations( - inputLongDecimalRows, - {"c0"}, - {"spark_sum(c1)"}, - expectedLongDecimalResult); -} - -TEST_F(DecimalSumAggregateTest, globalSumDecimalOverflow) { - // Test Positive Overflow. - std::vector> longDecimalInput; - std::vector> longDecimalOutput; - // Create input with 2 kLongDecimalMax. - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); - // The sum must overflow, and will return null - decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); - - // Now add kLongDecimalMin. - // The sum now must not overflow. - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); - longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMax); - decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); - - // Test Negative Overflow. - longDecimalInput.clear(); - longDecimalOutput.clear(); - - // Create input with 2 kLongDecimalMin. - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); - - // The sum must overflow, and will return null - decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); - - // Now add kLongDecimalMax. - // The sum now must not overflow. - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); - longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMin); - decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); - - // Check value in range. - longDecimalInput.clear(); - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); - longDecimalInput.emplace_back(1); - decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); - - longDecimalInput.clear(); - longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); - longDecimalInput.emplace_back(-1); - decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); -} - -TEST_F(DecimalSumAggregateTest, groupBySumDecimalOverflow) { - // Test Positive Overflow. - decimalGroupBySumOverflow( - std::vector>(20, DecimalUtil::kLongDecimalMax)); - - // Test Negative Overflow. - decimalGroupBySumOverflow( - std::vector>(20, DecimalUtil::kLongDecimalMin)); - - // Check value in range. - auto decimalVector = - std::vector>(10, DecimalUtil::kLongDecimalMax); - auto oneValueVector = std::vector>(10, 1); - decimalVector.insert( - decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); - decimalGroupBySumOverflow(decimalVector); - - decimalVector = - std::vector>(10, DecimalUtil::kLongDecimalMin); - oneValueVector = std::vector>(10, -1); - decimalVector.insert( - decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); - decimalGroupBySumOverflow(decimalVector); -} - -/// Test if all values in some groups are null, -/// the final sum of this group should be null. -TEST_F(DecimalSumAggregateTest, someGroupsAllnullValues) { - std::vector> shortDecimalNulls(20); - std::vector> longDecimalNulls(20); - for (int i = 0; i < 20; i++) { - if (i % 4 == 1 || i % 4 == 3) { - // not all groups are null - shortDecimalNulls[i] = 1; - longDecimalNulls[i] = 1; - } - } - - // Test short decimal inputs and the output sum is short decimal. - decimalSumAllNulls( - shortDecimalNulls, - DECIMAL(7, 2), - std::vector>{std::nullopt, 25, std::nullopt, 25}, - DECIMAL(17, 2)); - - // Test short decimal inputs and the output sum is long decimal. - decimalSumAllNulls( - shortDecimalNulls, - DECIMAL(17, 2), - std::vector>{std::nullopt, 25, std::nullopt, 25}, - DECIMAL(27, 2)); - - // Test long decimal inputs and the output sum is long decimal. - decimalSumAllNulls( - longDecimalNulls, - DECIMAL(25, 2), - std::vector>{std::nullopt, 25, std::nullopt, 25}, - DECIMAL(35, 2)); -} -} // namespace -} // namespace facebook::velox::functions::aggregate::sparksql::test diff --git a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 9cf3b67b5ce0b..869b475760136 100644 --- a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -14,9 +14,13 @@ * limitations under the License. */ +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/lib/aggregates/tests/SumTestBase.h" #include "velox/functions/sparksql/aggregates/Register.h" +using facebook::velox::exec::test::PlanBuilder; +using namespace facebook::velox::exec::test; using namespace facebook::velox::functions::aggregate::test; namespace facebook::velox::functions::aggregate::sparksql::test { @@ -28,11 +32,313 @@ class SumAggregationTest : public SumTestBase { SumTestBase::SetUp(); registerAggregateFunctions("spark_"); } + + protected: + // check global partial agg overflow, and final agg output null + void decimalGlobalSumOverflow( + const std::vector>& input, + const std::vector>& output) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector({makeNullableFlatVector({input}, type)}); + auto expected = + makeRowVector({makeNullableFlatVector({output}, type)}); + testAggregations( + {in}, + {}, + {"spark_sum(c0)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } + + // check group by partial agg overflow, and final agg output null + void decimalGroupBySumOverflow( + const std::vector>& input) { + const TypePtr type = DECIMAL(38, 0); + auto in = makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 10; }), + makeNullableFlatVector(input, type)}); + auto expected = makeRowVector( + {makeFlatVector(10, [](auto row) { return row; }), + makeNullableFlatVector( + std::vector>(10, std::nullopt), type)}); + testAggregations( + {in}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } + + template + void decimalSumAllNulls( + const std::vector>& input, + const TypePtr& inputType, + const std::vector>& output, + const TypePtr& outputType) { + std::vector vectors; + FlatVectorPtr inputDecimalVector; + if constexpr (std::is_same_v) { + inputDecimalVector = makeNullableFlatVector(input, inputType); + } else { + inputDecimalVector = makeNullableFlatVector(input, inputType); + } + for (int i = 0; i < 5; ++i) { + vectors.emplace_back(makeRowVector( + {makeFlatVector(20, [](auto row) { return row % 4; }), + inputDecimalVector})); + } + + FlatVectorPtr outputDecimalVector; + if constexpr (std::is_same_v) { + outputDecimalVector = makeNullableFlatVector(output, outputType); + } else { + outputDecimalVector = + makeNullableFlatVector(output, outputType); + } + auto expected = makeRowVector( + {makeFlatVector(std::vector{0, 1, 2, 3}), + outputDecimalVector}); + testAggregations( + {vectors}, + {"c0"}, + {"spark_sum(c1)"}, + {expected}, + /*config*/ {}, + /*testWithTableScan*/ false); + } }; TEST_F(SumAggregationTest, overflow) { SumTestBase::testAggregateOverflow("spark_sum"); } +TEST_F(SumAggregationTest, decimalSum) { + std::vector> shortDecimalRawVector; + std::vector> longDecimalRawVector; + for (int i = 0; i < 1000; ++i) { + shortDecimalRawVector.emplace_back(i * 1000); + longDecimalRawVector.emplace_back(HugeInt::build(i * 10, i * 100)); + } + shortDecimalRawVector.emplace_back(std::nullopt); + longDecimalRawVector.emplace_back(std::nullopt); + auto input = makeRowVector( + {makeNullableFlatVector(shortDecimalRawVector, DECIMAL(10, 1)), + makeNullableFlatVector(longDecimalRawVector, DECIMAL(23, 4))}); + createDuckDbTable({input}); + testAggregations( + {input}, + {}, + {"spark_sum(c0)", "spark_sum(c1)"}, + "SELECT sum(c0), sum(c1) FROM tmp", + /*config*/ {}, + /*testWithTableScan*/ false); + + // Short decimal sum aggregation with multiple groups. + auto inputShortDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + std::vector{37220, 53450}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + std::vector{10410, 9250}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + std::vector{-12783, 0}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + std::vector{23178, 41093}, DECIMAL(5, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + std::vector{-10023, 5290}, DECIMAL(5, 2))}), + }; + + auto expectedShortDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{113848}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{50730}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{-7493}, DECIMAL(15, 2))})}; + + testAggregations( + inputShortDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedShortDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); + + // Long decimal sum aggregation with multiple groups. + auto inputLongDecimalRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector( + {HugeInt::build(13, 113848), HugeInt::build(12, 53450)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector( + {HugeInt::build(21, 10410), HugeInt::build(17, 9250)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector( + {HugeInt::build(25, 12783), HugeInt::build(19, 0)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector( + {HugeInt::build(31, 23178), HugeInt::build(82, 41093)}, + DECIMAL(20, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector( + {HugeInt::build(25, 10023), HugeInt::build(43, 5290)}, + DECIMAL(20, 2))}), + }; + + auto expectedLongDecimalResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector( + std::vector{HugeInt::build(56, 190476)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector( + std::vector{HugeInt::build(145, 70776)}, + DECIMAL(38, 2))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector( + std::vector{HugeInt::build(87, 18073)}, + DECIMAL(38, 2))})}; + + testAggregations( + inputLongDecimalRows, + {"c0"}, + {"spark_sum(c1)"}, + expectedLongDecimalResult, + /*config*/ {}, + /*testWithTableScan*/ false); +} + +TEST_F(SumAggregationTest, decimalGlobalSumOverflow) { + // Test Positive Overflow. + std::vector> longDecimalInput; + std::vector> longDecimalOutput; + // Create input with 2 kLongDecimalMax. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMin. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMax); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Test Negative Overflow. + longDecimalInput.clear(); + longDecimalOutput.clear(); + + // Create input with 2 kLongDecimalMin. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + + // The sum must overflow, and will return null + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + // Now add kLongDecimalMax. + // The sum now must not overflow. + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalOutput.emplace_back(DecimalUtil::kLongDecimalMin); + decimalGlobalSumOverflow(longDecimalInput, longDecimalOutput); + + // Check value in range. + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMax); + longDecimalInput.emplace_back(1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); + + longDecimalInput.clear(); + longDecimalInput.emplace_back(DecimalUtil::kLongDecimalMin); + longDecimalInput.emplace_back(-1); + decimalGlobalSumOverflow(longDecimalInput, {std::nullopt}); +} + +TEST_F(SumAggregationTest, decimalGroupBySumOverflow) { + // Test Positive Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMax)); + + // Test Negative Overflow. + decimalGroupBySumOverflow( + std::vector>(20, DecimalUtil::kLongDecimalMin)); + + // Check value in range. + auto decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMax); + auto oneValueVector = std::vector>(10, 1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); + + decimalVector = + std::vector>(10, DecimalUtil::kLongDecimalMin); + oneValueVector = std::vector>(10, -1); + decimalVector.insert( + decimalVector.end(), oneValueVector.begin(), oneValueVector.end()); + decimalGroupBySumOverflow(decimalVector); +} + +/// Test if all values in some groups are null, the final sum of +/// this group should be null. +TEST_F(SumAggregationTest, decimalSomeGroupsAllnullValues) { + std::vector> shortDecimalNulls(20); + std::vector> longDecimalNulls(20); + for (int i = 0; i < 20; i++) { + if (i % 4 == 1 || i % 4 == 3) { + // not all groups are null + shortDecimalNulls[i] = 1; + longDecimalNulls[i] = 1; + } + } + + // Test short decimal inputs and the output sum is short decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(7, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(17, 2)); + + // Test short decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + shortDecimalNulls, + DECIMAL(17, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(27, 2)); + + // Test long decimal inputs and the output sum is long decimal. + decimalSumAllNulls( + longDecimalNulls, + DECIMAL(25, 2), + std::vector>{std::nullopt, 25, std::nullopt, 25}, + DECIMAL(35, 2)); +} } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test