From baee29c847f774c93b08a89bb025198f8bc02e7a Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Tue, 27 Aug 2024 06:19:11 +0000 Subject: [PATCH] support native hive udaf --- .../HashAggregateExecTransformer.scala | 17 +- .../gluten/expression/VeloxUdfSuite.scala | 64 +++-- .../SubstraitToVeloxPlanValidator.cc | 4 +- cpp/velox/udf/examples/MyUDAF.cc | 226 +----------------- .../spark/sql/hive/HiveUDAFInspector.scala | 30 +++ .../AggregateFunctionsBuilder.scala | 2 +- 6 files changed, 104 insertions(+), 239 deletions(-) create mode 100644 gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDAFInspector.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index 9c5b68e7bff1..fe5e0d92d6d5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -30,6 +30,8 @@ import org.apache.gluten.utils.VeloxIntermediateData import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution._ +import org.apache.spark.sql.expression.UDFResolver +import org.apache.spark.sql.hive.HiveUDAFInspector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -681,14 +683,25 @@ object VeloxAggregateFunctionsBuilder { aggregateFunc: AggregateFunction, mode: AggregateMode): Long = { val functionMap = args.asInstanceOf[JHashMap[String, JLong]] - val sigName = AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc) + val (sigName, aggFunc) = + try { + (AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc), aggregateFunc) + } catch { + case e: GlutenNotSupportException => + HiveUDAFInspector.getUDAFClassName(aggregateFunc) match { + case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) => + (udafClass, UDFResolver.getUdafExpression(udafClass)(aggregateFunc.children)) + case _ => throw e + } + case e: Throwable => throw e + } ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName( // Substrait-to-Velox procedure will choose appropriate companion function if needed. sigName, - VeloxIntermediateData.getInputTypes(aggregateFunc, mode == PartialMerge || mode == Final), + VeloxIntermediateData.getInputTypes(aggFunc, mode == PartialMerge || mode == Final), FunctionConfig.REQ ) ) diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala index 3abab1376891..3a3fe203595f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.expression -import org.apache.gluten.backendsapi.velox.VeloxBackendSettings import org.apache.gluten.tags.{SkipTestTags, UDFTest} import org.apache.spark.SparkConf @@ -91,26 +90,50 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { .set("spark.memory.offHeap.size", "1024MB") } - ignore("test udaf") { - val df = spark.sql("""select - | myavg(1), - | myavg(1L), - | myavg(cast(1.0 as float)), - | myavg(cast(1.0 as double)), - | mycount_if(true) - |""".stripMargin) - df.collect() - assert( - df.collect() - .sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L)))) - } + test("test native hive udaf") { + val tbl = "test_hive_udaf_replacement" + withTempPath { + dir => + try { + // Check native hive udaf has been registered. + val udafClass = "test.org.apache.spark.sql.MyDoubleAvg" + assert(UDFResolver.UDAFNames.contains(udafClass)) - ignore("test udaf allow type conversion") { - withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") { - val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""") - assert( - df.collect() - .sameElements(Array(Row(1.0, 1.0, 1L)))) + spark.sql(s""" + |CREATE TEMPORARY FUNCTION my_double_avg + |AS '$udafClass' + |""".stripMargin) + spark.sql(s""" + |CREATE EXTERNAL TABLE $tbl + |LOCATION 'file://$dir' + |AS select * from values (1, '1'), (2, '2'), (3, '3') + |""".stripMargin) + val df = spark.sql(s"""select + | my_double_avg(cast(col1 as double)), + | my_double_avg(cast(col2 as double)) + | from $tbl + |""".stripMargin) + val nativeImplicitConversionDF = spark.sql(s"""select + | my_double_avg(col1), + | my_double_avg(col2) + | from $tbl + |""".stripMargin) + val nativeResult = df.collect() + val nativeImplicitConversionResult = nativeImplicitConversionDF.collect() + + UDFResolver.UDAFNames.remove(udafClass) + val fallbackDF = spark.sql(s"""select + | my_double_avg(cast(col1 as double)), + | my_double_avg(cast(col2 as double)) + | from $tbl + |""".stripMargin) + val fallbackResult = fallbackDF.collect() + assert(nativeResult.sameElements(fallbackResult)) + assert(nativeImplicitConversionResult.sameElements(fallbackResult)) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tbl") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg") + } } } @@ -205,6 +228,7 @@ class VeloxUdfSuiteLocal extends VeloxUdfSuite { super.sparkConf .set("spark.files", udfLibPath) .set("spark.gluten.sql.columnar.backend.velox.udfLibraryPaths", udfLibRelativePath) + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 7bb0eab77758..60a8d38d192a 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -1162,11 +1162,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag "regr_sxy", "regr_replacement"}; - auto udfFuncs = UdfLoader::getInstance()->getRegisteredUdafNames(); + auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames(); for (const auto& funcSpec : funcSpecs) { auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec); - if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udfFuncs.find(funcName) == udfFuncs.end()) { + if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) { LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel."); return false; } diff --git a/cpp/velox/udf/examples/MyUDAF.cc b/cpp/velox/udf/examples/MyUDAF.cc index 710bce53ae65..516404b55c3f 100644 --- a/cpp/velox/udf/examples/MyUDAF.cc +++ b/cpp/velox/udf/examples/MyUDAF.cc @@ -90,7 +90,7 @@ class AverageAggregate { } bool writeFinalResult(exec::out_type& out) { - out = sum_ / count_; + out = sum_ / count_ + 100.0; return true; } @@ -103,12 +103,12 @@ class AverageAggregate { class MyAvgRegisterer final : public gluten::UdafRegisterer { int getNumUdaf() override { - return 4; + return 2; } void populateUdafEntries(int& index, gluten::UdafEntry* udafEntries) override { - for (const auto& argTypes : {myAvgArg1_, myAvgArg2_, myAvgArg3_, myAvgArg4_}) { - udafEntries[index++] = {name_.c_str(), kDouble, 1, argTypes, myAvgIntermediateType_}; + for (const auto& argTypes : {myAvgArgFloat_, myAvgArgDouble_}) { + udafEntries[index++] = {name_.c_str(), kDouble, 1, argTypes, myAvgIntermediateType_, false, true}; } } @@ -120,13 +120,11 @@ class MyAvgRegisterer final : public gluten::UdafRegisterer { exec::AggregateRegistrationResult registerSimpleAverageAggregate() { std::vector> signatures; - for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("double") - .intermediateType("row(double,bigint)") - .argumentType(inputType) - .build()); - } + signatures.push_back(exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType("row(double,bigint)") + .argumentType("double") + .build()); signatures.push_back(exec::AggregateFunctionSignatureBuilder() .returnType("real") @@ -146,12 +144,6 @@ class MyAvgRegisterer final : public gluten::UdafRegisterer { auto inputType = argTypes[0]; if (exec::isRawInput(step)) { switch (inputType->kind()) { - case TypeKind::SMALLINT: - return std::make_unique>>(resultType); - case TypeKind::INTEGER: - return std::make_unique>>(resultType); - case TypeKind::BIGINT: - return std::make_unique>>(resultType); case TypeKind::REAL: return std::make_unique>>(resultType); case TypeKind::DOUBLE: @@ -175,207 +167,14 @@ class MyAvgRegisterer final : public gluten::UdafRegisterer { true /*overwrite*/); } - const std::string name_ = "myavg"; - const char* myAvgArg1_[1] = {kInteger}; - const char* myAvgArg2_[1] = {kBigInt}; - const char* myAvgArg3_[1] = {kFloat}; - const char* myAvgArg4_[1] = {kDouble}; + const std::string name_ = "test.org.apache.spark.sql.MyDoubleAvg"; + const char* myAvgArgFloat_[1] = {kFloat}; + const char* myAvgArgDouble_[1] = {kDouble}; const char* myAvgIntermediateType_ = "struct"; }; } // namespace myavg -namespace mycountif { - -// Copied from velox/functions/prestosql/aggregates/CountIfAggregate.cpp -class CountIfAggregate : public exec::Aggregate { - public: - explicit CountIfAggregate() : exec::Aggregate(BIGINT()) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(int64_t); - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) override { - extractValues(groups, numGroups, result); - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { - auto* vector = (*result)->as>(); - VELOX_CHECK(vector); - vector->resize(numGroups); - - auto* rawValues = vector->mutableRawValues(); - for (vector_size_t i = 0; i < numGroups; ++i) { - rawValues[i] = *value(groups[i]); - } - } - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - DecodedVector decoded(*args[0], rows); - - if (decoded.isConstantMapping()) { - if (decoded.isNullAt(0)) { - return; - } - if (decoded.valueAt(0)) { - rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); - } - } else if (decoded.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decoded.isNullAt(i)) { - return; - } - if (decoded.valueAt(i)) { - addToGroup(groups[i], 1); - } - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - if (decoded.valueAt(i)) { - addToGroup(groups[i], 1); - } - }); - } - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - DecodedVector decoded(*args[0], rows); - - if (decoded.isConstantMapping()) { - auto numTrue = decoded.valueAt(0); - rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], numTrue); }); - return; - } - - rows.applyToSelected([&](vector_size_t i) { - auto numTrue = decoded.valueAt(i); - addToGroup(groups[i], numTrue); - }); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - DecodedVector decoded(*args[0], rows); - - // Constant mapping - check once and add number of selected rows if true. - if (decoded.isConstantMapping()) { - if (!decoded.isNullAt(0)) { - auto isTrue = decoded.valueAt(0); - if (isTrue) { - addToGroup(group, rows.countSelected()); - } - } - return; - } - - int64_t numTrue = 0; - if (decoded.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decoded.isNullAt(i)) { - return; - } - if (decoded.valueAt(i)) { - ++numTrue; - } - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - if (decoded.valueAt(i)) { - ++numTrue; - } - }); - } - addToGroup(group, numTrue); - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - auto arg = args[0]->as>(); - - int64_t numTrue = 0; - rows.applyToSelected([&](auto row) { numTrue += arg->valueAt(row); }); - - addToGroup(group, numTrue); - } - - protected: - void initializeNewGroupsInternal(char** groups, folly::Range indices) override { - for (auto i : indices) { - *value(groups[i]) = 0; - } - } - - private: - inline void addToGroup(char* group, int64_t numTrue) { - *value(group) += numTrue; - } -}; - -class MyCountIfRegisterer final : public gluten::UdafRegisterer { - int getNumUdaf() override { - return 1; - } - - void populateUdafEntries(int& index, gluten::UdafEntry* udafEntries) override { - udafEntries[index++] = {name_.c_str(), kBigInt, 1, myCountIfArg_, kBigInt}; - } - - void registerSignatures() override { - registerCountIfAggregate(); - } - - private: - void registerCountIfAggregate() { - std::vector> signatures{ - exec::AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("boolean") - .build(), - }; - - exec::registerAggregateFunction( - name_, - std::move(signatures), - [this]( - core::AggregationNode::Step step, - std::vector argTypes, - const TypePtr& /*resultType*/, - const core::QueryConfig& /*config*/) -> std::unique_ptr { - VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes one argument", name_); - - auto isPartial = exec::isRawInput(step); - if (isPartial) { - VELOX_CHECK_EQ(argTypes[0]->kind(), TypeKind::BOOLEAN, "{} function only accepts boolean parameter", name_); - } - - return std::make_unique(); - }, - {false /*orderSensitive*/}, - true, - true); - } - - const std::string name_ = "mycount_if"; - const char* myCountIfArg_[1] = {kBoolean}; -}; -} // namespace mycountif - std::vector>& globalRegisters() { static std::vector> registerers; return registerers; @@ -388,7 +187,6 @@ void setupRegisterers() { } auto& registerers = globalRegisters(); registerers.push_back(std::make_shared()); - registerers.push_back(std::make_shared()); inited = true; } } // namespace diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDAFInspector.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDAFInspector.scala new file mode 100644 index 000000000000..7c6401b6765a --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDAFInspector.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.aggregate.ScalaUDAF + +object HiveUDAFInspector { + def getUDAFClassName(expr: Expression): Option[String] = { + expr match { + case func: HiveUDAFFunction => Some(func.funcWrapper.functionClassName) + case scalaUDAF: ScalaUDAF => Some(scalaUDAF.udaf.getClass.getName) + case _ => None + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index bd73b7b7aa54..15de4a734d53 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -56,7 +56,7 @@ object AggregateFunctionsBuilder { case _ => val nameOpt = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass) if (nameOpt.isEmpty) { - throw new UnsupportedOperationException( + throw new GlutenNotSupportException( s"Could not find a valid substrait mapping name for $aggregateFunc.") } nameOpt.get match {