diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml
index 0fe8f5f6fd8e..417f64999b95 100755
--- a/backends-velox/pom.xml
+++ b/backends-velox/pom.xml
@@ -140,6 +140,13 @@
spark-core_${scala.binary.version}
test-jar
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
org.apache.spark
spark-sql_${scala.binary.version}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index abb39c5bb23d..438895b25ae9 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -47,7 +47,7 @@ private object VeloxRuleApi {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
- UDFResolver.getFunctionSignatures.foreach(injector.injectFunction)
+ UDFResolver.getFunctionSignatures().foreach(injector.injectFunction)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index bd390004feda..554b3791dad3 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
+import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -819,4 +820,10 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
case other => other
}
}
+
+ override def genHiveUDFTransformer(
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = {
+ VeloxHiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
+ }
}
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index ab83c55ee306..39032e46f381 100644
--- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expression
import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
-import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace
@@ -95,11 +95,14 @@ case class UDAFSignature(
case class UDFExpression(
name: String,
+ alias: String,
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
extends Unevaluable
with Transformable {
+ override def nodeName: String = alias
+
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
@@ -118,11 +121,11 @@ case class UDFExpression(
}
object UDFResolver extends Logging {
- private val UDFNames = mutable.HashSet[String]()
+ val UDFNames = mutable.HashSet[String]()
// (udf_name, arg1, arg2, ...) => return type
private val UDFMap = mutable.HashMap[String, mutable.ListBuffer[UDFSignature]]()
- private val UDAFNames = mutable.HashSet[String]()
+ val UDAFNames = mutable.HashSet[String]()
// (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
private val UDAFMap =
mutable.HashMap[String, mutable.ListBuffer[UDAFSignature]]()
@@ -331,7 +334,7 @@ object UDFResolver extends Logging {
.mkString(",")
}
- def getFunctionSignatures: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
+ def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)
@@ -341,13 +344,12 @@ object UDFResolver extends Logging {
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()
-
UDFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UDFExpression].getName, name),
- (e: Seq[Expression]) => getUdfExpression(name)(e))
+ (e: Seq[Expression]) => getUdfExpression(name, name)(e))
}.toSeq ++ UDAFNames.map {
name =>
(
@@ -364,27 +366,29 @@ object UDFResolver extends Logging {
.toBoolean
}
- private def getUdfExpression(name: String)(children: Seq[Expression]) = {
+ def getUdfExpression(name: String, alias: String)(children: Seq[Expression]): UDFExpression = {
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."
val allowTypeConversion = checkAllowTypeConversion
val signatures =
- UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage));
+ UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage));
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDFExpression(
name,
+ alias,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
- else applyCast(children, sig))
+ else applyCast(children, sig)
+ )
case None =>
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
}
}
- private def getUdafExpression(name: String)(children: Seq[Expression]) = {
+ def getUdafExpression(name: String)(children: Seq[Expression]): UserDefinedAggregateFunction = {
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."
@@ -392,7 +396,7 @@ object UDFResolver extends Logging {
val signatures =
UDAFMap.getOrElse(
name,
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
)
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
@@ -405,7 +409,7 @@ object UDFResolver extends Logging {
sig.intermediateAttrs
)
case None =>
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
}
}
diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
new file mode 100644
index 000000000000..d895faa31702
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive
+
+import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer}
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.expression.UDFResolver
+
+object VeloxHiveUDFTransformer {
+ def replaceWithExpressionTransformer(
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = {
+ val (udfName, udfClassName) = expr match {
+ case s: HiveSimpleUDF =>
+ (s.name.stripPrefix("default."), s.funcWrapper.functionClassName)
+ case g: HiveGenericUDF =>
+ (g.name.stripPrefix("default."), g.funcWrapper.functionClassName)
+ case _ =>
+ throw new GlutenNotSupportException(
+ s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
+ }
+
+ if (UDFResolver.UDFNames.contains(udfClassName)) {
+ UDFResolver
+ .getUdfExpression(udfClassName, udfName)(expr.children)
+ .getTransformer(
+ ExpressionConverter.replaceWithExpressionTransformer(expr.children, attributeSeq)
+ )
+ } else {
+ HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr, attributeSeq)
+ }
+ }
+}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 008337b9400e..596757df35d9 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.tags.{SkipTestTags, UDFTest}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.expression.UDFResolver
import java.nio.file.Paths
import java.sql.Date
@@ -56,12 +57,31 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.builder()
.master(master)
.config(sparkConf)
+ .enableHiveSupport()
.getOrCreate()
}
_spark.sparkContext.setLogLevel("info")
}
+ override def afterAll(): Unit = {
+ try {
+ super.afterAll()
+ if (_spark != null) {
+ try {
+ _spark.sessionState.catalog.reset()
+ } finally {
+ _spark.stop()
+ _spark = null
+ }
+ }
+ } finally {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ doThreadPostAudit()
+ }
+ }
+
override protected def spark = _spark
protected def sparkConf: SparkConf = {
@@ -128,6 +148,85 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}
+
+ test("test hive udf replacement") {
+ val tbl = "test_hive_udf_replacement"
+ withTempPath {
+ dir =>
+ try {
+ spark.sql(s"""
+ |CREATE EXTERNAL TABLE $tbl
+ |LOCATION 'file://$dir'
+ |AS select * from values (1, '1'), (2, '2'), (3, '3')
+ |""".stripMargin)
+
+ // Check native hive udf has been registered.
+ assert(
+ UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
+
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_string_string
+ |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
+ |""".stripMargin)
+
+ val nativeResult =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
+ // Unregister native hive udf to fallback.
+ UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
+ val fallbackResult =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
+ assert(nativeResult.sameElements(fallbackResult))
+
+ // Add an unimplemented udf to the map to test fallback of registered native hive udf.
+ UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_int_to_string
+ |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+ |""".stripMargin)
+ val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+ checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+ } finally {
+ spark.sql(s"DROP TABLE IF EXISTS $tbl")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+ }
+ }
+ }
+
+ test("test udf fallback in partition filter") {
+ withTempPath {
+ dir =>
+ try {
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_int_to_string
+ |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |CREATE EXTERNAL TABLE t(i INT, p INT)
+ |LOCATION 'file://$dir'
+ |PARTITIONED BY (p)""".stripMargin)
+
+ spark
+ .range(0, 10, 1)
+ .selectExpr("id as col")
+ .createOrReplaceTempView("temp")
+
+ for (part <- Seq(1, 2, 3, 4)) {
+ spark.sql(s"""
+ |INSERT OVERWRITE TABLE t PARTITION (p=$part)
+ |SELECT col FROM temp""".stripMargin)
+ }
+
+ val df = spark.sql("SELECT i FROM t WHERE hive_int_to_string(p) = '4'")
+ checkAnswer(df, (0 until 10).map(Row(_)))
+ } finally {
+ spark.sql("DROP TABLE IF EXISTS t")
+ spark.sql("DROP VIEW IF EXISTS temp")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+ }
+ }
+ }
}
@UDFTest
diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc
index db1c5d7709f0..75e68413a842 100644
--- a/cpp/velox/udf/examples/MyUDF.cc
+++ b/cpp/velox/udf/examples/MyUDF.cc
@@ -30,6 +30,7 @@ namespace {
static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
+static const char* kVarChar = "varchar";
namespace myudf {
@@ -248,6 +249,43 @@ class MyDate2Registerer final : public gluten::UdfRegisterer {
};
} // namespace mydate
+namespace hivestringstring {
+template
+struct HiveStringStringFunction {
+ VELOX_DEFINE_FUNCTION_TYPES(T);
+
+ FOLLY_ALWAYS_INLINE void call(out_type& result, const arg_type& a, const arg_type& b) {
+ result.append(a.data());
+ result.append(" ");
+ result.append(b.data());
+ }
+};
+
+// name: org.apache.spark.sql.hive.execution.UDFStringString
+// signatures:
+// varchar, varchar -> varchar
+// type: SimpleFunction
+class HiveStringStringRegisterer final : public gluten::UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ // Set `allowTypeConversion` for hive udf.
+ udfEntries[index++] = {name_.c_str(), kVarChar, 2, arg_, false, true};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::registerFunction({name_});
+ }
+
+ private:
+ const std::string name_ = "org.apache.spark.sql.hive.execution.UDFStringString";
+ const char* arg_[2] = {kVarChar, kVarChar};
+};
+} // namespace hivestringstring
+
std::vector>& globalRegisters() {
static std::vector> registerers;
return registerers;
@@ -264,6 +302,7 @@ void setupRegisterers() {
registerers.push_back(std::make_shared());
registerers.push_back(std::make_shared());
registerers.push_back(std::make_shared());
+ registerers.push_back(std::make_shared());
inited = true;
}
} // namespace
diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 0227ed5da127..fb87a9ac93c0 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
-import org.apache.spark.sql.hive.HiveTableScanExecTransformer
+import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer}
import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -670,4 +670,10 @@ trait SparkPlanExecApi {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}
+
+ def genHiveUDFTransformer(
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = {
+ HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
+ }
}
diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 8bca5dbf8605..d5ca31bb5e78 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -128,7 +128,9 @@ object ExpressionConverter extends SQLConfHelper with Logging {
case s: ScalaUDF =>
return replaceScalaUDFWithExpressionTransformer(s, attributeSeq, expressionsMap)
case _ if HiveUDFTransformer.isHiveUDF(expr) =>
- return HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
+ return BackendsApiManager.getSparkPlanExecApiInstance.genHiveUDFTransformer(
+ expr,
+ attributeSeq)
case i: StaticInvoke =>
val objectName = i.staticObject.getName.stripSuffix("$")
if (objectName.endsWith("UrlCodec")) {
diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
index 5cd64cc212f8..52739aaca439 100644
--- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
@@ -43,7 +43,13 @@ object HiveUDFTransformer {
throw new GlutenNotSupportException(
s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
}
+ genTransformerFromUDFMappings(udfName, expr, attributeSeq)
+ }
+ def genTransformerFromUDFMappings(
+ udfName: String,
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): GenericExpressionTransformer = {
UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match {
case Some(name) =>
GenericExpressionTransformer(