Skip to content

Commit

Permalink
Fix comments and add testAggregationsWithCompanion test
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 23, 2024
1 parent 7a01252 commit 52f573a
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
19 changes: 18 additions & 1 deletion velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class SimpleAggregateAdapter : public Aggregate {
struct support_to_intermediate<T, std::void_t<decltype(&T::toIntermediate)>>
: std::true_type {};

// Whether the accumulator requires aligned access. If it is defined,
// SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// alignof(typename FUNC::AccumulatorType).
// Otherwise, SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// Aggregate::accumulatorAlignmentSize(), with a default value of 1.
template <typename T, typename = void>
struct aligned_accumulator : std::false_type {};

template <typename T>
struct aligned_accumulator<T, std::void_t<decltype(T::aligned_accumulator_)>>
: std::integral_constant<bool, T::aligned_accumulator_> {};

static constexpr bool aggregate_default_null_behavior_ =
aggregate_default_null_behavior<FUNC>::value;

Expand All @@ -160,6 +172,8 @@ class SimpleAggregateAdapter : public Aggregate {
static constexpr bool support_to_intermediate_ =
support_to_intermediate<FUNC>::value;

static constexpr bool aligned_accumulator_ = aligned_accumulator<FUNC>::value;

bool isFixedSize() const override {
return accumulator_is_fixed_size_;
}
Expand All @@ -173,7 +187,10 @@ class SimpleAggregateAdapter : public Aggregate {
}

int32_t accumulatorAlignmentSize() const override {
return alignof(typename FUNC::AccumulatorType);
if constexpr (aligned_accumulator_) {
return alignof(typename FUNC::AccumulatorType);
}
return Aggregate::accumulatorAlignmentSize();
}

void initializeNewGroups(
Expand Down
8 changes: 8 additions & 0 deletions velox/functions/sparksql/aggregates/DecimalSumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DecimalSumAggregate {
/// default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static constexpr bool aligned_accumulator_ = true;

static bool toIntermediate(
exec::out_type<Row<TSumType, bool>>& out,
exec::optional_arg_type<TInputType> in) {
Expand Down Expand Up @@ -94,6 +96,12 @@ class DecimalSumAggregate {
if (!data.has_value()) {
return false;
}
if (!sum.has_value()) {
// sum is initialized to 0. When it is nullopt, it implies that the
// input data must not be empty.
VELOX_CHECK(!isEmpty)
return true;
}
int128_t result;
overflow +=
DecimalUtil::addWithOverflow(result, data.value(), sum.value());
Expand Down
4 changes: 2 additions & 2 deletions velox/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ exec::AggregateRegistrationResult registerSum(
case TypeKind::ROW: {
VELOX_DCHECK(!exec::isRawInput(step));
checkAccumulatorRowType(inputType);
// For intermediate input agg, input intermediate sum type
// is equal to final result sum type.
// For the intermediate aggregation step, input intermediate sum
// type is equal to final result sum type.
if (inputType->childAt(0)->isShortDecimal()) {
return std::make_unique<exec::SimpleAggregateAdapter<
DecimalSumAggregate<int64_t, int64_t>>>(resultType);
Expand Down
72 changes: 72 additions & 0 deletions velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ class SumAggregationTest : public SumTestBase {
{expected},
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{in},
[](auto& /*builder*/) {},
{},
{"spark_sum(c0)"},
{{type}},
{},
{expected},
{});
}

// Check group by partial agg overflow, and final agg output null.
Expand All @@ -69,6 +78,15 @@ class SumAggregationTest : public SumTestBase {
{expected},
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{in},
[](auto& /*builder*/) {},
{"c0"},
{"spark_sum(c1)"},
{{type}},
{"c0", "a0"},
{expected},
{});
}

template <typename TIn, typename TOut>
Expand Down Expand Up @@ -107,6 +125,15 @@ class SumAggregationTest : public SumTestBase {
{expected},
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{vectors},
[](auto& /*builder*/) {},
{"c0"},
{"spark_sum(c1)"},
{{inputType}},
{"c0", "a0"},
{expected},
{});
}
};

Expand Down Expand Up @@ -176,6 +203,15 @@ TEST_F(SumAggregationTest, decimalSum) {
"SELECT sum(c0), sum(c1) FROM tmp",
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_sum(c0)", "spark_sum(c1)"},
{{DECIMAL(10, 1)}, {DECIMAL(23, 4)}},
{},
"SELECT sum(c0), sum(c1) FROM tmp",
{});

// Short decimal sum aggregation with multiple groups.
auto inputShortDecimalRows = {
Expand Down Expand Up @@ -222,6 +258,15 @@ TEST_F(SumAggregationTest, decimalSum) {
expectedShortDecimalResult,
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{inputShortDecimalRows},
[](auto& /*builder*/) {},
{"c0"},
{"spark_sum(c1)"},
{{DECIMAL(5, 2)}},
{"c0", "a0"},
expectedShortDecimalResult,
{});

// Long decimal sum aggregation with multiple groups.
auto inputLongDecimalRows = {
Expand Down Expand Up @@ -276,6 +321,15 @@ TEST_F(SumAggregationTest, decimalSum) {
expectedLongDecimalResult,
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{inputShortDecimalRows},
[](auto& /*builder*/) {},
{"c0"},
{"spark_sum(c1)"},
{{DECIMAL(20, 2)}},
{"c0", "a0"},
expectedShortDecimalResult,
{});
}

TEST_F(SumAggregationTest, decimalGlobalSumOverflow) {
Expand Down Expand Up @@ -362,6 +416,15 @@ TEST_F(SumAggregationTest, decimalAllNullValues) {
{expected},
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_sum(c0)"},
{{DECIMAL(20, 2)}},
{},
{expected},
{});
}

// Test if all values in some groups are null, the final sum of this group
Expand Down Expand Up @@ -420,6 +483,15 @@ TEST_F(SumAggregationTest, decimalRangeOverflow) {
{expected},
/*config*/ {},
/*testWithTableScan*/ false);
testAggregationsWithCompanion(
{firstInput, secondInput},
[](auto& /*builder*/) {},
{},
{"spark_sum(c0)"},
{{DECIMAL(38, 18)}},
{},
{expected},
{});
}
} // namespace
} // namespace facebook::velox::functions::aggregate::sparksql::test

0 comments on commit 52f573a

Please sign in to comment.