From 48e6e1ba85797950030a923f13ce0bdec726685d Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 19 Aug 2024 09:41:53 +0800 Subject: [PATCH] support anti/semi join with mixed join condition --- .../execution/CHHashJoinExecTransformer.scala | 24 ++++++------------- .../gluten/utils/CHJoinValidateUtil.scala | 11 ++++----- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 2dd45281e4169..40ac7e1cb358f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -34,15 +34,12 @@ import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { - def toNativeJoinType(joinType: JoinType): JoinType = { - joinType match { - case ExistenceJoin(_) => - LeftSemi - case _ => - joinType - } - } + // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with + // a new column to indecate whether the row is matched in the right table. + // Indeed, the ExistenceJoin is transformed into left any join in CH. + // We don't have left any join in substrait, so use left semi join instead. + // and isExistenceJoin is set to true to indicate that it is an existence join. def toSubstraitJoinType(sparkJoin: JoinType, buildRight: Boolean): JoinRel.JoinType = sparkJoin match { case _: InnerLike => @@ -104,7 +101,7 @@ case class CHShuffledHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - ShuffleHashJoinStrategy(finalJoinType), + ShuffleHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -113,7 +110,6 @@ case class CHShuffledHashJoinExecTransformer( } super.doValidateInternal() } - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") @@ -226,7 +222,7 @@ case class CHBroadcastHashJoinExecTransformer( override protected def doValidateInternal(): ValidationResult = { val shouldFallback = CHJoinValidateUtil.shouldFallback( - BroadcastHashJoinStrategy(finalJoinType), + BroadcastHashJoinStrategy(joinType), left.outputSet, right.outputSet, condition) @@ -278,12 +274,6 @@ case class CHBroadcastHashJoinExecTransformer( res } - // ExistenceJoin is introduced in #SPARK-14781. It returns all rows from the left table with - // a new column to indecate whether the row is matched in the right table. - // Indeed, the ExistenceJoin is transformed into left any join in CH. - // We don't have left any join in substrait, so use left semi join instead. - // and isExistenceJoin is set to true to indicate that it is an existence join. - private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = { JoinTypeTransform.toSubstraitJoinType(joinType, buildSide == BuildRight) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala index 0f5b5e2c4fd5a..c52d14fcd89aa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala @@ -55,25 +55,24 @@ object CHJoinValidateUtil extends Logging { var shouldFallback = false val joinType = joinStrategy.joinType - if (!joinType.isInstanceOf[ExistenceJoin] && joinType.sql.contains("INNER")) { + if (joinType.sql.contains("INNER")) { shouldFallback = false; } else if ( condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, condition.get) ) { shouldFallback = joinStrategy match { - case BroadcastHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") case SortMergeJoinStrategy(_) => true - case ShuffleHashJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") case UnknownJoinStrategy(joinTy) => - joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") + throw new IllegalArgumentException(s"Unknown join type $joinStrategy") + case _ => false } } else { shouldFallback = joinStrategy match { case SortMergeJoinStrategy(joinTy) => joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") || joinTy.toString.contains( "ExistenceJoin") + case UnknownJoinStrategy(_) => + throw new IllegalArgumentException(s"Unknown join type $joinStrategy") case _ => false } }