diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 31318ff0aa0c..cd5196aa8c13 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -38,6 +38,8 @@ const std::string kIgnoreMissingFiles = "spark.sql.files.ignoreMissingFiles"; const std::string kDefaultSessionTimezone = "spark.gluten.sql.session.timeZone.default"; +const std::string kSparkOverheadMemory = "spark.gluten.memoryOverhead.size.in.bytes"; + const std::string kSparkOffHeapMemory = "spark.gluten.memory.offHeap.size.in.bytes"; const std::string kSparkTaskOffHeapMemory = "spark.gluten.memory.task.offHeap.size.in.bytes"; diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala index faadf3a27cb3..d298ab481a0c 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala @@ -84,12 +84,13 @@ object SparkResourceUtil extends Logging { } def getMemoryOverheadSize(conf: SparkConf): Long = { - conf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse { + val overheadMib = conf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse { val executorMemMib = conf.get(EXECUTOR_MEMORY) val factor = conf.getDouble("spark.executor.memoryOverheadFactor", 0.1D) val minMib = conf.getLong("spark.executor.minMemoryOverhead", 384L) - ByteUnit.MiB.toBytes((executorMemMib * factor).toLong max minMib) + (executorMemMib * factor).toLong max minMib } + ByteUnit.MiB.toBytes(overheadMib) } }