diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 092612ea7340..5e3027b8d03b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec @@ -79,6 +79,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case LeftExistence(_) => + left.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => throw new IllegalArgumentException(s"${getClass.getSimpleName} not take $x as the JoinType") } @@ -145,6 +149,22 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputBuildOutput) } + private def validateJoinTypeAndBuildSide(): ValidationResult = { + val result = joinType match { + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok + case _ => + ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") + } + if (!result.isValid) { + return result + } + (joinType, buildSide) match { + case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => + ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.ok // continue + } + } + override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportBroadcastNestedLoopJoinExec()) { return ValidationResult.notOk("Broadcast Nested Loop join is not supported in this backend") @@ -152,10 +172,9 @@ abstract class BroadcastNestedLoopJoinExecTransformer( if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - (joinType, buildSide) match { - case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - return ValidationResult.notOk(s"$joinType join is not supported with $buildSide") - case _ => // continue + val validateResult = validateJoinTypeAndBuildSide() + if (!validateResult.isValid) { + return validateResult } val substraitContext = new SubstraitContext