diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index ee208ac74a..6afa6c7e8a 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -938,7 +938,7 @@ impl PhysicalPlanner { &join.left_join_keys, &join.right_join_keys, join.join_type, - &None, + &join.condition, )?; let sort_options = join diff --git a/native/core/src/execution/proto/operator.proto b/native/core/src/execution/proto/operator.proto index 335d425966..74a75543d5 100644 --- a/native/core/src/execution/proto/operator.proto +++ b/native/core/src/execution/proto/operator.proto @@ -104,6 +104,7 @@ message SortMergeJoin { repeated spark.spark_expression.Expr right_join_keys = 2; JoinType join_type = 3; repeated spark.spark_expression.Expr sort_options = 4; + optional spark.spark_expression.Expr condition = 5; } enum JoinType { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index da534b02ce..6b1782a54e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2728,10 +2728,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } - // TODO: Support SortMergeJoin with join condition after new DataFusion release - if (join.condition.isDefined) { - withInfo(op, "Sort merge join with a join condition is not supported") - return None + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + withInfo(join, cond) + return None + } + condProto.get } val joinType = join.joinType match { @@ -2777,6 +2780,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .addAllSortOptions(sortOptions.map(_.get).asJava) .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + condition.map(joinBuilder.setCondition) Some(result.setSortMergeJoin(joinBuilder).build()) } else { val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 8bae2ecaa3..07bd61e7be 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -217,7 +217,6 @@ class CometJoinSuite extends CometTestBase { } } - // TODO: Add a test for SortMergeJoin with join filter after new DataFusion release test("SortMergeJoin without join filter") { withSQLConf( SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", @@ -263,4 +262,72 @@ class CometJoinSuite extends CometTestBase { } } } + + test("SortMergeJoin with join filter") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df1 = sql( + "SELECT * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1 AND " + + "tbl_a._1 > tbl_b._2") + df1.explain() + checkSparkAnswerAndOperator(df1) + + val df2 = sql( + "SELECT * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df2) + + val df3 = sql( + "SELECT * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df3) + + val df4 = sql( + "SELECT * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df4) + + val df5 = sql( + "SELECT * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df5) + + val df6 = sql( + "SELECT * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df6) + + val df7 = sql( + "SELECT * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1 " + + "AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df7) + + /* + + val left = sql("SELECT * FROM tbl_a") + val right = sql("SELECT * FROM tbl_b") + + val df8 = + left.join(right, left("_2") === right("_1") && left("_2") >= right("_1"), "leftsemi") + checkSparkAnswerAndOperator(df8) + + val df9 = + right.join(left, left("_2") === right("_1") && left("_2") >= right("_1"), "leftsemi") + checkSparkAnswerAndOperator(df9) + + val df10 = + left.join(right, left("_2") === right("_1") && left("_2") >= right("_1"), "leftanti") + checkSparkAnswerAndOperator(df10) + + val df11 = + right.join(left, left("_2") === right("_1") && left("_2") >= right("_1"), "leftanti") + checkSparkAnswerAndOperator(df11) + */ + } + } + } + } }