Skip to content

Commit

Permalink
fix raise_error, assert_true sql run ut
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu committed Jul 9, 2024
1 parent 369170e commit de1c00b
Show file tree
Hide file tree
Showing 15 changed files with 823 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.ExpressionNames.{RAISE_ERROR, TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.FallbackTags
Expand Down Expand Up @@ -837,6 +837,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES),
Sig[RaiseError](RAISE_ERROR),
// For test purpose.
Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,8 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
sql("""SELECT assert_true(l_orderkey >= 100), l_orderkey from
| lineitem limit 100""".stripMargin).collect()
}
assert(e.getMessage.contains("l_orderkey"))
assert(e.getCause.isInstanceOf[RuntimeException])
assert(e.getMessage.contains("'(l_orderkey#76L >= cast(100 as bigint))' is not true"))
}

test("Test spark_partition_id function") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ object ExpressionMappings {
Sig[SparkPartitionID](SPARK_PARTITION_ID),
Sig[WidthBucket](WIDTH_BUCKET),
Sig[ReplicateRows](REPLICATE_ROWS),
Sig[RaiseError](RAISE_ERROR),
// Decimal
Sig[UnscaledValue](UNSCALED_VALUE),
// Generator function
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- test for misc functions

-- typeof
select typeof(null);
select typeof(true);
select typeof(1Y), typeof(1S), typeof(1), typeof(1L);
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2);
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days');
select typeof(x'ABCD'), typeof('SPARK');
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'));

-- Spark-32793: Rewrite AssertTrue with RaiseError
SELECT assert_true(true), assert_true(boolean(1));
SELECT assert_true(false);
SELECT assert_true(boolean(0));
SELECT assert_true(null);
SELECT assert_true(boolean(null));
SELECT assert_true(false, 'custom error message');

CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v);
SELECT raise_error('error message');
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc;
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 16


-- !query
select typeof(null)
-- !query schema
struct<typeof(NULL):string>
-- !query output
void


-- !query
select typeof(true)
-- !query schema
struct<typeof(true):string>
-- !query output
boolean


-- !query
select typeof(1Y), typeof(1S), typeof(1), typeof(1L)
-- !query schema
struct<typeof(1):string,typeof(1):string,typeof(1):string,typeof(1):string>
-- !query output
tinyint smallint int bigint


-- !query
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2)
-- !query schema
struct<typeof(CAST(1.0 AS FLOAT)):string,typeof(1.0):string,typeof(1.2):string>
-- !query output
float double decimal(2,1)


-- !query
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days')
-- !query schema
struct<typeof(DATE '1986-05-23'):string,typeof(TIMESTAMP '1986-05-23 00:00:00'):string,typeof(INTERVAL '23' DAY):string>
-- !query output
date timestamp interval day


-- !query
select typeof(x'ABCD'), typeof('SPARK')
-- !query schema
struct<typeof(X'ABCD'):string,typeof(SPARK):string>
-- !query output
binary string


-- !query
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'))
-- !query schema
struct<typeof(array(1, 2)):string,typeof(map(1, 2)):string,typeof(named_struct(a, 1, b, spark)):string>
-- !query output
array<int> map<int,int> struct<a:int,b:string>


-- !query
SELECT assert_true(true), assert_true(boolean(1))
-- !query schema
struct<assert_true(true, 'true' is not true!):void,assert_true(1, 'cast(1 as boolean)' is not true!):void>
-- !query output
NULL NULL


-- !query
SELECT assert_true(false)
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'false' is not true!


-- !query
SELECT assert_true(boolean(0))
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'cast(0 as boolean)' is not true!


-- !query
SELECT assert_true(null)
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'null' is not true!


-- !query
SELECT assert_true(boolean(null))
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
'cast(null as boolean)' is not true!


-- !query
SELECT assert_true(false, 'custom error message')
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
custom error message


-- !query
CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v)
-- !query schema
struct<>
-- !query output



-- !query
SELECT raise_error('error message')
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
error message


-- !query
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc
-- !query schema
struct<>
-- !query output
org.apache.gluten.exception.GlutenException
too big: 8
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
package org.apache.spark.sql

import org.apache.gluten.GlutenConfig
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.utils.{BackendTestSettings, BackendTestUtils, SystemParameters}

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand All @@ -39,6 +40,7 @@ import java.util.Locale
import scala.collection.mutable.ArrayBuffer
import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try
import scala.util.control.NonFatal

/**
* End-to-end test cases for SQL queries.
Expand Down Expand Up @@ -761,4 +763,45 @@ class GlutenSQLQueryTestSuite
super.afterAll()
}
}

/**
* This method handles exceptions occurred during query execution as they may need special care to
* become comparable to the expected output.
*
* @param result
* a function that returns a pair of schema and output
*/
override protected def handleExceptions(
result: => (String, Seq[String])): (String, Seq[String]) = {
try {
result
} catch {
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
s.getCause match {
case e: GlutenException =>
val reasonPattern = "Reason: (.*)".r
val reason = reasonPattern.findFirstMatchIn(e.getMessage).map(_.group(1))

reason match {
case Some(r) =>
(emptySchema, Seq(e.getClass.getName, r))
case None => (emptySchema, Seq())
}
case cause =>
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
}
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- test for misc functions

-- typeof
select typeof(null);
select typeof(true);
select typeof(1Y), typeof(1S), typeof(1), typeof(1L);
select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2);
select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days');
select typeof(x'ABCD'), typeof('SPARK');
select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'));

-- Spark-32793: Rewrite AssertTrue with RaiseError
SELECT assert_true(true), assert_true(boolean(1));
SELECT assert_true(false);
SELECT assert_true(boolean(0));
SELECT assert_true(null);
SELECT assert_true(boolean(null));
SELECT assert_true(false, 'custom error message');

CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v);
SELECT raise_error('error message');
SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc;
Loading

0 comments on commit de1c00b

Please sign in to comment.