Skip to content

Commit

Permalink
support create temporary function for native hive udf
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Aug 14, 2024
1 parent fc7f9cd commit 9af3037
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 9 deletions.
6 changes: 6 additions & 0 deletions backends-velox/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>${hive.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {

override def genInjectedFunctions()
: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
UDFResolver.getFunctionSignatures
UDFResolver.getFunctionSignatures()
}

override def rewriteSpillPath(path: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable, UDFMappings}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace

Expand Down Expand Up @@ -331,7 +331,7 @@ object UDFResolver extends Logging {
.mkString(",")
}

def getFunctionSignatures: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)
Expand All @@ -341,6 +341,7 @@ object UDFResolver extends Logging {
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()
UDFNames.foreach(UDFMappings.nativeHiveUDF.add)

UDFNames.map {
name =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.builder()
.master(master)
.config(sparkConf)
.enableHiveSupport()
.getOrCreate()
}

Expand Down Expand Up @@ -91,10 +92,12 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {

test("test udf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
val df =
spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
assert(
df.collect()
.sameElements(Array(Row(105L, 6L, Date.valueOf("2024-03-30")))))
.sameElements(
Array(Row("c4ca4238a0b923820dcc509a6f75849b", 105L, 6L, Date.valueOf("2024-03-30")))))
}

withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "false") {
Expand Down Expand Up @@ -128,6 +131,44 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}

test("test hive udf replacement") {
val tbl = "test_hive_udf_replacement"
withTempPath {
dir =>
try {
spark.sql(s"""
|CREATE EXTERNAL TABLE $tbl
|LOCATION 'file://$dir'
|AS select * from values ('1'), ('2'), ('3')
|""".stripMargin)

// Check native hive udf has been registered.
assert(UDFMappings.nativeHiveUDF.contains("org.apache.hadoop.hive.ql.udf.UDFLog10"))

spark.sql(
"""CREATE TEMPORARY FUNCTION hive_log10 AS 'org.apache.hadoop.hive.ql.udf.UDFLog10'""")

val nativeResult =
spark.sql(s"""SELECT hive_log10(col1) FROM $tbl""").collect()
// Unregister native hive udf to fallback.
UDFMappings.nativeHiveUDF.remove("org.apache.hadoop.hive.ql.udf.UDFLog10")
val fallbackResult =
spark.sql(s"""SELECT hive_log10(col1) FROM $tbl""").collect()
assert(nativeResult.sameElements(fallbackResult))

// Add an unimplemented udf to the map to test fallback of registered native hive udf.
UDFMappings.nativeHiveUDF.add("org.apache.hadoop.hive.ql.udf.UDFMd5")
spark.sql("CREATE TEMPORARY FUNCTION hive_md5 AS 'org.apache.hadoop.hive.ql.udf.UDFMd5'")
val df =
spark.sql(s"""select hive_md5(col1) from $tbl""")
val expected = spark.sql(s"""SELECT md5(col1) FROM $tbl""")
assert(df.collect().sameElements(expected.collect()))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tbl")
}
}
}
}

@UDFTest
Expand Down
39 changes: 39 additions & 0 deletions cpp/velox/udf/examples/MyUDF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace {
static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
static const char* kDouble = "double";

namespace myudf {

Expand Down Expand Up @@ -248,6 +249,43 @@ class MyDate2Registerer final : public gluten::UdfRegisterer {
};
} // namespace mydate

namespace mylog10 {
template <typename T>
struct HiveLog10Function {
FOLLY_ALWAYS_INLINE bool call(double& result, double a) {
if (a <= 0.0) {
return false;
}
result = std::log10(a);
return true;
}
};

// name: org.apache.hadoop.hive.ql.udf.UDFLog10
// signatures:
// double -> double
// type: SimpleFunction
class HiveLog10Registerer final : public gluten::UdfRegisterer {
public:
int getNumUdf() override {
return 1;
}

void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
// Set `allowTypeConversion` for hive udf.
udfEntries[index++] = {name_.c_str(), kDouble, 1, arg_, false, true};
}

void registerSignatures() override {
facebook::velox::registerFunction<HiveLog10Function, double, double>({name_});
}

private:
const std::string name_ = "org.apache.hadoop.hive.ql.udf.UDFLog10";
const char* arg_[1] = {kDouble};
};
} // namespace mylog10

std::vector<std::shared_ptr<gluten::UdfRegisterer>>& globalRegisters() {
static std::vector<std::shared_ptr<gluten::UdfRegisterer>> registerers;
return registerers;
Expand All @@ -264,6 +302,7 @@ void setupRegisterers() {
registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
registerers.push_back(std::make_shared<mydate::MyDate2Registerer>());
registerers.push_back(std::make_shared<mylog10::HiveLog10Registerer>());
inited = true;
}
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ import org.apache.commons.lang3.StringUtils

import java.util.Locale

import scala.collection.mutable
import scala.collection.mutable.Map

object UDFMappings extends Logging {
val hiveUDFMap: Map[String, String] = Map()
val pythonUDFMap: Map[String, String] = Map()
val scalaUDFMap: Map[String, String] = Map()
val nativeHiveUDF: mutable.Set[String] = mutable.HashSet()

private def appendKVToMap(key: String, value: String, res: Map[String, String]): Unit = {
if (key.isEmpty || value.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ object HiveUDFTransformer {
def replaceWithExpressionTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
val udfName = expr match {
val (udfName, udfClassName) = expr match {
case s: HiveSimpleUDF =>
s.name.stripPrefix("default.")
(s.name.stripPrefix("default."), s.funcWrapper.functionClassName)
case g: HiveGenericUDF =>
g.name.stripPrefix("default.")
(g.name.stripPrefix("default."), g.funcWrapper.functionClassName)
case _ =>
throw new GlutenNotSupportException(
s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
}

UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match {
val udf = if (UDFMappings.nativeHiveUDF.contains(udfClassName)) {
Some(udfClassName)
} else {
UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT))
}
udf match {
case Some(name) =>
GenericExpressionTransformer(
name,
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
<arrow-gluten.version>15.0.0-gluten</arrow-gluten.version>
<arrow-memory.artifact>arrow-memory-unsafe</arrow-memory.artifact>
<hadoop.version>2.7.4</hadoop.version>
<hive.version>2.3.9</hive.version>
<slf4j.version>2.0.7</slf4j.version>
<log4j.version>2.20.0</log4j.version>
<antlr4.version>4.9.3</antlr4.version>
Expand Down

0 comments on commit 9af3037

Please sign in to comment.