From 3b6da621f21ec0b621d054d412c2d7efe8aaed23 Mon Sep 17 00:00:00 2001 From: "Ma, Rong" Date: Mon, 27 May 2024 12:41:15 +0800 Subject: [PATCH 1/3] fix udaf register and refine example --- .../spark/sql/expression/UDFResolver.scala | 14 +- .../gluten/expression/VeloxUdfSuite.scala | 5 +- cpp/velox/udf/examples/MyUDAF.cc | 366 +++++++++++++++--- cpp/velox/udf/examples/MyUDF.cc | 30 +- cpp/velox/udf/examples/UdfCommon.h | 53 +++ 5 files changed, 378 insertions(+), 90 deletions(-) create mode 100644 cpp/velox/udf/examples/UdfCommon.h diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index ec98e98f1c6e..915fc554584c 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -175,12 +175,16 @@ object UDFResolver extends Logging { intermediateTypes: ExpressionType, variableArity: Boolean): Unit = { assert(argTypes.dataType.isInstanceOf[StructType]) - assert(intermediateTypes.dataType.isInstanceOf[StructType]) - val aggBufferAttributes = - intermediateTypes.dataType.asInstanceOf[StructType].fields.zipWithIndex.map { - case (f, index) => - AttributeReference(s"inter_$index", f.dataType, f.nullable)() + val aggBufferAttributes: Seq[AttributeReference] = + intermediateTypes.dataType match { + case StructType(fields) => + fields.zipWithIndex.map { + case (f, index) => + AttributeReference(s"agg_inter_$index", f.dataType, f.nullable)() + } + case t => + Seq(AttributeReference(s"agg_inter", t)()) } val v = diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala index 4d2f9fae3147..534a8d9f1c74 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala @@ -93,12 +93,13 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { | myavg(1), | myavg(1L), | myavg(cast(1.0 as float)), - | myavg(cast(1.0 as double)) + | myavg(cast(1.0 as double)), + | mycount_if(true) |""".stripMargin) df.collect() assert( df.collect() - .sameElements(Array(Row(1.0, 1.0, 1.0, 1.0)))) + .sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L)))) } } diff --git a/cpp/velox/udf/examples/MyUDAF.cc b/cpp/velox/udf/examples/MyUDAF.cc index e6c4b1fea7e0..710bce53ae65 100644 --- a/cpp/velox/udf/examples/MyUDAF.cc +++ b/cpp/velox/udf/examples/MyUDAF.cc @@ -20,19 +20,22 @@ #include #include #include -#include + #include "udf/Udaf.h" +#include "udf/examples/UdfCommon.h" using namespace facebook::velox; using namespace facebook::velox::exec; namespace { +static const char* kBoolean = "boolean"; static const char* kInteger = "int"; static const char* kBigInt = "bigint"; static const char* kFloat = "float"; static const char* kDouble = "double"; +namespace myavg { // Copied from velox/exec/tests/SimpleAverageAggregate.cpp // Implementation of the average aggregation function through the @@ -98,84 +101,321 @@ class AverageAggregate { }; }; -exec::AggregateRegistrationResult registerSimpleAverageAggregate(const std::string& name) { - std::vector> signatures; +class MyAvgRegisterer final : public gluten::UdafRegisterer { + int getNumUdaf() override { + return 4; + } + + void populateUdafEntries(int& index, gluten::UdafEntry* udafEntries) override { + for (const auto& argTypes : {myAvgArg1_, myAvgArg2_, myAvgArg3_, myAvgArg4_}) { + udafEntries[index++] = {name_.c_str(), kDouble, 1, argTypes, myAvgIntermediateType_}; + } + } + + void registerSignatures() override { + registerSimpleAverageAggregate(); + } + + private: + exec::AggregateRegistrationResult registerSimpleAverageAggregate() { + std::vector> signatures; + + for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { + signatures.push_back(exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType("row(double,bigint)") + .argumentType(inputType) + .build()); + } - for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("double") + .returnType("real") .intermediateType("row(double,bigint)") - .argumentType(inputType) + .argumentType("real") .build()); - } - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("real") - .intermediateType("row(double,bigint)") - .argumentType("real") - .build()); - - return exec::registerAggregateFunction( - name, - std::move(signatures), - [name]( - core::AggregationNode::Step step, - const std::vector& argTypes, - const TypePtr& resultType, - const core::QueryConfig& /*config*/) -> std::unique_ptr { - VELOX_CHECK_LE(argTypes.size(), 1, "{} takes at most one argument", name); - auto inputType = argTypes[0]; - if (exec::isRawInput(step)) { - switch (inputType->kind()) { - case TypeKind::SMALLINT: - return std::make_unique>>(resultType); - case TypeKind::INTEGER: - return std::make_unique>>(resultType); - case TypeKind::BIGINT: - return std::make_unique>>(resultType); - case TypeKind::REAL: - return std::make_unique>>(resultType); - case TypeKind::DOUBLE: - return std::make_unique>>(resultType); - default: - VELOX_FAIL("Unknown input type for {} aggregation {}", name, inputType->kindName()); - } - } else { - switch (resultType->kind()) { - case TypeKind::REAL: - return std::make_unique>>(resultType); - case TypeKind::DOUBLE: - case TypeKind::ROW: - return std::make_unique>>(resultType); - default: - VELOX_FAIL("Unsupported result type for final aggregation: {}", resultType->kindName()); + return exec::registerAggregateFunction( + name_, + std::move(signatures), + [this]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) -> std::unique_ptr { + VELOX_CHECK_LE(argTypes.size(), 1, "{} takes at most one argument", name_); + auto inputType = argTypes[0]; + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::SMALLINT: + return std::make_unique>>(resultType); + case TypeKind::INTEGER: + return std::make_unique>>(resultType); + case TypeKind::BIGINT: + return std::make_unique>>(resultType); + case TypeKind::REAL: + return std::make_unique>>(resultType); + case TypeKind::DOUBLE: + return std::make_unique>>(resultType); + default: + VELOX_FAIL("Unknown input type for {} aggregation {}", name_, inputType->kindName()); + } + } else { + switch (resultType->kind()) { + case TypeKind::REAL: + return std::make_unique>>(resultType); + case TypeKind::DOUBLE: + case TypeKind::ROW: + return std::make_unique>>(resultType); + default: + VELOX_FAIL("Unsupported result type for final aggregation: {}", resultType->kindName()); + } } + }, + true /*registerCompanionFunctions*/, + true /*overwrite*/); + } + + const std::string name_ = "myavg"; + const char* myAvgArg1_[1] = {kInteger}; + const char* myAvgArg2_[1] = {kBigInt}; + const char* myAvgArg3_[1] = {kFloat}; + const char* myAvgArg4_[1] = {kDouble}; + + const char* myAvgIntermediateType_ = "struct"; +}; +} // namespace myavg + +namespace mycountif { + +// Copied from velox/functions/prestosql/aggregates/CountIfAggregate.cpp +class CountIfAggregate : public exec::Aggregate { + public: + explicit CountIfAggregate() : exec::Aggregate(BIGINT()) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(int64_t); + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) override { + extractValues(groups, numGroups, result); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { + auto* vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + + auto* rawValues = vector->mutableRawValues(); + for (vector_size_t i = 0; i < numGroups; ++i) { + rawValues[i] = *value(groups[i]); + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + DecodedVector decoded(*args[0], rows); + + if (decoded.isConstantMapping()) { + if (decoded.isNullAt(0)) { + return; + } + if (decoded.valueAt(0)) { + rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); + } + } else if (decoded.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + if (decoded.valueAt(i)) { + addToGroup(groups[i], 1); } - }, - true /*registerCompanionFunctions*/, - true /*overwrite*/); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + if (decoded.valueAt(i)) { + addToGroup(groups[i], 1); + } + }); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + DecodedVector decoded(*args[0], rows); + + if (decoded.isConstantMapping()) { + auto numTrue = decoded.valueAt(0); + rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], numTrue); }); + return; + } + + rows.applyToSelected([&](vector_size_t i) { + auto numTrue = decoded.valueAt(i); + addToGroup(groups[i], numTrue); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + DecodedVector decoded(*args[0], rows); + + // Constant mapping - check once and add number of selected rows if true. + if (decoded.isConstantMapping()) { + if (!decoded.isNullAt(0)) { + auto isTrue = decoded.valueAt(0); + if (isTrue) { + addToGroup(group, rows.countSelected()); + } + } + return; + } + + int64_t numTrue = 0; + if (decoded.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + if (decoded.valueAt(i)) { + ++numTrue; + } + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + if (decoded.valueAt(i)) { + ++numTrue; + } + }); + } + addToGroup(group, numTrue); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + auto arg = args[0]->as>(); + + int64_t numTrue = 0; + rows.applyToSelected([&](auto row) { numTrue += arg->valueAt(row); }); + + addToGroup(group, numTrue); + } + + protected: + void initializeNewGroupsInternal(char** groups, folly::Range indices) override { + for (auto i : indices) { + *value(groups[i]) = 0; + } + } + + private: + inline void addToGroup(char* group, int64_t numTrue) { + *value(group) += numTrue; + } +}; + +class MyCountIfRegisterer final : public gluten::UdafRegisterer { + int getNumUdaf() override { + return 1; + } + + void populateUdafEntries(int& index, gluten::UdafEntry* udafEntries) override { + udafEntries[index++] = {name_.c_str(), kBigInt, 1, myCountIfArg_, kBigInt}; + } + + void registerSignatures() override { + registerCountIfAggregate(); + } + + private: + void registerCountIfAggregate() { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("bigint") + .intermediateType("bigint") + .argumentType("boolean") + .build(), + }; + + exec::registerAggregateFunction( + name_, + std::move(signatures), + [this]( + core::AggregationNode::Step step, + std::vector argTypes, + const TypePtr& /*resultType*/, + const core::QueryConfig& /*config*/) -> std::unique_ptr { + VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes one argument", name_); + + auto isPartial = exec::isRawInput(step); + if (isPartial) { + VELOX_CHECK_EQ(argTypes[0]->kind(), TypeKind::BOOLEAN, "{} function only accepts boolean parameter", name_); + } + + return std::make_unique(); + }, + {false /*orderSensitive*/}, + true, + true); + } + + const std::string name_ = "mycount_if"; + const char* myCountIfArg_[1] = {kBoolean}; +}; +} // namespace mycountif + +std::vector>& globalRegisters() { + static std::vector> registerers; + return registerers; } -} // namespace -const int kNumMyUdaf = 4; +void setupRegisterers() { + static bool inited = false; + if (inited) { + return; + } + auto& registerers = globalRegisters(); + registerers.push_back(std::make_shared()); + registerers.push_back(std::make_shared()); + inited = true; +} +} // namespace DEFINE_GET_NUM_UDAF { - return kNumMyUdaf; + setupRegisterers(); + + int numUdf = 0; + for (const auto& registerer : globalRegisters()) { + numUdf += registerer->getNumUdaf(); + } + return numUdf; } -const char* myAvgArg1[] = {kInteger}; -const char* myAvgArg2[] = {kBigInt}; -const char* myAvgArg3[] = {kFloat}; -const char* myAvgArg4[] = {kDouble}; -const char* myAvgIntermediateType = "struct"; DEFINE_GET_UDAF_ENTRIES { + setupRegisterers(); + int index = 0; - udafEntries[index++] = {"myavg", kDouble, 1, myAvgArg1, myAvgIntermediateType}; - udafEntries[index++] = {"myavg", kDouble, 1, myAvgArg2, myAvgIntermediateType}; - udafEntries[index++] = {"myavg", kDouble, 1, myAvgArg3, myAvgIntermediateType}; - udafEntries[index++] = {"myavg", kDouble, 1, myAvgArg4, myAvgIntermediateType}; + for (const auto& registerer : globalRegisters()) { + registerer->populateUdafEntries(index, udafEntries); + } } DEFINE_REGISTER_UDAF { - registerSimpleAverageAggregate("myavg"); + setupRegisterers(); + + for (const auto& registerer : globalRegisters()) { + registerer->registerSignatures(); + } } diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc index 88bc3ad85da3..ee20ca39d026 100644 --- a/cpp/velox/udf/examples/MyUDF.cc +++ b/cpp/velox/udf/examples/MyUDF.cc @@ -20,28 +20,17 @@ #include #include #include "udf/Udf.h" +#include "udf/examples/UdfCommon.h" using namespace facebook::velox; using namespace facebook::velox::exec; +namespace { + static const char* kInteger = "int"; static const char* kBigInt = "bigint"; static const char* kDate = "date"; -class UdfRegisterer { - public: - ~UdfRegisterer() = default; - - // Returns the number of UDFs in populateUdfEntries. - virtual int getNumUdf() = 0; - - // Populate the udfEntries, starting at the given index. - virtual void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) = 0; - - // Register all function signatures to velox. - virtual void registerSignatures() = 0; -}; - namespace myudf { template @@ -106,7 +95,7 @@ static std::shared_ptr makePlusConstant( // signatures: // bigint -> bigint // type: VectorFunction -class MyUdf1Registerer final : public UdfRegisterer { +class MyUdf1Registerer final : public gluten::UdfRegisterer { public: int getNumUdf() override { return 1; @@ -135,7 +124,7 @@ class MyUdf1Registerer final : public UdfRegisterer { // integer -> integer // bigint -> bigint // type: StatefulVectorFunction -class MyUdf2Registerer final : public UdfRegisterer { +class MyUdf2Registerer final : public gluten::UdfRegisterer { public: int getNumUdf() override { return 2; @@ -167,7 +156,7 @@ class MyUdf2Registerer final : public UdfRegisterer { // [integer,] ... -> integer // bigint, [bigint,] ... -> bigint // type: StatefulVectorFunction with variable arity -class MyUdf3Registerer final : public UdfRegisterer { +class MyUdf3Registerer final : public gluten::UdfRegisterer { public: int getNumUdf() override { return 2; @@ -215,7 +204,7 @@ struct MyDateSimpleFunction { // signatures: // date, integer -> bigint // type: SimpleFunction -class MyDateRegisterer final : public UdfRegisterer { +class MyDateRegisterer final : public gluten::UdfRegisterer { public: int getNumUdf() override { return 1; @@ -235,8 +224,8 @@ class MyDateRegisterer final : public UdfRegisterer { }; } // namespace mydate -std::vector>& globalRegisters() { - static std::vector> registerers; +std::vector>& globalRegisters() { + static std::vector> registerers; return registerers; } @@ -252,6 +241,7 @@ void setupRegisterers() { registerers.push_back(std::make_shared()); inited = true; } +} // namespace DEFINE_GET_NUM_UDF { setupRegisterers(); diff --git a/cpp/velox/udf/examples/UdfCommon.h b/cpp/velox/udf/examples/UdfCommon.h new file mode 100644 index 000000000000..e5278969cce0 --- /dev/null +++ b/cpp/velox/udf/examples/UdfCommon.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "udf/Udf.h" +#include "udf/Udaf.h" + +namespace gluten { + +class UdfRegisterer { + public: + ~UdfRegisterer() = default; + + // Returns the number of UDFs in populateUdfEntries. + virtual int getNumUdf() = 0; + + // Populate the udfEntries, starting at the given index. + virtual void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) = 0; + + // Register all function signatures to velox. + virtual void registerSignatures() = 0; +}; + +class UdafRegisterer { + public: + ~UdafRegisterer() = default; + + // Returns the number of UDFs in populateUdafEntries. + virtual int getNumUdaf() = 0; + + // Populate the udfEntries, starting at the given index. + virtual void populateUdafEntries(int& index, gluten::UdafEntry* udafEntries) = 0; + + // Register all function signatures to velox. + virtual void registerSignatures() = 0; +}; + +} // namespace gluten \ No newline at end of file From d8314ac0ff0880ddc98c1c3b1a2d45a31e389e50 Mon Sep 17 00:00:00 2001 From: "Ma, Rong" Date: Mon, 27 May 2024 12:43:44 +0800 Subject: [PATCH 2/3] update doc --- docs/developers/VeloxUDF.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md index b88c4de1515e..c896fd672657 100644 --- a/docs/developers/VeloxUDF.md +++ b/docs/developers/VeloxUDF.md @@ -137,13 +137,10 @@ You can also specify the local or HDFS URIs to the UDF libraries or archives. Lo ## Try the example We provided Velox UDF examples in file [MyUDF.cc](../../cpp/velox/udf/examples/MyUDF.cc) and UDAF examples in file [MyUDAF.cc](../../cpp/velox/udf/examples/MyUDAF.cc). -You need to build the gluten cpp project with `--build_example=ON` to get the example libraries. +You need to build the gluten project with `--build_example=ON` to get the example libraries. ```shell -## compile Gluten cpp module -cd /path/to/gluten/cpp -## if you use custom velox_home, make sure specified here by --velox_home -./compile.sh --build_velox_backend=ON --build_examples=ON +./dev/buildbundle-veloxbe.sh --build_examples=ON ``` Then, you can find the example libraries at /path/to/gluten/cpp/build/velox/udf/examples/ From 67d200601ecee0bcca19f59bad9248e56b15439d Mon Sep 17 00:00:00 2001 From: "Ma, Rong" Date: Mon, 27 May 2024 13:29:41 +0800 Subject: [PATCH 3/3] style --- cpp/velox/udf/examples/UdfCommon.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/velox/udf/examples/UdfCommon.h b/cpp/velox/udf/examples/UdfCommon.h index e5278969cce0..a68c474607cd 100644 --- a/cpp/velox/udf/examples/UdfCommon.h +++ b/cpp/velox/udf/examples/UdfCommon.h @@ -17,8 +17,8 @@ #pragma once -#include "udf/Udf.h" #include "udf/Udaf.h" +#include "udf/Udf.h" namespace gluten {