Skip to content

Commit

Permalink
[VL] Support Spark assert_true function (#6329)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu authored Jul 11, 2024
1 parent 0448115 commit 6f189c7
Show file tree
Hide file tree
Showing 25 changed files with 1,062 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ object CHExpressionUtil {
UNIX_MICROS -> DefaultValidator(),
TIMESTAMP_MILLIS -> DefaultValidator(),
TIMESTAMP_MICROS -> DefaultValidator(),
STACK -> DefaultValidator()
STACK -> DefaultValidator(),
TRANSFORM_KEYS -> DefaultValidator(),
TRANSFORM_VALUES -> DefaultValidator(),
RAISE_ERROR -> DefaultValidator()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.FallbackTags
Expand Down Expand Up @@ -835,8 +834,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET),
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES),
// For test purpose.
Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.spark.SparkException
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -663,6 +664,19 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("Test raise_error, assert_true function") {
runQueryAndCompare("""SELECT assert_true(l_orderkey >= 1), l_orderkey
| from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
val e = intercept[SparkException] {
sql("""SELECT assert_true(l_orderkey >= 100), l_orderkey from
| lineitem limit 100""".stripMargin).collect()
}
assert(e.getCause.isInstanceOf[RuntimeException])
assert(e.getMessage.contains("l_orderkey"))
}

test("Test E function") {
runQueryAndCompare("""SELECT E() from lineitem limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ object ExpressionMappings {
Sig[MapEntries](MAP_ENTRIES),
Sig[MapZipWith](MAP_ZIP_WITH),
Sig[StringToMap](STR_TO_MAP),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES),
// Struct functions
Sig[GetStructField](GET_STRUCT_FIELD),
Sig[CreateNamedStruct](NAMED_STRUCT),
Expand All @@ -284,6 +286,7 @@ object ExpressionMappings {
Sig[SparkPartitionID](SPARK_PARTITION_ID),
Sig[WidthBucket](WIDTH_BUCKET),
Sig[ReplicateRows](REPLICATE_ROWS),
Sig[RaiseError](RAISE_ERROR),
// Decimal
Sig[UnscaledValue](UNSCALED_VALUE),
// Generator function
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- test for misc functions

-- typeof
select typeof(null);
select typeof(true);
select typeof(1Y), typeof(1S), typeof(1), typeof(1L);
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2);
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days');
select typeof(x'ABCD'), typeof('SPARK');
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'));

-- Spark-32793: Rewrite AssertTrue with RaiseError
SELECT assert_true(true), assert_true(boolean(1));
SELECT assert_true(false);
SELECT assert_true(boolean(0));
SELECT assert_true(null);
SELECT assert_true(boolean(null));
SELECT assert_true(false, 'custom error message');

CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v);
SELECT raise_error('error message');
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc;
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 16


-- !query
select typeof(null)
-- !query schema
struct<typeof(NULL):string>
-- !query output
void


-- !query
select typeof(true)
-- !query schema
struct<typeof(true):string>
-- !query output
boolean


-- !query
select typeof(1Y), typeof(1S), typeof(1), typeof(1L)
-- !query schema
struct<typeof(1):string,typeof(1):string,typeof(1):string,typeof(1):string>
-- !query output
tinyint smallint int bigint


-- !query
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2)
-- !query schema
struct<typeof(CAST(1.0 AS FLOAT)):string,typeof(1.0):string,typeof(1.2):string>
-- !query output
float double decimal(2,1)


-- !query
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days')
-- !query schema
struct<typeof(DATE '1986-05-23'):string,typeof(TIMESTAMP '1986-05-23 00:00:00'):string,typeof(INTERVAL '23' DAY):string>
-- !query output
date timestamp interval day


-- !query
select typeof(x'ABCD'), typeof('SPARK')
-- !query schema
struct<typeof(X'ABCD'):string,typeof(SPARK):string>
-- !query output
binary string


-- !query
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'))
-- !query schema
struct<typeof(array(1, 2)):string,typeof(map(1, 2)):string,typeof(named_struct(a, 1, b, spark)):string>
-- !query output
array<int> map<int,int> struct<a:int,b:string>


-- !query
SELECT assert_true(true), assert_true(boolean(1))
-- !query schema
struct<assert_true(true, 'true' is not true!):void,assert_true(1, 'cast(1 as boolean)' is not true!):void>
-- !query output
NULL NULL


-- !query
SELECT assert_true(false)
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'false' is not true!


-- !query
SELECT assert_true(boolean(0))
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'cast(0 as boolean)' is not true!


-- !query
SELECT assert_true(null)
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'null' is not true!


-- !query
SELECT assert_true(boolean(null))
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'cast(null as boolean)' is not true!


-- !query
SELECT assert_true(false, 'custom error message')
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
custom error message


-- !query
CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v)
-- !query schema
struct<>
-- !query output



