diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BatchIterator.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BatchIterator.java index 1fbb6053a2af..f08e833ab976 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BatchIterator.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BatchIterator.java @@ -19,6 +19,7 @@ import org.apache.gluten.metrics.IMetrics; import org.apache.gluten.metrics.NativeMetrics; +import org.apache.spark.TaskContext import org.apache.spark.sql.execution.utils.CHExecUtil; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -30,8 +31,8 @@ public class BatchIterator extends GeneralOutIterator { private final long handle; private final AtomicBoolean cancelled = new AtomicBoolean(false); - public BatchIterator(long handle) { - super(); + public BatchIterator(long handle, TaskContext context) { + super(context); this.handle = handle; } diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java index b8b4138dc8c0..3cc9c739b133 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHNativeExpressionEvaluator.java @@ -25,7 +25,7 @@ import org.apache.gluten.substrait.extensions.ExtensionBuilder; import org.apache.gluten.substrait.plan.PlanBuilder; -import org.apache.spark.SparkConf; +import org.apache.spark.{SparkConf, TaskContext}; import org.apache.spark.sql.internal.SQLConf; import java.util.Arrays; @@ -115,6 +115,6 @@ public BatchIterator createKernelWithBatchIterator( } private BatchIterator createBatchIterator(long nativeHandle) { - return new BatchIterator(nativeHandle); + return new BatchIterator(nativeHandle, TaskContext.get()); } } diff --git a/backends-velox/src/main/java/org/apache/gluten/utils/VeloxBatchAppender.java b/backends-velox/src/main/java/org/apache/gluten/utils/VeloxBatchAppender.java index 32b2289471f9..37bb3a007f0e 100644 --- a/backends-velox/src/main/java/org/apache/gluten/utils/VeloxBatchAppender.java +++ b/backends-velox/src/main/java/org/apache/gluten/utils/VeloxBatchAppender.java @@ -21,6 +21,7 @@ import org.apache.gluten.vectorized.ColumnarBatchInIterator; import org.apache.gluten.vectorized.ColumnarBatchOutIterator; +import org.apache.spark.TaskContext; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.util.Iterator; @@ -32,6 +33,6 @@ public static ColumnarBatchOutIterator create( long outHandle = VeloxBatchAppenderJniWrapper.create(runtime) .create(minOutputBatchSize, new ColumnarBatchInIterator(in)); - return new ColumnarBatchOutIterator(runtime, outHandle); + return new ColumnarBatchOutIterator(runtime, outHandle, TaskContext.get()); } } diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala index 19305e95dbec..6158cc581a18 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala @@ -22,7 +22,7 @@ import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized._ -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.SHUFFLE_COMPRESS import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} @@ -121,7 +121,8 @@ private class CelebornColumnarBatchSerializerInstance( runtime, ShuffleReaderJniWrapper .create(runtime) - .readStream(shuffleReaderHandle, byteIn)) + .readStream(shuffleReaderHandle, byteIn), + TaskContext.get()) private var cb: ColumnarBatch = _ diff --git a/gluten-core/src/main/java/org/apache/gluten/vectorized/GeneralOutIterator.java b/gluten-core/src/main/java/org/apache/gluten/vectorized/GeneralOutIterator.java index b82b0179295d..8514297ab472 100644 --- a/gluten-core/src/main/java/org/apache/gluten/vectorized/GeneralOutIterator.java +++ b/gluten-core/src/main/java/org/apache/gluten/vectorized/GeneralOutIterator.java @@ -19,6 +19,7 @@ import org.apache.gluten.exception.GlutenException; import org.apache.gluten.metrics.IMetrics; +import org.apache.spark.TaskContext; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.Serializable; @@ -28,12 +29,16 @@ public abstract class GeneralOutIterator implements AutoCloseable, Serializable, Iterator { protected final AtomicBoolean closed = new AtomicBoolean(false); + protected final TaskContext context; - public GeneralOutIterator() {} + public GeneralOutIterator(TaskContext context) { + this.context = context; + } @Override public final boolean hasNext() { try { + context.killTaskIfInterrupted(); return hasNextInternal(); } catch (Exception e) { throw new GlutenException(e); @@ -43,6 +48,7 @@ public final boolean hasNext() { @Override public final ColumnarBatch next() { try { + context.killTaskIfInterrupted(); return nextInternal(); } catch (Exception e) { throw new GlutenException(e); diff --git a/gluten-data/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java b/gluten-data/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java index 9dd0404384ad..00c85f5b16fd 100644 --- a/gluten-data/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java +++ b/gluten-data/src/main/java/org/apache/gluten/vectorized/ColumnarBatchOutIterator.java @@ -21,6 +21,7 @@ import org.apache.gluten.exec.RuntimeAware; import org.apache.gluten.metrics.IMetrics; +import org.apache.spark.TaskContext; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; @@ -29,8 +30,8 @@ public class ColumnarBatchOutIterator extends GeneralOutIterator implements Runt private final Runtime runtime; private final long iterHandle; - public ColumnarBatchOutIterator(Runtime runtime, long iterHandle) { - super(); + public ColumnarBatchOutIterator(Runtime runtime, long iterHandle, TaskContext context) { + super(context); this.runtime = runtime; this.iterHandle = iterHandle; } diff --git a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java index e5eea029b2b3..f64bf3e7c1a0 100644 --- a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java +++ b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java @@ -76,7 +76,7 @@ public GeneralOutIterator createKernelWithBatchIterator( TaskContext.get().taskAttemptId(), DebugUtil.saveInputToFile(), BackendsApiManager.getSparkPlanExecApiInstance().rewriteSpillPath(spillDirPath)); - final ColumnarBatchOutIterator out = createOutIterator(runtime, itrHandle); + final ColumnarBatchOutIterator out = createOutIterator(runtime, itrHandle, TaskContext.get()); runtime.addSpiller( new Spiller() { @Override @@ -90,7 +90,8 @@ public long spill(MemoryTarget self, Spiller.Phase phase, long size) { return out; } - private ColumnarBatchOutIterator createOutIterator(Runtime runtime, long itrHandle) { - return new ColumnarBatchOutIterator(runtime, itrHandle); + private ColumnarBatchOutIterator createOutIterator( + Runtime runtime, long itrHandle, TaskContext context) { + return new ColumnarBatchOutIterator(runtime, itrHandle, context); } } diff --git a/gluten-data/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala b/gluten-data/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala index e75abe41e4e8..ec69270a7103 100644 --- a/gluten-data/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala +++ b/gluten-data/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.exec.Runtimes import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.utils.ArrowAbiUtil -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.shuffle.GlutenShuffleUtils @@ -139,7 +139,8 @@ private class ColumnarBatchSerializerInstance( runtime, ShuffleReaderJniWrapper .create(runtime) - .readStream(shuffleReaderHandle, byteIn)) + .readStream(shuffleReaderHandle, byteIn), + TaskContext.get()) private var cb: ColumnarBatch = _