From 0251654db3164b0b985e309ea9a382082aab1b85 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 10 Nov 2023 19:32:40 +0800 Subject: [PATCH] Add skewness Spark agg function --- .../exec/tests/SparkAggregationFuzzerTest.cpp | 6 +- velox/functions/lib/aggregates/CMakeLists.txt | 6 +- .../CentralMomentsAggregatesBase.cpp | 52 ++ .../aggregates/CentralMomentsAggregatesBase.h | 440 ++++++++++++++++ .../aggregates/CentralMomentsAggregates.cpp | 468 +----------------- .../sparksql/aggregates/CMakeLists.txt | 1 + .../aggregates/CentralMomentsAggregate.cpp | 107 ++++ .../aggregates/CentralMomentsAggregate.h | 27 + .../sparksql/aggregates/Register.cpp | 2 + .../sparksql/aggregates/tests/CMakeLists.txt | 3 +- .../tests/CentralMomentsAggregationTest.cpp | 60 +++ 11 files changed, 713 insertions(+), 459 deletions(-) create mode 100644 velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp create mode 100644 velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h create mode 100644 velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp create mode 100644 velox/functions/sparksql/aggregates/CentralMomentsAggregate.h create mode 100644 velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp diff --git a/velox/exec/tests/SparkAggregationFuzzerTest.cpp b/velox/exec/tests/SparkAggregationFuzzerTest.cpp index 257523bebe01f..313ab7d891333 100644 --- a/velox/exec/tests/SparkAggregationFuzzerTest.cpp +++ b/velox/exec/tests/SparkAggregationFuzzerTest.cpp @@ -66,7 +66,11 @@ int main(int argc, char** argv) { {"first", nullptr}, {"first_ignore_null", nullptr}, {"max_by", nullptr}, - {"min_by", nullptr}}; + {"min_by", nullptr}, + // The skewness functions of Velox and DuckDB use different + // algorithms. + // https://github.com/facebookincubator/velox/issues/4845 + {"skewness", nullptr}}; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; auto duckQueryRunner = diff --git a/velox/functions/lib/aggregates/CMakeLists.txt b/velox/functions/lib/aggregates/CMakeLists.txt index 19e3eee7de655..1eabebce1f16d 100644 --- a/velox/functions/lib/aggregates/CMakeLists.txt +++ b/velox/functions/lib/aggregates/CMakeLists.txt @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_functions_aggregates SingleValueAccumulator.cpp - AverageAggregateBase.cpp ValueSet.cpp) +add_library( + velox_functions_aggregates + AverageAggregateBase.cpp CentralMomentsAggregatesBase.cpp + SingleValueAccumulator.cpp ValueSet.cpp) target_link_libraries(velox_functions_aggregates velox_exec velox_presto_serializer Folly::folly) diff --git a/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp new file mode 100644 index 0000000000000..4e972fb24795a --- /dev/null +++ b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" + +namespace facebook::velox::functions::aggregate { + +void checkAccumulatorRowType( + const TypePtr& type, + const std::string& errorMessage) { + VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.count)->kind(), + TypeKind::BIGINT, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m1)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m2)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m3)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m4)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); +} + +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h new file mode 100644 index 0000000000000..d0bd97874715b --- /dev/null +++ b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h @@ -0,0 +1,440 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/Aggregate.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::aggregate { + +// Indices into RowType representing intermediate results of skewness and +// kurtosis. Columns appear in alphabetical order. +struct CentralMomentsIndices { + int32_t count; + int32_t m1; + int32_t m2; + int32_t m3; + int32_t m4; +}; +constexpr CentralMomentsIndices kCentralMomentsIndices{0, 1, 2, 3, 4}; + +struct CentralMomentsAccumulator { + double count() const { + return count_; + } + + double m1() const { + return m1_; + } + + double m2() const { + return m2_; + } + + double m3() const { + return m3_; + } + + double m4() const { + return m4_; + } + + void update(double value) { + double oldCount = count(); + count_ += 1; + double oldM1 = m1(); + double oldM2 = m2(); + double oldM3 = m3(); + double delta = value - oldM1; + double deltaN = delta / count(); + double deltaN2 = deltaN * deltaN; + double dm2 = delta * deltaN * oldCount; + + m1_ += deltaN; + m2_ += dm2; + m3_ += dm2 * deltaN * (count() - 2) - 3 * deltaN * oldM2; + m4_ += dm2 * deltaN2 * (count() * (double)count() - 3 * count() + 3) + + 6 * deltaN2 * oldM2 - 4 * deltaN * oldM3; + } + + inline void merge(const CentralMomentsAccumulator& other) { + merge(other.count(), other.m1(), other.m2(), other.m3(), other.m4()); + } + + void merge( + double otherCount, + double otherM1, + double otherM2, + double otherM3, + double otherM4) { + if (otherCount == 0) { + return; + } + + double oldCount = count(); + count_ += otherCount; + + double oldM1 = m1(); + double oldM2 = m2(); + double oldM3 = m3(); + double delta = otherM1 - oldM1; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m1_ = (oldCount * oldM1 + otherCount * otherM1) / count(); + m2_ += otherM2 + delta2 * oldCount * otherCount / count(); + m3_ += otherM3 + + delta3 * oldCount * otherCount * (oldCount - otherCount) / + (count() * count()) + + 3 * delta * (oldCount * otherM2 - otherCount * oldM2) / count(); + m4_ += otherM4 + + delta4 * oldCount * otherCount * + (oldCount * oldCount - oldCount * otherCount + + otherCount * otherCount) / + (count() * count() * count()) + + 6 * delta2 * + (oldCount * oldCount * otherM2 + otherCount * otherCount * oldM2) / + (count() * count()) + + 4 * delta * (oldCount * otherM3 - otherCount * oldM3) / count(); + } + + private: + int64_t count_{0}; + double m1_{0}; + double m2_{0}; + double m3_{0}; + double m4_{0}; +}; + +template +SimpleVector* asSimpleVector( + const RowVector* rowVector, + int32_t childIndex) { + auto result = rowVector->childAt(childIndex)->as>(); + VELOX_CHECK_NOT_NULL(result); + return result; +} + +class CentralMomentsIntermediateInput { + public: + explicit CentralMomentsIntermediateInput( + const RowVector* rowVector, + const CentralMomentsIndices& indices = kCentralMomentsIndices) + : count_{asSimpleVector(rowVector, indices.count)}, + m1_{asSimpleVector(rowVector, indices.m1)}, + m2_{asSimpleVector(rowVector, indices.m2)}, + m3_{asSimpleVector(rowVector, indices.m3)}, + m4_{asSimpleVector(rowVector, indices.m4)} {} + + void mergeInto(CentralMomentsAccumulator& accumulator, vector_size_t row) { + accumulator.merge( + count_->valueAt(row), + m1_->valueAt(row), + m2_->valueAt(row), + m3_->valueAt(row), + m4_->valueAt(row)); + } + + protected: + SimpleVector* count_; + SimpleVector* m1_; + SimpleVector* m2_; + SimpleVector* m3_; + SimpleVector* m4_; +}; + +template +T* mutableRawValues(const RowVector* rowVector, int32_t childIndex) { + return rowVector->childAt(childIndex) + ->as>() + ->mutableRawValues(); +} + +class CentralMomentsIntermediateResult { + public: + explicit CentralMomentsIntermediateResult( + const RowVector* rowVector, + const CentralMomentsIndices& indices = kCentralMomentsIndices) + : count_{mutableRawValues(rowVector, indices.count)}, + m1_{mutableRawValues(rowVector, indices.m1)}, + m2_{mutableRawValues(rowVector, indices.m2)}, + m3_{mutableRawValues(rowVector, indices.m3)}, + m4_{mutableRawValues(rowVector, indices.m4)} {} + + static std::string type() { + return "row(bigint,double,double,double,double)"; + } + + void set(vector_size_t row, const CentralMomentsAccumulator& accumulator) { + count_[row] = accumulator.count(); + m1_[row] = accumulator.m1(); + m2_[row] = accumulator.m2(); + m3_[row] = accumulator.m3(); + m4_[row] = accumulator.m4(); + } + + private: + int64_t* count_; + double* m1_; + double* m2_; + double* m3_; + double* m4_; +}; + +// T is the input type for partial aggregation, it can be integer, double or +// float. Not used for final aggregation. TResultAccessor is the type of the +// static struct that will access the result in a certain way from the +// CentralMoments Accumulator. +template +class CentralMomentsAggregatesBase : public exec::Aggregate { + public: + explicit CentralMomentsAggregatesBase(TypePtr resultType) + : exec::Aggregate(resultType) {} + + int32_t accumulatorAlignmentSize() const override { + return alignof(CentralMomentsAccumulator); + } + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(CentralMomentsAccumulator); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) CentralMomentsAccumulator(); + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedRaw_.decode(*args[0], rows); + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected( + [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedRaw_.isNullAt(i)) { + return; + } + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], data[i]); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedRaw_.decode(*args[0], rows); + + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected( + [&](vector_size_t i) { updateNonNullValue(group, value); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedRaw_.isNullAt(i)) { + updateNonNullValue(group, decodedRaw_.valueAt(i)); + } + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + CentralMomentsAccumulator accData; + rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); + updateNonNullValue(group, accData); + } else { + CentralMomentsAccumulator accData; + rows.applyToSelected( + [&](vector_size_t i) { accData.update(decodedRaw_.valueAt(i)); }); + updateNonNullValue(group, accData); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + CentralMomentsIntermediateInput input{baseRowVector}; + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + rows.applyToSelected([&](vector_size_t i) { + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + CentralMomentsIntermediateInput input{baseRowVector}; + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + CentralMomentsAccumulator accData; + rows.applyToSelected( + [&](vector_size_t i) { input.mergeInto(accData, decodedIndex); }); + updateNonNullValue(group, accData); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(group); + input.mergeInto(*accumulator(group), decodedIndex); + }); + } else { + CentralMomentsAccumulator accData; + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + input.mergeInto(accData, decodedIndex); + }); + updateNonNullValue(group, accData); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = getRawNulls(vector); + + double* rawValues = vector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + vector->setNull(i, true); + } else { + auto* accData = accumulator(group); + if (TResultAccessor::hasResult(*accData)) { + clearNull(rawNulls, i); + rawValues[i] = TResultAccessor::result(*accData); + } else { + vector->setNull(i, true); + } + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto rowVector = (*result)->as(); + rowVector->resize(numGroups); + for (auto& child : rowVector->children()) { + child->resize(numGroups); + } + + uint64_t* rawNulls = getRawNulls(rowVector); + + CentralMomentsIntermediateResult centralMomentsResult{rowVector}; + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + rowVector->setNull(i, true); + } else { + clearNull(rawNulls, i); + centralMomentsResult.set(i, *accumulator(group)); + } + } + } + + private: + template + inline void updateNonNullValue(char* group, T value) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + CentralMomentsAccumulator* accData = accumulator(group); + accData->update((double)value); + } + + template + inline void updateNonNullValue( + char* group, + const CentralMomentsAccumulator& accData) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + CentralMomentsAccumulator* thisAccData = accumulator(group); + thisAccData->merge(accData); + } + + inline CentralMomentsAccumulator* accumulator(char* group) { + return exec::Aggregate::value(group); + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; +}; + +void checkAccumulatorRowType( + const TypePtr& type, + const std::string& errorMessage); + +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp b/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp index 2c5b9c28d2a33..5d03ab2c861a9 100644 --- a/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp +++ b/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp @@ -13,112 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/Aggregate.h" +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/vector/FlatVector.h" - -namespace facebook::velox::aggregate::prestosql { - -namespace { -// Indices into RowType representing intermediate results of skewness and -// kurtosis. Columns appear in alphabetical order. -struct CentralMomentsIndices { - int32_t count; - int32_t m1; - int32_t m2; - int32_t m3; - int32_t m4; -}; -constexpr CentralMomentsIndices kCentralMomentsIndices{0, 1, 2, 3, 4}; - -struct CentralMomentsAccumulator { - double count() const { - return count_; - } - - double m1() const { - return m1_; - } - - double m2() const { - return m2_; - } - - double m3() const { - return m3_; - } - - double m4() const { - return m4_; - } - - void update(double value) { - double oldCount = count(); - count_ += 1; - double oldM1 = m1(); - double oldM2 = m2(); - double oldM3 = m3(); - double delta = value - oldM1; - double deltaN = delta / count(); - double deltaN2 = deltaN * deltaN; - double dm2 = delta * deltaN * oldCount; - - m1_ += deltaN; - m2_ += dm2; - m3_ += dm2 * deltaN * (count() - 2) - 3 * deltaN * oldM2; - m4_ += dm2 * deltaN2 * (count() * (double)count() - 3 * count() + 3) + - 6 * deltaN2 * oldM2 - 4 * deltaN * oldM3; - } - inline void merge(const CentralMomentsAccumulator& other) { - merge(other.count(), other.m1(), other.m2(), other.m3(), other.m4()); - } - - void merge( - double otherCount, - double otherM1, - double otherM2, - double otherM3, - double otherM4) { - if (otherCount == 0) { - return; - } - - double oldCount = count(); - count_ += otherCount; - - double oldM1 = m1(); - double oldM2 = m2(); - double oldM3 = m3(); - double delta = otherM1 - oldM1; - double delta2 = delta * delta; - double delta3 = delta * delta2; - double delta4 = delta2 * delta2; - - m1_ = (oldCount * oldM1 + otherCount * otherM1) / count(); - m2_ += otherM2 + delta2 * oldCount * otherCount / count(); - m3_ += otherM3 + - delta3 * oldCount * otherCount * (oldCount - otherCount) / - (count() * count()) + - 3 * delta * (oldCount * otherM2 - otherCount * oldM2) / count(); - m4_ += otherM4 + - delta4 * oldCount * otherCount * - (oldCount * oldCount - oldCount * otherCount + - otherCount * otherCount) / - (count() * count() * count()) + - 6 * delta2 * - (oldCount * oldCount * otherM2 + otherCount * otherCount * oldM2) / - (count() * count()) + - 4 * delta * (oldCount * otherM3 - otherCount * oldM3) / count(); - } +using namespace facebook::velox::functions::aggregate; - private: - int64_t count_{0}; - double m1_{0}; - double m2_{0}; - double m3_{0}; - double m4_{0}; -}; +namespace facebook::velox::aggregate::prestosql { struct SkewnessResultAccessor { static bool hasResult(const CentralMomentsAccumulator& accumulator) { @@ -146,350 +48,6 @@ struct KurtosisResultAccessor { } }; -template -SimpleVector* asSimpleVector( - const RowVector* rowVector, - int32_t childIndex) { - auto result = rowVector->childAt(childIndex)->as>(); - VELOX_CHECK_NOT_NULL(result); - return result; -} - -class CentralMomentsIntermediateInput { - public: - explicit CentralMomentsIntermediateInput( - const RowVector* rowVector, - const CentralMomentsIndices& indices = kCentralMomentsIndices) - : count_{asSimpleVector(rowVector, indices.count)}, - m1_{asSimpleVector(rowVector, indices.m1)}, - m2_{asSimpleVector(rowVector, indices.m2)}, - m3_{asSimpleVector(rowVector, indices.m3)}, - m4_{asSimpleVector(rowVector, indices.m4)} {} - - void mergeInto(CentralMomentsAccumulator& accumulator, vector_size_t row) { - accumulator.merge( - count_->valueAt(row), - m1_->valueAt(row), - m2_->valueAt(row), - m3_->valueAt(row), - m4_->valueAt(row)); - } - - protected: - SimpleVector* count_; - SimpleVector* m1_; - SimpleVector* m2_; - SimpleVector* m3_; - SimpleVector* m4_; -}; - -template -T* mutableRawValues(const RowVector* rowVector, int32_t childIndex) { - return rowVector->childAt(childIndex) - ->as>() - ->mutableRawValues(); -} - -class CentralMomentsIntermediateResult { - public: - explicit CentralMomentsIntermediateResult( - const RowVector* rowVector, - const CentralMomentsIndices& indices = kCentralMomentsIndices) - : count_{mutableRawValues(rowVector, indices.count)}, - m1_{mutableRawValues(rowVector, indices.m1)}, - m2_{mutableRawValues(rowVector, indices.m2)}, - m3_{mutableRawValues(rowVector, indices.m3)}, - m4_{mutableRawValues(rowVector, indices.m4)} {} - - static std::string type() { - return "row(bigint,double,double,double,double)"; - } - - void set(vector_size_t row, const CentralMomentsAccumulator& accumulator) { - count_[row] = accumulator.count(); - m1_[row] = accumulator.m1(); - m2_[row] = accumulator.m2(); - m3_[row] = accumulator.m3(); - m4_[row] = accumulator.m4(); - } - - private: - int64_t* count_; - double* m1_; - double* m2_; - double* m3_; - double* m4_; -}; - -// T is the input type for partial aggregation, it can be integer, double or -// float. Not used for final aggregation. TResultAccessor is the type of the -// static struct that will access the result in a certain way from the -// CentralMoments Accumulator. -template -class CentralMomentsAggregate : public exec::Aggregate { - public: - explicit CentralMomentsAggregate(TypePtr resultType) - : exec::Aggregate(resultType) {} - - int32_t accumulatorAlignmentSize() const override { - return alignof(CentralMomentsAccumulator); - } - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(CentralMomentsAccumulator); - } - - void initializeNewGroups( - char** groups, - folly::Range indices) override { - setAllNulls(groups, indices); - for (auto i : indices) { - new (groups[i] + offset_) CentralMomentsAccumulator(); - } - } - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedRaw_.decode(*args[0], rows); - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - auto value = decodedRaw_.valueAt(0); - rows.applyToSelected( - [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedRaw_.isNullAt(i)) { - return; - } - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - auto data = decodedRaw_.data(); - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], data[i]); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedRaw_.decode(*args[0], rows); - - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - auto value = decodedRaw_.valueAt(0); - rows.applyToSelected( - [&](vector_size_t i) { updateNonNullValue(group, value); }); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (!decodedRaw_.isNullAt(i)) { - updateNonNullValue(group, decodedRaw_.valueAt(i)); - } - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - auto data = decodedRaw_.data(); - CentralMomentsAccumulator accData; - rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); - updateNonNullValue(group, accData); - } else { - CentralMomentsAccumulator accData; - rows.applyToSelected( - [&](vector_size_t i) { accData.update(decodedRaw_.valueAt(i)); }); - updateNonNullValue(group, accData); - } - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - CentralMomentsIntermediateInput input{baseRowVector}; - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - rows.applyToSelected([&](vector_size_t i) { - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedPartial_.isNullAt(i)) { - return; - } - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - CentralMomentsIntermediateInput input{baseRowVector}; - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - CentralMomentsAccumulator accData; - rows.applyToSelected( - [&](vector_size_t i) { input.mergeInto(accData, decodedIndex); }); - updateNonNullValue(group, accData); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedPartial_.isNullAt(i)) { - return; - } - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(group); - input.mergeInto(*accumulator(group), decodedIndex); - }); - } else { - CentralMomentsAccumulator accData; - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - input.mergeInto(accData, decodedIndex); - }); - updateNonNullValue(group, accData); - } - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto vector = (*result)->as>(); - VELOX_CHECK(vector); - vector->resize(numGroups); - uint64_t* rawNulls = getRawNulls(vector); - - double* rawValues = vector->mutableRawValues(); - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - vector->setNull(i, true); - } else { - auto* accData = accumulator(group); - if (TResultAccessor::hasResult(*accData)) { - clearNull(rawNulls, i); - rawValues[i] = TResultAccessor::result(*accData); - } else { - vector->setNull(i, true); - } - } - } - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto rowVector = (*result)->as(); - rowVector->resize(numGroups); - for (auto& child : rowVector->children()) { - child->resize(numGroups); - } - - uint64_t* rawNulls = getRawNulls(rowVector); - - CentralMomentsIntermediateResult centralMomentsResult{rowVector}; - - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - rowVector->setNull(i, true); - } else { - clearNull(rawNulls, i); - centralMomentsResult.set(i, *accumulator(group)); - } - } - } - - private: - template - inline void updateNonNullValue(char* group, T value) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - CentralMomentsAccumulator* accData = accumulator(group); - accData->update((double)value); - } - - template - inline void updateNonNullValue( - char* group, - const CentralMomentsAccumulator& accData) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - CentralMomentsAccumulator* thisAccData = accumulator(group); - thisAccData->merge(accData); - } - - inline CentralMomentsAccumulator* accumulator(char* group) { - return exec::Aggregate::value(group); - } - - DecodedVector decodedRaw_; - DecodedVector decodedPartial_; -}; - -void checkAccumulatorRowType( - const TypePtr& type, - const std::string& errorMessage) { - VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.count)->kind(), - TypeKind::BIGINT, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m1)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m2)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m3)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m4)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); -} - template exec::AggregateRegistrationResult registerCentralMoments( const std::string& name) { @@ -521,22 +79,24 @@ exec::AggregateRegistrationResult registerCentralMoments( switch (inputType->kind()) { case TypeKind::SMALLINT: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::INTEGER: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::BIGINT: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::DOUBLE: return std::make_unique< - CentralMomentsAggregate>(resultType); + CentralMomentsAggregatesBase>( + resultType); case TypeKind::REAL: return std::make_unique< - CentralMomentsAggregate>(resultType); + CentralMomentsAggregatesBase>( + resultType); default: VELOX_UNSUPPORTED( "Unsupported input type: {}. " @@ -548,15 +108,13 @@ exec::AggregateRegistrationResult registerCentralMoments( inputType, "Input type for final aggregation must be " "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); - // final agg not use template T, int64_t here has no effect. return std::make_unique< - CentralMomentsAggregate>(resultType); + CentralMomentsAggregatesBase>( + resultType); } }); } -} // namespace - void registerCentralMomentsAggregates(const std::string& prefix) { registerCentralMoments(prefix + kKurtosis); registerCentralMoments(prefix + kSkewness); diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 77101bb8717ce..011ff1dfeb398 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -16,6 +16,7 @@ add_library( AverageAggregate.cpp BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp + CentralMomentsAggregate.cpp FirstLastAggregate.cpp MinMaxByAggregate.cpp Register.cpp diff --git a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp new file mode 100644 index 0000000000000..a2d90a1658148 --- /dev/null +++ b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +namespace { +struct SkewnessResultAccessor { + static bool hasResult(const CentralMomentsAccumulator& accumulator) { + return accumulator.count() >= 1 && accumulator.m2() != 0; + } + + static double result(const CentralMomentsAccumulator& accumulator) { + return std::sqrt(accumulator.count()) * accumulator.m3() / + std::pow(accumulator.m2(), 1.5); + } +}; + +template +exec::AggregateRegistrationResult registerCentralMoments( + const std::string& name) { + std::vector> signatures; + std::vector inputTypes = { + "smallint", "integer", "bigint", "real", "double"}; + for (const auto& inputType : inputTypes) { + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType(CentralMomentsIntermediateResult::type()) + .argumentType(inputType) + .build()); + } + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + const auto& inputType = argTypes[0]; + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::SMALLINT: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::INTEGER: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::BIGINT: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::DOUBLE: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::REAL: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + default: + VELOX_UNSUPPORTED( + "Unsupported input type: {}. " + "Expected SMALLINT, INTEGER, BIGINT, DOUBLE or REAL.", + inputType->toString()) + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + } + }, + true); +} +} // namespace + +void registerCentralMomentsAggregate(const std::string& prefix) { + registerCentralMoments(prefix + "skewness"); +} + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h new file mode 100644 index 0000000000000..601037176810b --- /dev/null +++ b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/exec/Aggregate.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +void registerCentralMomentsAggregate(const std::string& name); + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index b44d3c4397d14..8beb67633639f 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -19,6 +19,7 @@ #include "velox/functions/sparksql/aggregates/AverageAggregate.h" #include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h" #include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" +#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" #include "velox/functions/sparksql/aggregates/SumAggregate.h" namespace facebook::velox::functions::aggregate::sparksql { @@ -35,5 +36,6 @@ void registerAggregateFunctions( registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); registerAverage(prefix + "avg", withCompanionFunctions); registerSum(prefix + "sum"); + registerCentralMomentsAggregate(prefix); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index 22730f9d7e578..f7a5fdf05cdc1 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -14,11 +14,12 @@ add_executable( velox_functions_spark_aggregates_test + AverageAggregationTest.cpp BitwiseXorAggregationTest.cpp BloomFilterAggAggregateTest.cpp + CentralMomentsAggregationTest.cpp FirstAggregateTest.cpp LastAggregateTest.cpp - AverageAggregationTest.cpp Main.cpp MinMaxByAggregationTest.cpp SumAggregationTest.cpp) diff --git a/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp new file mode 100644 index 0000000000000..1a557f4b5b06b --- /dev/null +++ b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/sparksql/aggregates/Register.h" + +using namespace facebook::velox::exec::test; +using namespace facebook::velox::functions::aggregate::test; + +namespace facebook::velox::functions::aggregate::sparksql::test { + +namespace { +class CentralMomentsAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerAggregateFunctions("spark_"); + } + + void testSkewnessResult( + const RowVectorPtr& input, + const RowVectorPtr& expected) { + PlanBuilder builder(pool()); + builder.values({input}); + builder.singleAggregation({}, {"spark_skewness(c0)"}); + AssertQueryBuilder queryBuilder( + builder.planNode(), this->duckDbQueryRunner_); + queryBuilder.assertResults({expected}); + } +}; + +TEST_F(CentralMomentsAggregationTest, skewnessHasResult) { + auto input = makeRowVector({makeFlatVector({1, 2})}); + // Even when the count is 2, Spark still produces output. + auto expected = + makeRowVector({makeFlatVector(std::vector{0.0})}); + testSkewnessResult(input, expected); + + input = makeRowVector({makeFlatVector({1, 1})}); + expected = makeRowVector({makeNullableFlatVector( + std::vector>{std::nullopt})}); + testSkewnessResult(input, expected); +} + +} // namespace +} // namespace facebook::velox::functions::aggregate::sparksql::test