diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 64d14a00e..0f698d8aa 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -129,6 +129,11 @@ case class CometBroadcastExchangeExec( case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) if s.plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() + case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() + case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) + if plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) => throw new CometRuntimeException( "Child of CometBroadcastExchangeExec should be CometExec, " + diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 8bae2ecaa..b2e225b15 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -30,6 +30,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometJoinSuite extends CometTestBase { + import testImplicits._ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { @@ -40,6 +41,17 @@ class CometJoinSuite extends CometTestBase { } } + test("join - self join") { + val df1 = testData.select(testData("key")).as("df1") + val df2 = testData.select(testData("key")).as("df2") + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect() + .toSeq) + } + test("SortMergeJoin with unsupported key type should fall back to Spark") { withSQLConf( SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",