Skip to content

Commit

Permalink
[GLUTEN-341][CH] Support BHJ + isNullAwareAntiJoin for the CH backend
Browse files Browse the repository at this point in the history
for example: TPCH Q16

Close #341.
  • Loading branch information
zzcclp committed Sep 2, 2024
1 parent 376167e commit 087f53e
Show file tree
Hide file tree
Showing 28 changed files with 395 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ private static native long nativeBuild(
int joinType,
boolean hasMixedFiltCondition,
boolean isExistenceJoin,
byte[] namedStruct);
byte[] namedStruct,
boolean isNullAwareAntiJoin);

private StorageJoinBuilder() {}

Expand Down Expand Up @@ -94,7 +95,8 @@ public static long build(
joinType,
broadCastContext.hasMixedFiltCondition(),
broadCastContext.isExistenceJoin(),
toNameStruct(output).toByteArray());
toNameStruct(output).toByteArray(),
broadCastContext.isNullAwareAntiJoin());
}

/** create table named struct */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ case class BroadCastHashJoinContext(
hasMixedFiltCondition: Boolean,
isExistenceJoin: Boolean,
buildSideStructure: Seq[Attribute],
buildHashTableId: String)
buildHashTableId: String,
isNullAwareAntiJoin: Boolean = false)

case class CHBroadcastHashJoinExecTransformer(
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -230,9 +231,6 @@ case class CHBroadcastHashJoinExecTransformer(
if (shouldFallback) {
return ValidationResult.failed("ch join validate fail")
}
if (isNullAwareAntiJoin) {
return ValidationResult.failed("ch does not support NAAJ")
}
super.doValidateInternal()
}

Expand All @@ -256,7 +254,9 @@ case class CHBroadcastHashJoinExecTransformer(
isMixedCondition(condition),
joinType.isInstanceOf[ExistenceJoin],
buildPlan.output,
buildHashTableId)
buildHashTableId,
isNullAwareAntiJoin
)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class GlutenClickHouseColumnarMemorySortShuffleSuite
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class GlutenClickHouseDSV2Suite extends GlutenClickHouseTPCHAbstractSuite {
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,22 +343,18 @@ class GlutenClickHouseDecimalSuite
decimalTPCHTables.foreach {
dt =>
{
val fallBack = (sql_num == 16)
val compareResult = !dt._2.contains(sql_num)
val native = if (fallBack) "fallback" else "native"
val compare = if (compareResult) "compare" else "noCompare"
val PrecisionLoss = s"allowPrecisionLoss=$allowPrecisionLoss"
val decimalType = dt._1
test(s"""TPCH Decimal(${decimalType.precision},${decimalType.scale})
| Q$sql_num[$PrecisionLoss,$native,$compare]""".stripMargin) {
| Q$sql_num[$PrecisionLoss,native,$compare]""".stripMargin) {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
withSQLConf(
(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key, allowPrecisionLoss)) {
runTPCHQuery(
sql_num,
tpchQueries,
compareResult = compareResult,
noFallBack = !fallBack) { _ => {} }
runTPCHQuery(sql_num, tpchQueries, compareResult = compareResult) {
_ => {}
}
}
spark.sql(s"use default")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class GlutenClickHouseTPCHNullableColumnarShuffleSuite extends GlutenClickHouseT
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class GlutenClickHouseTPCHParquetAQESuite
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr

// see issue https://github.com/Kyligence/ClickHouse/issues/93
test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
runTPCHQuery(16) { df => }
}

test("TPCH Q17") {
Expand Down Expand Up @@ -2797,5 +2797,144 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-341: Support BHJ + isNullAwareAntiJoin for the CH backend") {
def checkBHJWithIsNullAwareAntiJoin(df: DataFrame): Unit = {
val bhjs = df.queryExecution.executedPlan.collect {
case bhj: CHBroadcastHashJoinExecTransformer if bhj.isNullAwareAntiJoin => true
}
assert(bhjs.size == 1)
}

val sql =
s"""
|SELECT
| p_brand,
| p_type,
| p_size,
| count(DISTINCT ps_suppkey) AS supplier_cnt
|FROM
| partsupp,
| part
|WHERE
| p_partkey = ps_partkey
| AND p_brand <> 'Brand#45'
| AND p_type NOT LIKE 'MEDIUM POLISHED%'
| AND p_size IN (49, 14, 23, 45, 19, 3, 36, 9)
| AND ps_suppkey NOT IN (
| SELECT
| s_suppkey
| FROM
| supplier
| WHERE
| s_comment is null)
|GROUP BY
| p_brand,
| p_type,
| p_size
|ORDER BY
| supplier_cnt DESC,
| p_brand,
| p_type,
| p_size;
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})

val sql1 =
s"""
|SELECT
| p_brand,
| p_type,
| p_size,
| count(DISTINCT ps_suppkey) AS supplier_cnt
|FROM
| partsupp,
| part
|WHERE
| p_partkey = ps_partkey
| AND p_brand <> 'Brand#45'
| AND p_type NOT LIKE 'MEDIUM POLISHED%'
| AND p_size IN (49, 14, 23, 45, 19, 3, 36, 9)
| AND ps_suppkey NOT IN (
| SELECT
| s_suppkey
| FROM
| supplier
| WHERE
| s_comment LIKE '%Customer%Complaints11%')
|GROUP BY
| p_brand,
| p_type,
| p_size
|ORDER BY
| supplier_cnt DESC,
| p_brand,
| p_type,
| p_size;
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql1,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})

val sql2 =
s"""
|select * from partsupp
|where
|ps_suppkey NOT IN (SELECT suppkey FROM VALUES (50), (null) sub(suppkey))
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql2,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})

val sql3 =
s"""
|select * from partsupp
|where
|ps_suppkey NOT IN (SELECT suppkey FROM VALUES (50) sub(suppkey) WHERE suppkey > 100)
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql3,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})

val sql4 =
s"""
|select * from partsupp
|where
|ps_suppkey NOT IN (SELECT suppkey FROM VALUES (50), (60) sub(suppkey))
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql4,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})

val sql5 =
s"""
|select * from partsupp
|where
|ps_suppkey NOT IN (SELECT suppkey FROM VALUES (null) sub(suppkey))
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql5,
true,
df => {
checkBHJWithIsNullAwareAntiJoin(df)
})
}
}
// scalastyle:on line.size.limit
6 changes: 4 additions & 2 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jint join_type,
bool has_mixed_join_condition,
bool is_existence_join,
const std::string & named_struct)
const std::string & named_struct,
bool is_null_aware_anti_join)
{
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
Names key_names;
Expand Down Expand Up @@ -191,7 +192,8 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
columns_description,
ConstraintsDescription(),
key,
true);
true,
is_null_aware_anti_join);
}

void init(JNIEnv * env)
Expand Down
3 changes: 2 additions & 1 deletion cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jint join_type,
bool has_mixed_join_condition,
bool is_existence_join,
const std::string & named_struct);
const std::string & named_struct,
bool is_null_aware_anti_join);
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & hash_table_id);

Expand Down
Loading

0 comments on commit 087f53e

Please sign in to comment.