diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java index 382dc6d1a..1bc394234 100644 --- a/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java +++ b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java @@ -52,22 +52,25 @@ public static synchronized CometShuffleMemoryAllocatorTrait getInstance( (boolean) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_UNIFIED_MEMORY_ALLOCATOR_IN_TEST().get(); - if (INSTANCE == null) { - if (isSparkTesting && !useUnifiedMemAllocator) { + if (isSparkTesting && !useUnifiedMemAllocator) { + if (INSTANCE == null) { + // CometTestShuffleMemoryAllocator handles pages by itself so it can be a singleton. INSTANCE = new CometTestShuffleMemoryAllocator(conf, taskMemoryManager, pageSize); - } else { - if (taskMemoryManager.getTungstenMemoryMode() != MemoryMode.OFF_HEAP) { - throw new IllegalArgumentException( - "CometShuffleMemoryAllocator should be used with off-heap " - + "memory mode, but got " - + taskMemoryManager.getTungstenMemoryMode()); - } - - INSTANCE = new CometShuffleMemoryAllocator(taskMemoryManager, pageSize); } - } - return INSTANCE; + return INSTANCE; + } else { + if (taskMemoryManager.getTungstenMemoryMode() != MemoryMode.OFF_HEAP) { + throw new IllegalArgumentException( + "CometShuffleMemoryAllocator should be used with off-heap " + + "memory mode, but got " + + taskMemoryManager.getTungstenMemoryMode()); + } + + // CometShuffleMemoryAllocator stores pages in TaskMemoryManager which is not singleton, + // but one instance per task. So we need to create a new instance for each task. + return new CometShuffleMemoryAllocator(taskMemoryManager, pageSize); + } } CometShuffleMemoryAllocator(TaskMemoryManager taskMemoryManager, long pageSize) {