Skip to content

Commit

Permalink
add bnlj
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Jul 2, 2024
1 parent 6f3fa01 commit 70fb353
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 403 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,20 @@ public static long build(
return converter.genColumnNameWithExprId(attr);
})
.collect(Collectors.joining(","));

int joinType;
if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) {
joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal();
} else {
joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal();
}

return nativeBuild(
broadCastContext.buildHashTableId(),
batches,
rowCount,
joinKey,
SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(),
joinType,
broadCastContext.hasMixedFiltCondition(),
toNameStruct(output).toByteArray());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, DenseRank, Expression, Lag, Lead, Literal, NamedExpression, Rank, RowNumber}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -297,4 +298,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true

override def supportBroadcastNestedJoinJoinType: JoinType => Boolean = {
case _: InnerLike | LeftOuter | RightOuter | LeftSemi | FullOuter => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashJoin}
import org.apache.spark.sql.types._
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}
Expand All @@ -44,31 +43,7 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
condition
) {
// Unique ID for builded table
lazy val buildBroadcastTableId: String = "BuiltBroadcastTable-" + buildPlan.id

lazy val (buildKeyExprs, streamedKeyExprs) = {
require(
leftKeys.length == rightKeys.length &&
leftKeys
.map(_.dataType)
.zip(rightKeys.map(_.dataType))
.forall(types => sameType(types._1, types._2)),
"Join keys from two sides should have same length and types"
)
// Spark has an improvement which would patch integer joins keys to a Long value.
// But this improvement would cause add extra project before hash join in velox,
// disabling this improvement as below would help reduce the project.
val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) {
(HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys))
} else {
(leftKeys, rightKeys)
}
if (needSwitchChildren) {
(lkeys, rkeys)
} else {
(rkeys, lkeys)
}
}
lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
Expand Down Expand Up @@ -106,27 +81,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
res
}

def sameType(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
sameType(fromElement, toElement)

case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
sameType(fromKey, toKey) &&
sameType(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (l, r) =>
l.name.equalsIgnoreCase(r.name) &&
sameType(l.dataType, r.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}

override def genJoinParameters(): Any = {
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,6 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def supportCartesianProductExec(): Boolean = true

override def supportBroadcastNestedLoopJoinExec(): Boolean = true

override def supportSampleExec(): Boolean = true

override def supportColumnarArrowUdf(): Boolean = true
Expand Down
51 changes: 50 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <Parser/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Processors/Chunk.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
Expand All @@ -60,7 +61,6 @@
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/wrappers.pb.h>
#include <sys/resource.h>
#include <Poco/Logger.h>
#include <Poco/Util/MapConfiguration.h>
Expand Down Expand Up @@ -1077,4 +1077,53 @@ UInt64 MemoryUtil::getMemoryRSS()
return rss * sysconf(_SC_PAGESIZE);
}


void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols)
{
ActionsDAGPtr project = std::make_shared<ActionsDAG>(plan.getCurrentDataStream().header.getNamesAndTypesList());
NamesWithAliases project_cols;
for (const auto & col : cols)
{
project_cols.emplace_back(NameWithAlias(col, col));
}
project->project(project_cols);
QueryPlanStepPtr project_step = std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), project);
project_step->setStepDescription("Reorder Join Output");
plan.addStep(std::move(project_step));
}

std::pair<DB::JoinKind, DB::JoinStrictness> JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type)
{
switch (join_type)
{
case substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
return {DB::JoinKind::Inner, DB::JoinStrictness::All};
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
return {DB::JoinKind::Left, DB::JoinStrictness::Semi};
case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
return {DB::JoinKind::Left, DB::JoinStrictness::Anti};
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
return {DB::JoinKind::Left, DB::JoinStrictness::All};
case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
return {DB::JoinKind::Right, DB::JoinStrictness::All};
case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
return {DB::JoinKind::Full, DB::JoinStrictness::All};
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type));
}
}

std::pair<DB::JoinKind, DB::JoinStrictness> JoinUtil::getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type)
{
switch (join_type)
{
case substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER:
return {DB::JoinKind::Cross, DB::JoinStrictness::All};
default:
throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type));
}
}

}
11 changes: 11 additions & 0 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#pragma once

#include <filesystem>
#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
Expand All @@ -25,6 +26,8 @@
#include <Interpreters/Context.h>
#include <Processors/Chunk.h>
#include <base/types.h>
#include <google/protobuf/wrappers.pb.h>
#include <substrait/algebra.pb.h>
#include <Common/CurrentThread.h>

namespace DB
Expand Down Expand Up @@ -302,4 +305,12 @@ class ConcurrentDeque
mutable std::mutex mtx;
};

class JoinUtil
{
public:
static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols);
static std::pair<DB::JoinKind, DB::JoinStrictness> getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type);
static std::pair<DB::JoinKind, DB::JoinStrictness> getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type);
};

}
8 changes: 6 additions & 2 deletions cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
DB::ReadBuffer & input,
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
jint join_type,
bool has_mixed_join_condition,
const std::string & named_struct)
{
Expand All @@ -109,7 +109,11 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
DB::JoinKind kind;
DB::JoinStrictness strictness;

std::tie(kind, strictness) = getJoinKindAndStrictness(join_type);
if (key.starts_with("BuiltBNLJBroadcastTable-"))
std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast<substrait::CrossRel_JoinType>(join_type));
else
std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast<substrait::JoinRel_JoinType>(join_type));


substrait::NamedStruct substrait_struct;
substrait_struct.ParseFromString(named_struct);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Join/BroadCastJoinBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
DB::ReadBuffer & input,
jlong row_count,
const std::string & join_keys,
substrait::JoinRel_JoinType join_type,
jint join_type,
bool has_mixed_join_condition,
const std::string & named_struct);
void cleanBuildHashTable(const std::string & hash_table_id, jlong instance);
Expand Down
Loading

0 comments on commit 70fb353

Please sign in to comment.