Skip to content

Commit

Permalink
fixed missing columns when there is mixed join conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jun 6, 2024
1 parent a76c92e commit 7ddef62
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private static native long nativeBuild(
long rowCount,
String joinKeys,
int joinType,
boolean hasMixedFiltCondition,
byte[] namedStruct);

private StorageJoinBuilder() {}
Expand Down Expand Up @@ -79,6 +80,7 @@ public static long build(
rowCount,
joinKey,
SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(),
broadCastContext.hasMixedFiltCondition(),
toNameStruct(output).toByteArray());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ case class CHBroadcastBuildSideRDD(
case class BroadCastHashJoinContext(
buildSideJoinKeys: Seq[Expression],
joinType: JoinType,
hasMixedFiltCondition: Boolean,
buildSideStructure: Seq[Attribute],
buildHashTableId: String)

Expand Down Expand Up @@ -139,9 +140,26 @@ case class CHBroadcastHashJoinExecTransformer(
}
val broadcast = buildPlan.executeBroadcast[BuildSideRelation]()
val context =
BroadCastHashJoinContext(buildKeyExprs, joinType, buildPlan.output, buildHashTableId)
BroadCastHashJoinContext(
buildKeyExprs,
joinType,
isMixedCondition(condition),
buildPlan.output,
buildHashTableId)
val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context)
// FIXME: Do we have to make build side a RDD?
streamedRDD :+ broadcastRDD
}

