From 1527b4d2e72416dd08560aa542b3578ca43b1b97 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 19 Aug 2024 09:41:53 +0800 Subject: [PATCH] support anti/semi join with mixed join condition --- .../execution/CHHashJoinExecTransformer.scala | 26 ++++++------------- .../gluten/utils/CHJoinValidateUtil.scala | 11 ++++---- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 2dd45281e4169..9b6b2958ccc71 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -34,15 +34,12 @@ import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { - def toNativeJoinType(joinType: JoinType): JoinType = { - joinType match { - case ExistenceJoin(_) => - LeftSemi - case _ => - joinType - } - } + // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with + // a new column to indecate whether the row is matched in the right table. + // Indeed, the ExistenceJoin is transformed into left any join in CH. + // We don't have left any join in substrait, so use left semi join instead. + // and isExistenceJoin is set to true to indicate that it is an existence join. def toSubstraitJoinType(sparkJoin: JoinType, buildRight: Boolean): JoinRel.JoinType = sparkJoin match { case _: InnerLike => @@ -104,7 +101,7 @@ case class CHShuffledHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - ShuffleHashJoinStrategy(finalJoinType), + ShuffleHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -113,7 +110,6 @@ case class CHShuffledHashJoinExecTransformer( } super.doValidateInternal() } - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") @@ -226,7 +222,7 @@ case class CHBroadcastHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - BroadcastHashJoinStrategy(finalJoinType), + BroadcastHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -255,7 +251,7 @@ case class CHBroadcastHashJoinExecTransformer( val context = BroadCastHashJoinContext( buildKeyExprs, - finalJoinType, + joinType, buildSide == BuildRight, isMixedCondition(condition), joinType.isInstanceOf[ExistenceJoin], @@ -278,12 +274,6 @@ case class CHBroadcastHashJoinExecTransformer( res } - // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with - // a new column to indecate whether the row is matched in the right table. - // Indeed, the ExistenceJoin is transformed into left any join in CH. - // We don't have left any join in substrait, so use left semi join instead. - // and isExistenceJoin is set to true to indicate that it is an existence join. - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = { JoinTypeTransform.toSubstraitJoinType(joinType, buildSide == BuildRight) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala index 0f5b5e2c4fd5a..c52d14fcd89aa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala @@ -55,25 +55,24 @@ object CHJoinValidateUtil extends Logging { var shouldFallback = false val joinType = joinStrategy.joinType - if (!joinType.isInstanceOf[ExistenceJoin] && joinType.sql.contains("INNER")) { + if (joinType.sql.contains("INNER")) { shouldFallback = false; } else if ( condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, condition.get) ) { shouldFallback = joinStrategy match { - case BroadcastHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") case SortMergeJoinStrategy(_) => true - case ShuffleHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") case UnknownJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") + throw new IllegalArgumentException(s"Unknown join type $joinStrategy") + case _ => false } } else { shouldFallback = joinStrategy match { case SortMergeJoinStrategy(joinTy) => joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") || joinTy.toString.contains( "ExistenceJoin") + case UnknownJoinStrategy(_) => + throw new IllegalArgumentException(s"Unknown join type $joinStrategy") case _ => false } }