diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index e90a3821a41ba..f33e767e13e0f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -353,9 +353,9 @@ class CollectMetricIterator( if (!metricsUpdated) { val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] if (wholeStagePipeline) { - outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows()) + outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows) outputVectorCount = - Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches()) + Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches) } nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala index 53f85d84672b9..b0595fd05dd71 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.gluten.execution.ColumnarNativeIterator import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus import org.apache.spark.sql.vectorized.ColumnarBatch @@ -189,34 +189,30 @@ class CHColumnarShuffleWriter[K, V]( } +case class OutputMetrics(totalRows: Long, totalBatches: Long) + object CHColumnarShuffleWriter { - private val TOTAL_OUTPUT_ROWS = "total_output_rows" - private val TOTAL_OUTPUT_BATCHES = "total_output_batches" + private var metric = new ThreadLocal[OutputMetrics]() // Pass the statistics of the last operator before shuffle to CollectMetricIterator. def setOutputMetrics(splitResult: CHSplitResult): Unit = { - TaskContext - .get() - .getLocalProperties - .setProperty(TOTAL_OUTPUT_ROWS, splitResult.getTotalRows.toString) - TaskContext - .get() - .getLocalProperties - .setProperty(TOTAL_OUTPUT_BATCHES, splitResult.getTotalBatches.toString) + metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches)) } - def getTotalOutputRows(): Long = { - val output_rows = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_ROWS) - var output_rows_value = 0L - if (output_rows != null && output_rows.nonEmpty) output_rows_value = output_rows.toLong - output_rows_value + def getTotalOutputRows: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalRows + } } - def getTotalOutputBatches(): Long = { - val output_batches = TaskContext.get().getLocalProperty(TOTAL_OUTPUT_BATCHES) - var output_batches_value = 0L - if (output_batches != null) output_batches_value = output_batches.toLong - output_batches_value + def getTotalOutputBatches: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalBatches + } } } diff --git a/cpp-ch/local-engine/Parser/LocalExecutor.h b/cpp-ch/local-engine/Parser/LocalExecutor.h index 7b73dc7da0f7c..c2f4d2e309e2d 100644 --- a/cpp-ch/local-engine/Parser/LocalExecutor.h +++ b/cpp-ch/local-engine/Parser/LocalExecutor.h @@ -35,7 +35,7 @@ struct SparkBuffer class LocalExecutor : public BlockIterator { public: - static LocalExecutor * getCurrentExecutor() { return current_executor; } + static std::optional getCurrentExecutor() { return current_executor; } static void resetCurrentExecutor() { current_executor = nullptr; } LocalExecutor(DB::QueryPlanPtr query_plan, DB::QueryPipelineBuilderPtr pipeline, bool dump_pipeline_ = false); ~LocalExecutor(); diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 50bc6a04ea634..f21d9d0037e7e 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -544,11 +544,12 @@ local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env, jobject rss_pusher = nullptr ) { - auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor(); - chassert(current_executor); + auto current_executor = local_engine::LocalExecutor::getCurrentExecutor(); local_engine::SplitterHolder * splitter = nullptr; - // handle fallback, whole stage fallback or partial fallback - if (!current_executor || current_executor->fallbackMode()) + // There are two modes of fallback, one is full fallback but uses columnar shuffle, + // and the other is partial fallback that creates one or more LocalExecutor. + // In full fallback, the current executor does not exist. + if (!current_executor.has_value() || current_executor.value()->fallbackMode()) { auto first_block = local_engine::SourceFromJavaIter::peekBlock(env, iter); if (first_block.has_value()) @@ -574,9 +575,9 @@ local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env, splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(current_executor->getHeader().cloneEmpty(), name, options, rss_pusher)}; // TODO support multiple sinks splitter->exchange_manager->initSinks(1); - current_executor->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);}); + current_executor.value()->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);}); // execute pipeline - current_executor->execute(); + current_executor.value()->execute(); } return splitter; }