From a31df44bd6833ee5d10960bfe319b41c9f0c9691 Mon Sep 17 00:00:00 2001 From: Nicholas Jiang Date: Mon, 15 Jul 2024 17:40:37 +0800 Subject: [PATCH] [CH][CELEBORN] CHCelebornColumnarBatchSerializer uses AtomicBoolean to identify whether to call close() to avoid calling close() twice situation (#6455) [CH][CELEBORN] CHCelebornColumnarBatchSerializer uses AtomicBoolean to identify whether to call close() to avoid calling close() twice situation --- .../CHCelebornColumnarBatchSerializer.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala index 39aefb01c2ae..3619855f74ed 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala @@ -32,6 +32,7 @@ import org.apache.celeborn.client.read.CelebornInputStream import java.io._ import java.nio.ByteBuffer import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean import scala.reflect.ClassTag @@ -74,7 +75,8 @@ private class CHCelebornColumnarBatchSerializerInstance( private var numBatchesTotal: Long = _ private var numRowsTotal: Long = _ - private var isClosed: Boolean = false + // Otherwise calling close() twice would cause replication of metrics. + private val closeCalled: AtomicBoolean = new AtomicBoolean(false) override def asIterator: Iterator[Any] = { // This method is never called by shuffle code. @@ -153,18 +155,18 @@ private class CHCelebornColumnarBatchSerializerInstance( } override def close(): Unit = { - if (!isClosed) { - if (numBatchesTotal > 0) { - readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) - } - numOutputRows += numRowsTotal - if (cb != null) { - cb.close() - cb = null - } - closeReader() - isClosed = true + if (!closeCalled.compareAndSet(false, true)) { + return + } + if (numBatchesTotal > 0) { + readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) + } + numOutputRows += numRowsTotal + if (cb != null) { + cb.close() + cb = null } + closeReader() } def getReader: CHStreamReader = {