From 5c8b8f48a24a9d34fbda514bb2b0c440b0e74297 Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Wed, 15 Jan 2025 15:18:21 +0800 Subject: [PATCH 1/3] fix --- .../sql/execution/WholeStageCodegenExec.scala | 12 ++- ...ashAggregateCodegenInterruptionSuite.scala | 100 ++++++++++++++++++ 2 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 920f61574770d..a7b39147bd709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -466,6 +466,8 @@ trait InputRDDCodegen extends CodegenSupport { val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", forceInline = true) val row = ctx.freshName("row") + val numOutputRows = ctx.freshName("numOutputRows") + val taskInterrupted = ctx.freshName("taskInterrupted") val outputVars = if (createUnsafeProjection) { // creating the vars will make the parent consume add an unsafe projection. @@ -485,8 +487,16 @@ trait InputRDDCodegen extends CodegenSupport { "" } s""" - | while ($limitNotReachedCond $input.hasNext()) { + | long $numOutputRows = 0; + | boolean $taskInterrupted = false; + | while ($limitNotReachedCond !$taskInterrupted && $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows++; + | if ($numOutputRows % 1000 == 0) { + | if (org.apache.spark.TaskContext.get() != null) { + | $taskInterrupted = org.apache.spark.TaskContext.get().isInterrupted(); + | } + | } | ${updateNumOutputRowsMetrics} | ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim} | ${shouldStopCheckCode} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala new file mode 100644 index 0000000000000..3526791780333 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable +import scala.concurrent.Future +import scala.concurrent.duration._ + +import org.apache.logging.log4j.Level + +import org.apache.spark.{SparkConf, TaskEndReason, TaskKilled} +import org.apache.spark.scheduler.{SparkListenerJobEnd, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class HashAggregateCodegenInterruptionSuite extends QueryTest with SharedSparkSession { + + override def sparkConf: SparkConf = { + super + .sparkConf + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "10s") + } + + test("SPARK-50806: HashAggregate codegen should be interrupted on task cancellation") { + import testImplicits._ + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> true.toString, + SQLConf.INTERRUPT_ON_CANCEL.key -> false.toString) { + var taskId = -1L + var isJobEnded = false + var taskEndReasons = new mutable.ArrayBuffer[TaskEndReason] + spark.sparkContext.addSparkListener(new org.apache.spark.scheduler.SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + taskId = taskStart.taskInfo.taskId + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEndReasons += taskEnd.reason + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + isJobEnded = true + } + }) + + val logAppender = new LogAppender("") + withLogAppender(logAppender, level = Some(Level.INFO)) { + spark.sparkContext.setJobGroup("SPARK-50806", "SPARK-50806", false) + val slowDF = spark.range(1, 100000).rdd.mapPartitions { iter => + new Iterator[Long] { + var cnt = 0 + override def hasNext: Boolean = iter.hasNext + override def next(): Long = { + if (cnt % 1000 == 0) { + Thread.sleep(1000) + } + cnt += 1 + iter.next() + } + } + }.toDF("id") + val aggDF = slowDF.selectExpr("id % 10 as key").groupBy("key").agg(sum("key")) + import scala.concurrent.ExecutionContext.global + Future { + aggDF.collect() + }(global) + // Leave some time for the query to start running + Thread.sleep(5000) + spark.sparkContext.cancelJobGroup("SPARK-50806") + eventually(timeout(1.minute)) { + assert(isJobEnded) + assert(taskEndReasons.length === 2) + assert(taskEndReasons.forall(_.isInstanceOf[TaskKilled])) + val logs = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage) + assert(!logs.exists( + _.contains(s"Killed task $taskId could not be stopped within 10000 ms")) + ) + } + } + } + } +} From f5e62a341cbd77c61db3eac343d290d4fe5f3eff Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Wed, 15 Jan 2025 16:25:54 +0800 Subject: [PATCH 2/3] update comment --- .../sql/execution/HashAggregateCodegenInterruptionSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala index 3526791780333..9c1ccf7748d49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala @@ -64,6 +64,9 @@ class HashAggregateCodegenInterruptionSuite extends QueryTest with SharedSparkSe val logAppender = new LogAppender("") withLogAppender(logAppender, level = Some(Level.INFO)) { spark.sparkContext.setJobGroup("SPARK-50806", "SPARK-50806", false) + // The dataset is set to 100k as we are monitoring interruptions for every 1k rows. Two + // tasks (50 seconds each, totaling 100k / 2) exceed `spark.task.reaper.killTimeout` (10s), + // which should provide a proper test for the interruption behavior. val slowDF = spark.range(1, 100000).rdd.mapPartitions { iter => new Iterator[Long] { var cnt = 0 From 05b2f64c38d82d5a4afbc50dd382aebbeb092a9b Mon Sep 17 00:00:00 2001 From: Yi Wu Date: Thu, 23 Jan 2025 12:25:03 +0800 Subject: [PATCH 3/3] set numPartitions for range --- .../sql/execution/HashAggregateCodegenInterruptionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala index 9c1ccf7748d49..8c716dd3f3a64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HashAggregateCodegenInterruptionSuite.scala @@ -67,7 +67,7 @@ class HashAggregateCodegenInterruptionSuite extends QueryTest with SharedSparkSe // The dataset is set to 100k as we are monitoring interruptions for every 1k rows. Two // tasks (50 seconds each, totaling 100k / 2) exceed `spark.task.reaper.killTimeout` (10s), // which should provide a proper test for the interruption behavior. - val slowDF = spark.range(1, 100000).rdd.mapPartitions { iter => + val slowDF = spark.range(1, 100000, 1, 2).rdd.mapPartitions { iter => new Iterator[Long] { var cnt = 0 override def hasNext: Boolean = iter.hasNext