Skip to content

Commit

Permalink
support anti/semi join with mixed join condition
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 19, 2024
1 parent 9fcd488 commit 1527b4d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

object JoinTypeTransform {
def toNativeJoinType(joinType: JoinType): JoinType = {
joinType match {
case ExistenceJoin(_) =>
LeftSemi
case _ =>
joinType
}
}

// ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with
// a new column to indecate whether the row is matched in the right table.
// Indeed, the ExistenceJoin is transformed into left any join in CH.
// We don't have left any join in substrait, so use left semi join instead.
// and isExistenceJoin is set to true to indicate that it is an existence join.
def toSubstraitJoinType(sparkJoin: JoinType, buildRight: Boolean): JoinRel.JoinType =
sparkJoin match {
case _: InnerLike =>
Expand Down Expand Up @@ -104,7 +101,7 @@ case class CHShuffledHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
CHJoinValidateUtil.shouldFallback(
ShuffleHashJoinStrategy(finalJoinType),
ShuffleHashJoinStrategy(joinType),
left.outputSet,
right.outputSet,
condition)
Expand All @@ -113,7 +110,6 @@ case class CHShuffledHashJoinExecTransformer(
}
super.doValidateInternal()
}
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)

override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "")
Expand Down Expand Up @@ -226,7 +222,7 @@ case class CHBroadcastHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
CHJoinValidateUtil.shouldFallback(
BroadcastHashJoinStrategy(finalJoinType),
BroadcastHashJoinStrategy(joinType),
left.outputSet,
right.outputSet,
condition)
Expand Down Expand Up @@ -255,7 +251,7 @@ case class CHBroadcastHashJoinExecTransformer(
val context =
BroadCastHashJoinContext(
buildKeyExprs,
finalJoinType,
joinType,
buildSide == BuildRight,
isMixedCondition(condition),
joinType.isInstanceOf[ExistenceJoin],
Expand All @@ -278,12 +274,6 @@ case class CHBroadcastHashJoinExecTransformer(
res
}

// ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with
// a new column to indecate whether the row is matched in the right table.
// Indeed, the ExistenceJoin is transformed into left any join in CH.
// We don't have left any join in substrait, so use left semi join instead.
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType = {
JoinTypeTransform.toSubstraitJoinType(joinType, buildSide == BuildRight)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,24 @@ object CHJoinValidateUtil extends Logging {
var shouldFallback = false
val joinType = joinStrategy.joinType

if (!joinType.isInstanceOf[ExistenceJoin] && joinType.sql.contains("INNER")) {
if (joinType.sql.contains("INNER")) {
shouldFallback = false;
} else if (
condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, condition.get)
) {
shouldFallback = joinStrategy match {
case BroadcastHashJoinStrategy(joinTy) =>
joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
case SortMergeJoinStrategy(_) => true
case ShuffleHashJoinStrategy(joinTy) =>
joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
case UnknownJoinStrategy(joinTy) =>
joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
throw new IllegalArgumentException(s"Unknown join type $joinStrategy")
case _ => false
}
} else {
shouldFallback = joinStrategy match {
case SortMergeJoinStrategy(joinTy) =>
joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") || joinTy.toString.contains(
"ExistenceJoin")
case UnknownJoinStrategy(_) =>
throw new IllegalArgumentException(s"Unknown join type $joinStrategy")
case _ => false
}
}
Expand Down

0 comments on commit 1527b4d

Please sign in to comment.