-- !query
SELECT raise_error('error message')
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
error message


-- !query
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
too big: 8
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("sliding range between with aggregation")
.exclude("store and retrieve column stats in different time zones")
enableSuite[GlutenColumnExpressionSuite]
// Velox raise_error('errMsg') throws a velox_user_error exception with the message 'errMsg'.
// The final caught Spark exception's getCause().getMessage() contains 'errMsg' but does not
// equal 'errMsg' exactly. The following two tests will be skipped and overridden in Gluten.
.exclude("raise_error")
.exclude("assert_true")
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenGeneratorFunctionSuite]
enableSuite[GlutenDataFrameTimeWindowingSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,60 @@
*/
package org.apache.spark.sql

import org.apache.spark.SparkException
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions.{expr, input_file_name}
import org.apache.spark.sql.functions.{assert_true, expr, input_file_name, lit, raise_error}

class GlutenColumnExpressionSuite extends ColumnExpressionSuite with GlutenSQLTestsTrait {
import testImplicits._
testGluten("raise_error") {
val strDf = Seq(("hello")).toDF("a")

val e1 = intercept[SparkException] {
strDf.select(raise_error(lit(null.asInstanceOf[String]))).collect()
}
assert(e1.getCause.isInstanceOf[RuntimeException])

val e2 = intercept[SparkException] {
strDf.select(raise_error($"a")).collect()
}
assert(e2.getCause.isInstanceOf[RuntimeException])
assert(e2.getCause.getMessage contains "hello")
}

testGluten("assert_true") {
// assert_true(condition, errMsgCol)
val booleanDf = Seq((true), (false)).toDF("cond")
checkAnswer(
booleanDf.filter("cond = true").select(assert_true($"cond")),
Row(null) :: Nil
)
val e1 = intercept[SparkException] {
booleanDf.select(assert_true($"cond", lit(null.asInstanceOf[String]))).collect()
}
assert(e1.getCause.isInstanceOf[RuntimeException])

val nullDf = Seq(("first row", None), ("second row", Some(true))).toDF("n", "cond")
checkAnswer(
nullDf.filter("cond = true").select(assert_true($"cond", $"cond")),
Row(null) :: Nil
)
val e2 = intercept[SparkException] {
nullDf.select(assert_true($"cond", $"n")).collect()
}
assert(e2.getCause.isInstanceOf[RuntimeException])
assert(e2.getCause.getMessage contains "first row")

// assert_true(condition)
val intDf = Seq((0, 1)).toDF("a", "b")
checkAnswer(intDf.select(assert_true($"a" < $"b")), Row(null) :: Nil)
val e3 = intercept[SparkException] {
intDf.select(assert_true($"a" > $"b")).collect()
}
assert(e3.getCause.isInstanceOf[RuntimeException])
assert(e3.getCause.getMessage contains "'('a > 'b)' is not true!")
}

testGluten(
"input_file_name, input_file_block_start and input_file_block_length " +
"should fall back if scan falls back") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
package org.apache.spark.sql

import org.apache.gluten.GlutenConfig
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.utils.{BackendTestSettings, BackendTestUtils, SystemParameters}

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand All @@ -39,6 +40,7 @@ import java.util.Locale
import scala.collection.mutable.ArrayBuffer
import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try
import scala.util.control.NonFatal

/**
* End-to-end test cases for SQL queries.
Expand Down Expand Up @@ -761,4 +763,45 @@ class GlutenSQLQueryTestSuite
super.afterAll()
}
}

/**
* This method handles exceptions occurred during query execution as they may need special care to
* become comparable to the expected output.
*
* @param result
* a function that returns a pair of schema and output
*/
override protected def handleExceptions(
result: => (String, Seq[String])): (String, Seq[String]) = {
try {
result
} catch {
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
s.getCause match {
case e: GlutenException =>
val reasonPattern = "Reason: (.*)".r
val reason = reasonPattern.findFirstMatchIn(e.getMessage).map(_.group(1))

reason match {
case Some(r) =>
(emptySchema, Seq(e.getClass.getName, r))
case None => (emptySchema, Seq())
}
case cause =>
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
}
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- test for misc functions

-- typeof
select typeof(null);
select typeof(true);
select typeof(1Y), typeof(1S), typeof(1), typeof(1L);
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2);
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days');
select typeof(x'ABCD'), typeof('SPARK');
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'));

-- Spark-32793: Rewrite AssertTrue with RaiseError
SELECT assert_true(true), assert_true(boolean(1));
SELECT assert_true(false);
SELECT assert_true(boolean(0));
SELECT assert_true(null);
SELECT assert_true(boolean(null));
SELECT assert_true(false, 'custom error message');

CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v);
SELECT raise_error('error message');
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc;
Loading

0 comments on commit 6f189c7

Please sign in to comment.