def isMixedCondition(cond: Option[Expression]): Boolean = {
val res = if (cond.isDefined) {
val leftOutputSet = left.outputSet
val rightOutputSet = right.outputSet
val allReferences = cond.get.references
!(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet))
} else {
false
}
res
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2570,13 +2570,21 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
spark.sql("create table ineq_join_t2 (key bigint, value bigint) using parquet");
spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)");
spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4, 6), (5, 3)");
val sql =
val sql1 =
"""
| select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1
| left join ineq_join_t2 as t2
| on t1.key = t2.key and t1.value > t2.value
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
compareResultsAgainstVanillaSpark(sql1, true, { _ => })

val sql2 =
"""
| select t1.key, t1.value from ineq_join_t1 as t1
| left join ineq_join_t2 as t2
| on t1.key = t2.key and t1.value > t2.value and t1.value > t2.key
|""".stripMargin
compareResultsAgainstVanillaSpark(sql2, true, { _ => })
spark.sql("drop table ineq_join_t1")
spark.sql("drop table ineq_join_t2")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w
(
countsAndBytes.flatMap(_._2),
countsAndBytes.map(_._1).sum,
BroadCastHashJoinContext(Seq(child.output.head), Inner, child.output, "")
BroadCastHashJoinContext(Seq(child.output.head), Inner, false, child.output, "")
)
}
}
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
bool has_mixed_join_condition,
const std::string & named_struct)
{
auto join_key_list = Poco::StringTokenizer(join_keys, ",");
Expand All @@ -105,6 +106,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
true,
kind,
strictness,
has_mixed_join_condition,
columns_description,
ConstraintsDescription(),
key,
Expand Down
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
bool has_mixed_join_condition,
const std::string & named_struct);
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & hash_table_id);
Expand Down
54 changes: 52 additions & 2 deletions cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
bool use_nulls_,
DB::JoinKind kind,
DB::JoinStrictness strictness,
bool has_mixed_join_condition,
const ColumnsDescription & columns,
const ConstraintsDescription & constraints,
const String & comment,
Expand All @@ -91,7 +92,11 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
key_names.push_back(RIHGT_COLUMN_PREFIX + name);
auto table_join = std::make_shared<DB::TableJoin>(SizeLimits(), true, kind, strictness, key_names);
right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind());
buildJoin(in, right_sample_block, table_join);
/// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join.
if (!has_mixed_join_condition)
buildJoin(in, right_sample_block, table_join);
else
collectAllInputs(in, right_sample_block);
}

/// The column names may be different in two blocks.
Expand Down Expand Up @@ -135,6 +140,51 @@ void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block heade
}
}

void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const DB::Block header)
{
local_engine::NativeReader block_stream(in);
ProfileInfo info;
while (Block block = block_stream.read())
{
DB::ColumnsWithTypeAndName columns;
for (size_t i = 0; i < block.columns(); ++i)
{
const auto & column = block.getByPosition(i);
columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i)));
}
DB::Block final_block(columns);
info.update(final_block);
input_blocks.emplace_back(std::move(final_block));
}
}

void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_ptr<DB::TableJoin> analyzed_join)
{
{
std::shared_lock lock(join_mutex);
if (join)
return;
}
std::unique_lock lock(join_mutex);
if (join)
return;
join = std::make_shared<HashJoin>(analyzed_join, header, overwrite, row_count);
while(!input_blocks.empty())
{
auto & block = *input_blocks.begin();
DB::ColumnsWithTypeAndName columns;
for (size_t i = 0; i < block.columns(); ++i)
{
const auto & column = block.getByPosition(i);
columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i)));
}
DB::Block final_block(columns);
join->addBlockToJoin(final_block, true);
input_blocks.pop_front();
}
}


/// The column names of 'rgiht_header' could be different from the ones in `input_blocks`, and we must
/// use 'right_header' to build the HashJoin. Otherwise, it will cause exceptions with name mismatches.
///
Expand All @@ -148,7 +198,7 @@ DB::JoinPtr StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJo
ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN,
"Table {} needs the same join_use_nulls setting as present in LEFT or FULL JOIN",
storage_metadata.comment);

buildJoinLazily(getRightSampleBlock(), analyzed_join);
HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join, right_sample_block);
/// reuseJoinedData will set the flag `HashJoin::from_storage_join` which is required by `FilledStep`
join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join));
Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/
#pragma once
#include <shared_mutex>
#include <Interpreters/JoinUtils.h>
#include <Storages/StorageInMemoryMetadata.h>

Expand All @@ -40,6 +41,7 @@ class StorageJoinFromReadBuffer
bool use_nulls_,
DB::JoinKind kind,
DB::JoinStrictness strictness,
bool has_mixed_join_condition,
const DB::ColumnsDescription & columns_,
const DB::ConstraintsDescription & constraints_,
const String & comment,
Expand All @@ -58,9 +60,13 @@ class StorageJoinFromReadBuffer
size_t row_count;
bool overwrite;
DB::Block right_sample_block;
std::shared_mutex join_mutex;
std::list<DB::Block> input_blocks;
std::shared_ptr<DB::HashJoin> join = nullptr;

void readAllBlocksFromInput(DB::ReadBuffer & in);
void buildJoin(DB::ReadBuffer & in, const DB::Block header, std::shared_ptr<DB::TableJoin> analyzed_join);
void collectAllInputs(DB::ReadBuffer & in, const DB::Block header);
void buildJoinLazily(DB::Block header, std::shared_ptr<DB::TableJoin> analyzed_join);
};
}
14 changes: 11 additions & 3 deletions cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,15 @@ JNIEXPORT jobject Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn
}

JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild(
JNIEnv * env, jclass, jstring key, jbyteArray in, jlong row_count_, jstring join_key_, jint join_type_, jbyteArray named_struct)
JNIEnv * env,
jclass,
jstring key,
jbyteArray in,
jlong row_count_,
jstring join_key_,
jint join_type_,
jboolean has_mixed_join_condition,
jbyteArray named_struct)
{
LOCAL_ENGINE_JNI_METHOD_START
const auto hash_table_id = jstring2string(env, key);
Expand All @@ -1276,8 +1284,8 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild
local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, length);
DB::CompressedReadBuffer input(read_buffer_from_java_array);
local_engine::configureCompressedReadBuffer(input);
const auto * obj
= make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(hash_table_id, input, row_count_, join_key, join_type, struct_string));
const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(
hash_table_id, input, row_count_, join_key, join_type, has_mixed_join_condition, struct_string));
env->ReleaseByteArrayElements(named_struct, struct_address, JNI_ABORT);
return obj->instance();
LOCAL_ENGINE_JNI_METHOD_END(env, 0)
Expand Down

0 comments on commit 7ddef62

Please sign in to comment.