Skip to content

Commit

Permalink
[GLUTEN-3924][CORE] Match hive UDF name in case-insensitive mode duri…
Browse files Browse the repository at this point in the history
…ng expression transformation (#3925)

* case-insensitive matching for hive udfs

* add uts

* fix failed uts

* fix failed uts
  • Loading branch information
taiyang-li authored Dec 7, 2023
1 parent 0028393 commit 9c6e77b
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class GlutenClickHouseHiveTableSuite()
"spark.sql.warehouse.dir",
getClass.getResource("/").getPath + "unit-tests-working-home/spark-warehouse")
.set("spark.hive.exec.dynamic.partition.mode", "nonstrict")
.set("spark.gluten.supported.hive.udfs", "my_add")
.setMaster("local[*]")
}

Expand Down Expand Up @@ -1060,4 +1061,14 @@ class GlutenClickHouseHiveTableSuite()
compareResultsAgainstVanillaSpark(select_sql, compareResult = true, _ => {})
spark.sql("DROP TABLE test_tbl_3548")
}

test("test 'hive udf'") {
val jarPath = "src/test/resources/udfs/hive-test-udfs.jar"
val jarUrl = s"file://${System.getProperty("user.dir")}/$jarPath"
spark.sql(
s"CREATE FUNCTION my_add as " +
s"'org.apache.hadoop.hive.contrib.udf.example.UDFExampleAdd2' USING JAR '$jarUrl'")
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)")(
checkOperatorMatch[ProjectExecTransformer])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.use.v2", "false")
.set("spark.gluten.supported.scala.udfs", "my_add")
.set("spark.gluten.supported.hive.udfs", "my_add")
}

override protected val createNullableTables = true
Expand Down Expand Up @@ -1319,16 +1318,6 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
checkOperatorMatch[ProjectExecTransformer])
}

ignore("test 'hive udf'") {
val jarPath = "backends-clickhouse/src/test/resources/udfs/hive-test-udfs.jar"
val jarUrl = s"file://${System.getProperty("user.dir")}/$jarPath"
spark.sql(
s"CREATE FUNCTION my_add as " +
"'org.apache.hadoop.hive.contrib.udf.example.UDFExampleAdd2' USING JAR '$jarUrl'")
runQueryAndCompare("select my_add(id, id+1) from range(10)")(
checkOperatorMatch[ProjectExecTransformer])
}

override protected def runTPCHQuery(
queryNum: Int,
tpchQueries: String = tpchQueries,
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/FunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class FunctionParser
{
return plan_parser->toFunctionNode(action_dag, func_name, args);
}

const DB::ActionsDAG::Node *
toFunctionNode(DB::ActionsDAGPtr & action_dag, const String & func_name, const String & result_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const
{
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/tests/gtest_parquet_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ TEST(ParquetWrite, ComplexTypes)
ch2arrow.chChunkToArrowTable(arrow_table, input_chunks, header.columns());

/// Convert Arrow Table to CH Block
ArrowColumnToCHColumn arrow2ch(header, "Parquet", true, true, true);
ArrowColumnToCHColumn arrow2ch(header, "Parquet", true, true, FormatSettings::DateTimeOverflowBehavior::Ignore);
Chunk output_chunk;
arrow2ch.arrowTableToCHChunk(output_chunk, arrow_table, arrow_table->num_rows());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.internal.Logging

import org.apache.commons.lang3.StringUtils

import java.util.Locale

import scala.collection.mutable.Map

object UDFMappings extends Logging {
Expand All @@ -41,7 +43,7 @@ object UDFMappings extends Logging {
s"will be replaced by value:$value")
}

res.put(key, value)
res.put(key.toLowerCase(Locale.ROOT), value)
}

private def parseStringToMap(input: String, res: Map[String, String]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import io.glutenproject.expression.{ExpressionConverter, ExpressionTransformer,

import org.apache.spark.sql.catalyst.expressions._

import java.util.Locale

object HiveSimpleUDFTransformer {
def isHiveSimpleUDF(expr: Expression): Boolean = {
expr match {
Expand All @@ -36,15 +38,18 @@ object HiveSimpleUDFTransformer {
}

val udf = expr.asInstanceOf[HiveSimpleUDF]
val substraitExprName = UDFMappings.hiveUDFMap.get(udf.name.stripPrefix("default."))
val substraitExprName =
UDFMappings.hiveUDFMap.get(udf.name.stripPrefix("default.").toLowerCase(Locale.ROOT))
substraitExprName match {
case Some(name) =>
GenericExpressionTransformer(
name,
udf.children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq)),
udf)
case _ =>
throw new UnsupportedOperationException(s"Not supported hive simple udf: $udf.")
throw new UnsupportedOperationException(
s"Not supported hive simple udf:$udf"
+ s" name:${udf.name} hiveUDFMap:${UDFMappings.hiveUDFMap}")
}
}
}

0 comments on commit 9c6e77b

Please sign in to comment.