Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 22, 2024
1 parent 4db26ba commit 88fef39
Show file tree
Hide file tree
Showing 19 changed files with 80 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession)
logError(s"Validation failed for ShuffledHashJoinExec: ${validateResult.reason()}")
return smj
}
logDebug(s"Applied SortMergeJoin to ShuffledHashJoin")
hashJoin
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ object CHJoinValidateUtil extends Logging {
condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, condition.get)
val shouldFallback = joinStrategy match {
case SortMergeJoinStrategy(joinType) =>
joinType.sql.contains("SEMI") || joinType.sql.contains("ANTI") || joinType.toString
.contains("ExistenceJoin") || hasMixedFilterCondition
if (!joinType.isInstanceOf[ExistenceJoin] && joinType.sql.contains("INNER")) {
false
} else {
joinType.sql.contains("SEMI") || joinType.sql.contains("ANTI") || joinType.toString
.contains("ExistenceJoin") || hasMixedFilterCondition
}
case UnknownJoinStrategy(joinType) =>
throw new IllegalArgumentException(s"Unknown join type $joinStrategy")
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class GlutenClickHouseColumnarMemorySortShuffleSuite
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class GlutenClickHouseDSV2Suite extends GlutenClickHouseTPCHAbstractSuite {
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class GlutenClickHouseDecimalSuite
decimalTPCHTables.foreach {
dt =>
{
val fallBack = (sql_num == 16 || sql_num == 21)
val fallBack = (sql_num == 16)
val compareResult = !dt._2.contains(sql_num)
val native = if (fallBack) "fallback" else "native"
val compare = if (compareResult) "compare" else "noCompare"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class GlutenClickHouseTPCHNullableColumnarShuffleSuite extends GlutenClickHouseT
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,6 @@ class GlutenClickHouseTPCDSParquetColumnarShuffleAQESuite
| LIMIT 100 ;
|""".stripMargin
// There are some BroadcastHashJoin with NOT condition
compareResultsAgainstVanillaSpark(sql, true, { df => }, false)
compareResultsAgainstVanillaSpark(sql, true, { df => })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends GlutenClickHouseTPC
.set("spark.shuffle.manager", "sort")
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.memory.offHeap.size", "8g")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.memory.offHeap.size", "12g")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_config.extra_memory_hard_limit",
"2147483648")
.set("spark.gluten.sql.columnar.forceShuffledHashJoin", "false")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) {
runTPCHQuery(21) {
df =>
val plans = collect(df.queryExecution.executedPlan) {
case scanExec: BasicScanExecTransformer => scanExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ class GlutenClickHouseTPCHParquetAQEConcurrentSuite
.set("spark.shuffle.manager", "sort")
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.gluten.sql.columnar.backend.ch.runtime_config.use_local_format", "true")
.set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", "sparkMurmurHash3_32")
}
override protected def runTPCHQuery(
queryNum: Int,
Expand Down Expand Up @@ -75,6 +76,47 @@ class GlutenClickHouseTPCHParquetAQEConcurrentSuite
createNotNullTPCHTablesInParquet(tablesPath)
}

test("TPCH Q21 (1)") {
runTPCHQuery(21) { df => }
}

test("TPCH Q21 (2)") {
val sql = """
|SELECT
|s_name,
|count(*) AS numwait
|FROM
|supplier, lineitem l1, orders, nation
|WHERE s_suppkey = l1.l_suppkey
|AND o_orderkey = l1.l_orderkey
|AND o_orderstatus = 'F'
|AND l1.l_receiptdate > l1.l_commitdate
|AND EXISTS (
| SELECT
| *
| FROM
| lineitem l2
| WHERE l2.l_orderkey = l1.l_orderkey
| AND l2.l_suppkey <> l1.l_suppkey)
| AND NOT EXISTS (
| SELECT
| *
| FROM
| lineitem l3
| WHERE l3.l_orderkey = l1.l_orderkey
| AND l3.l_suppkey <> l1.l_suppkey
| AND l3.l_receiptdate > l3.l_commitdate)
| AND s_nationkey = n_nationkey
| AND n_name = 'SAUDI ARABIA'
| GROUP BY
| s_name
| ORDER BY
| numwait DESC,
| s_name
| LIMIT 100""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("fix race condition at the global variable of ColumnarOverrideRules::isAdaptiveContext") {

val queries = ParVector((1 to 22) ++ (1 to 22) ++ (1 to 22) ++ (1 to 22): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class GlutenClickHouseTPCHParquetAQESuite
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("TPCH Q22") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q21") {
runTPCHQuery(21, noFallBack = false) { df => }
runTPCHQuery(21) { df => }
}

test("GLUTEN-2115: Fix wrong number of records shuffle written") {
Expand Down
6 changes: 5 additions & 1 deletion cpp-ch/local-engine/Common/QueryContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include <sstream>
#include <iomanip>

#include <Poco/Logger.h>
#include <Common/logger_useful.h>


namespace DB
{
Expand Down Expand Up @@ -70,6 +73,7 @@ int64_t QueryContextManager::initializeQuery()

query_context->thread_group->memory_tracker.setSoftLimit(memory_limit);
query_context->thread_group->memory_tracker.setHardLimit(memory_limit + config.extra_memory_hard_limit);
LOG_INFO(getLogger("QueryContextManager"), "xxx memory limit: {} {}", memory_limit, config.extra_memory_hard_limit);
int64_t id = reinterpret_cast<int64_t>(query_context->thread_group.get());
query_map.insert(id, query_context);
return id;
Expand Down Expand Up @@ -172,4 +176,4 @@ double currentThreadGroupMemoryUsageRatio()
}
return static_cast<double>(CurrentThread::getGroup()->memory_tracker.get()) / CurrentThread::getGroup()->memory_tracker.getSoftLimit();
}
}
}
3 changes: 3 additions & 0 deletions cpp-ch/local-engine/Parser/JoinRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ std::shared_ptr<DB::TableJoin> createDefaultTableJoin(substrait::JoinRel_JoinTyp
std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
table_join->setKind(kind_and_strictness.first);
table_join->setStrictness(kind_and_strictness.second);
LOG_ERROR(getLogger("JoinRelParser"), "xxx join type: {} {}", table_join->kind(), table_join->strictness());
return table_join;
}

Expand Down Expand Up @@ -204,6 +205,8 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ

DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right)
{
LOG_ERROR(getLogger("JoinRelParser"), "xxx left header: {}", left->getCurrentDataStream().header.dumpStructure());
LOG_ERROR(getLogger("JoinRelParser"), "xxx right header: {}", right->getCurrentDataStream().header.dumpStructure());
auto join_config = JoinConfig::loadFromContext(getContext());
google::protobuf::StringValue optimization_info;
optimization_info.ParseFromString(join.advanced_extension().optimization().value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ case class Table(name: String, partitionColumns: Seq[String])
abstract class WholeStageTransformerSuite
extends GlutenQueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
with AdaptiveSparkPlanHelper
with Logging {

protected val resourcePath: String
protected val fileFormat: String
Expand Down Expand Up @@ -120,6 +121,9 @@ abstract class WholeStageTransformerSuite
val queryResultStr =
Arm.withResource(Source.fromFile(new File(queriesResults + "/" + sqlNum + ".out"), "UTF-8"))(
_.mkString)
if (!queryResultStr.equals(resultStr.toString())) {
logError(s"Results are mismatched. $sqlNum \n$queryResultStr vs. \n${resultStr.toString()}")
}
assert(queryResultStr.equals(resultStr.toString))
}

Expand Down

0 comments on commit 88fef39

Please sign in to comment.