From e9882d192b6a6ef49ddd12fe680870ea3f72fe96 Mon Sep 17 00:00:00 2001 From: Jin Chengcheng Date: Tue, 13 Aug 2024 16:59:29 +0800 Subject: [PATCH] [VL] Fix parquet write sort spill OOM (#6480) --- package/pom.xml | 6 + .../apache/spark/memory/MemoryConsumer.java | 154 +++ .../spark/memory/TaskMemoryManager.java | 479 +++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 914 ++++++++++++++++++ 4 files changed, 1553 insertions(+) create mode 100644 shims/spark32/src/main/java/org/apache/spark/memory/MemoryConsumer.java create mode 100644 shims/spark32/src/main/java/org/apache/spark/memory/TaskMemoryManager.java create mode 100644 shims/spark32/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java diff --git a/package/pom.xml b/package/pom.xml index ae1432770f8a..0436889c090b 100644 --- a/package/pom.xml +++ b/package/pom.xml @@ -330,6 +330,12 @@ org.apache.spark.sql.execution.datasources.WriterBucketSpec$ org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$ + org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter$ + org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter$SpillableIterator + org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter$ChainedIterator + org.apache.spark.memory.MemoryConsumer + org.apache.spark.memory.TaskMemoryManager com.google.protobuf.* diff --git a/shims/spark32/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/shims/spark32/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 000000000000..bfe699c13a35 --- /dev/null +++ b/shims/spark32/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.memory; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import java.io.IOException; + +/** + * A memory consumer of {@link TaskMemoryManager} that supports spilling. + * + *

Note: this only supports allocation / spilling of Tungsten memory. + */ +public abstract class MemoryConsumer { + + protected final TaskMemoryManager taskMemoryManager; + private final long pageSize; + private final MemoryMode mode; + protected long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + this.mode = mode; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, MemoryMode mode) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes(), mode); + } + + public long getTaskAttemptId() { + return this.taskMemoryManager.getTaskAttemptId(); + } + + /** Returns the memory mode, {@link MemoryMode#ON_HEAP} or {@link MemoryMode#OFF_HEAP}. */ + public MemoryMode getMode() { + return mode; + } + + /** Returns the size of used memory in bytes. */ + public long getUsed() { + return used; + } + + /** Force spill during building. */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager when there + * is not enough memory for the task. + * + *

This should be implemented by subclass. + * + *

Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + *

Note: today, this only frees Tungsten-managed pages. + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + public long forceSpill(long size, MemoryConsumer trigger) throws IOException { + return 0; + } + + /** + * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError` if + * Spark doesn't have enough memory for this allocation, or throw `TooLargePageException` if this + * `LongArray` is too large to fit in a single page. The caller side should take care of these two + * exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * + * @throws SparkOutOfMemoryError + * @throws TooLargePageException + */ + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + throwOom(page, required); + } + used += required; + return new LongArray(page); + } + + /** Frees a LongArray. */ + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * @throws SparkOutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + throwOom(page, required); + } + used += page.size(); + return page; + } + + /** Free a memory block. */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } + + /** Allocates memory of `size`. */ + public long acquireMemory(long size) { + long granted = taskMemoryManager.acquireExecutionMemory(size, this); + used += granted; + return granted; + } + + /** Release N bytes of memory. */ + public void freeMemory(long size) { + taskMemoryManager.releaseExecutionMemory(size, this); + used -= size; + } + + private void throwOom(final MemoryBlock page, final long required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError( + "UNABLE_TO_ACQUIRE_MEMORY", new String[] {Long.toString(required), Long.toString(got)}); + // checkstyle.on: RegexpSinglelineJava + } +} diff --git a/shims/spark32/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/shims/spark32/src/main/java/org/apache/spark/memory/TaskMemoryManager.java new file mode 100644 index 000000000000..57a92110c06a --- /dev/null +++ b/shims/spark32/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.memory; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.io.IOException; +import java.nio.channels.ClosedByInterruptException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * Manages the memory allocated by an individual task. + * + *

Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit + * longs. In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, + * memory is addressed by the combination of a base Object reference and a 64-bit offset within that + * object. This is a problem when we want to store pointers to data structures inside of other + * structures, such as record pointers inside hashmaps or sorting buffers. Even if we decided to use + * 128 bits to address memory, we can't just store the address of the base object since it's not + * guaranteed to remain stable as the heap gets reorganized due to GC. + * + *

Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + * + *

This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * (2^31 - 1) * 8 bytes, which is + * approximately 140 terabytes of memory. + */ +public class TaskMemoryManager { + + private static final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's + * maximum page size is limited by the maximum amount of data that can be stored in a long[] + * array, which is (2^31 - 1) * 8 bytes (or about 17 gigabytes). Therefore, we cap this at 17 + * gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both on- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an on-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** Bitmap for tracking free pages. */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + private final MemoryManager memoryManager; + + private final long taskAttemptId; + + /** + * Tracks whether we're on-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + final MemoryMode tungstenMemoryMode; + + /** Tracks spillable memory consumers. */ + @GuardedBy("this") + private final HashSet consumers; + + /** The amount of memory that is acquired but not used. */ + private volatile long acquiredButNotUsed = 0L; + + /** Construct a new TaskMemoryManager. */ + public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); + this.memoryManager = memoryManager; + this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); + } + + public long getTaskAttemptId() { + return taskAttemptId; + } + + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + long got = acquireExecutionMemory(required, consumer, false); + if (got < required) { + got += acquireExecutionMemory(required, consumer, true); + } + return got; + } + + /** + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call spill() of + * consumers to release more memory. + * + * @return number of bytes successfully granted (<= N). + */ + public long acquireExecutionMemory(long required, MemoryConsumer consumer, boolean force) { + assert (required >= 0); + assert (consumer != null); + MemoryMode mode = consumer.getMode(); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); + + // Try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + // Sort the consumers according their memory usage. So we avoid spilling the same consumer + // which is just spilled in last few times and re-spilling on it will produce many small + // spill files. + TreeMap> sortedConsumers = new TreeMap<>(); + for (MemoryConsumer c : consumers) { + if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { + long key = c.getUsed(); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + list.add(c); + } + } + while (!sortedConsumers.isEmpty()) { + // Get the consumer using the least memory more than the remaining required memory. + Map.Entry> currentEntry = + sortedConsumers.ceilingEntry(required - got); + // No consumer has used memory more than the remaining required memory. + // Get the consumer of largest used memory. + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List cList = currentEntry.getValue(); + MemoryConsumer c = cList.get(cList.size() - 1); + try { + long released = + force ? c.forceSpill(required - got, consumer) : c.spill(required - got, consumer); + if (released > 0) { + logger.debug( + "Task {} released {} from {} for {}", + taskAttemptId, + Utils.bytesToString(released), + c, + consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; + } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + c, e); + throw new RuntimeException(e.getMessage()); + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError( + "error while calling spill() on " + c + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava + } + } + } + + // Attempt to free up memory by self-spilling. + // + // When our spill handler releases memory, `ExecutionMemoryPool#releaseMemory()` will + // immediately notify other tasks that memory has been freed, and they may acquire the + // newly-freed memory before we have a chance to do so (SPARK-35486). In that case, we will + // try again in the next loop iteration. + while (got < required) { + try { + long released = + force + ? consumer.forceSpill(required - got, consumer) + : consumer.spill(required - got, consumer); + if (released > 0) { + logger.debug( + "Task {} released {} from itself ({})", + taskAttemptId, + Utils.bytesToString(released), + consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + } else { + // Self-spilling could not free up any more memory. + break; + } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + consumer, e); + throw new RuntimeException(e.getMessage()); + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError( + "error while calling spill() on " + consumer + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava + } + } + + consumers.add(consumer); + logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } + } + + /** Release N bytes of execution memory for a MemoryConsumer. */ + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); + memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode()); + } + + /** Dump the memory usage of all consumers. */ + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + long memoryAccountedForByConsumers = 0; + for (MemoryConsumer c : consumers) { + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); + } + } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) + - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, + taskAttemptId); + logger.info( + "{} bytes of memory are used for execution and {} bytes of memory are used for storage", + memoryManager.executionMemoryUsed(), + memoryManager.storageMemoryUsed()); + } + } + + /** Return the page size in bytes. */ + public long pageSizeBytes() { + return memoryManager.pageSizeBytes(); + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of Tungsten memory that will be shared between operators. + * + *

Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. + * + * @throws TooLargePageException + */ + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { + assert (consumer != null); + assert (consumer.getMode() == tungstenMemoryMode); + if (size > MAXIMUM_PAGE_SIZE_BYTES) { + throw new TooLargePageException(size); + } + + long acquired = acquireExecutionMemory(size, consumer); + if (acquired <= 0) { + return null; + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + releaseExecutionMemory(acquired, consumer); + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + MemoryBlock page = null; + try { + page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); + // there is no enough memory actually, it means the actual free memory is smaller than + // MemoryManager thought, we should keep the acquired memory. + synchronized (this) { + acquiredButNotUsed += acquired; + allocatedPages.clear(pageNumber); + } + // this could trigger spilling to free some pages. + return allocatePage(size, consumer); + } + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); + } + return page; + } + + /** Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ + public void freePage(MemoryBlock page, MemoryConsumer consumer) { + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) + : "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) + : "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) + : "Called freePage() on a memory block that has already been freed"; + assert (allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + memoryManager.tungstenMemoryAllocator().free(page); + releaseExecutionMemory(pageSize, consumer); + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. This + * address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an UNSAFE call (e.g. + * page.baseOffset() + something). + * @return an encoded page address. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } + + /** + * Get the page associated with an address encoded by {@link + * TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + assert (page.getBaseObject() != null); + return page.getBaseObject(); + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by {@link + * TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { + return offsetInPage; + } else { + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + synchronized (this) { + for (MemoryConsumer c : consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } + } + consumers.clear(); + + for (MemoryBlock page : pageTable) { + if (page != null) { + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + memoryManager.tungstenMemoryAllocator().free(page); + } + } + Arrays.fill(pageTable, null); + } + + // release the memory that is not used by any consumer (acquired for pages in tungsten mode). + memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); + } + + /** Returns the memory consumption, in bytes, for the current task. */ + public long getMemoryConsumptionForThisTask() { + return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); + } + + /** Returns Tungsten memory mode */ + public MemoryMode getTungstenMemoryMode() { + return tungstenMemoryMode; + } +} diff --git a/shims/spark32/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/shims/spark32/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 000000000000..0f3bb8cc7d2c --- /dev/null +++ b/shims/spark32/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,914 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.function.Supplier; + +/** External sorter based on {@link UnsafeInMemorySorter}. */ +public final class UnsafeExternalSorter extends MemoryConsumer { + + private static final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + @Nullable private final PrefixComparator prefixComparator; + + /** + * {@link RecordComparator} may probably keep the reference to the records they compared last + * time, so we should not keep a {@link RecordComparator} instance inside {@link + * UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced by {@link + * TaskContext} and thus can not be garbage collected until the end of the task. + */ + @Nullable private final Supplier recordComparatorSupplier; + + private final TaskMemoryManager taskMemoryManager; + private final BlockManager blockManager; + private final SerializerManager serializerManager; + private final TaskContext taskContext; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** Force this sorter to spill when there are this many elements in memory. */ + private final int numElementsForSpillThreshold; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList<>(); + + private final LinkedList spillWriters = new LinkedList<>(); + + // These variables are reset after spilling: + @Nullable private volatile UnsafeInMemorySorter inMemSorter; + + private MemoryBlock currentPage = null; + private long pageCursor = -1; + private long peakMemoryUsedBytes = 0; + private long totalSpillBytes = 0L; + private long totalSortTimeNanos = 0L; + private volatile SpillableIterator readingIterator = null; + + public static UnsafeExternalSorter createWithExistingInMemorySorter( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + UnsafeInMemorySorter inMemorySorter, + long existingMemoryConsumption) + throws IOException { + UnsafeExternalSorter sorter = + new UnsafeExternalSorter( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + recordComparatorSupplier, + prefixComparator, + initialSize, + pageSizeBytes, + numElementsForSpillThreshold, + inMemorySorter, + false /* ignored */); + sorter.spill(Long.MAX_VALUE, sorter); + taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption); + sorter.totalSpillBytes += existingMemoryConsumption; + // The external sorter will be used to insert records, in-memory sorter is not needed. + sorter.inMemSorter = null; + return sorter; + } + + public static UnsafeExternalSorter create( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + boolean canUseRadixSort) { + return new UnsafeExternalSorter( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + recordComparatorSupplier, + prefixComparator, + initialSize, + pageSizeBytes, + numElementsForSpillThreshold, + null, + canUseRadixSort); + } + + private UnsafeExternalSorter( + TaskMemoryManager taskMemoryManager, + BlockManager blockManager, + SerializerManager serializerManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + int numElementsForSpillThreshold, + @Nullable UnsafeInMemorySorter existingInMemorySorter, + boolean canUseRadixSort) { + super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); + this.taskMemoryManager = taskMemoryManager; + this.blockManager = blockManager; + this.serializerManager = serializerManager; + this.taskContext = taskContext; + this.recordComparatorSupplier = recordComparatorSupplier; + this.prefixComparator = prefixComparator; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 + this.fileBufferSizeBytes = 32 * 1024; + + if (existingInMemorySorter == null) { + RecordComparator comparator = null; + if (recordComparatorSupplier != null) { + comparator = recordComparatorSupplier.get(); + } + this.inMemSorter = + new UnsafeInMemorySorter( + this, taskMemoryManager, comparator, prefixComparator, initialSize, canUseRadixSort); + } else { + this.inMemSorter = existingInMemorySorter; + } + this.peakMemoryUsedBytes = getMemoryUsage(); + this.numElementsForSpillThreshold = numElementsForSpillThreshold; + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addTaskCompletionListener( + context -> { + cleanupResources(); + }); + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a new page + * or spill when we see the next record. + */ + @VisibleForTesting + public void closeCurrentPage() { + if (currentPage != null) { + pageCursor = currentPage.getBaseOffset() + currentPage.size(); + } + } + + @Override + public long forceSpill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this && readingIterator != null) { + return readingIterator.spill(); + } + if (getTaskAttemptId() != trigger.getTaskAttemptId()) { + return 0; // fail + } + + if (inMemSorter == null || inMemSorter.numRecords() <= 0) { + // There could still be some memory allocated when there are no records in the in-memory + // sorter. We will not spill it however, to ensure that we can always process at least one + // record before spilling. See the comments in `allocateMemoryForRecordIfNecessary` for why + // this is necessary. + return 0L; + } + + logger.info( + "Thread {} force spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter( + blockManager, fileBufferSizeBytes, writeMetrics, inMemSorter.numRecords()); + spillWriters.add(spillWriter); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); + + final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.freeMemory(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory + // pages, we might not be able to get memory for the pointer array. + + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); + totalSpillBytes += spillSize; + return spillSize; + } + + /** Sort and spill the current records in response to memory pressure. */ + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this) { + if (readingIterator != null) { + return readingIterator.spill(); + } + return 0L; // this should throw exception + } + + if (inMemSorter == null || inMemSorter.numRecords() <= 0) { + // There could still be some memory allocated when there are no records in the in-memory + // sorter. We will not spill it however, to ensure that we can always process at least one + // record before spilling. See the comments in `allocateMemoryForRecordIfNecessary` for why + // this is necessary. + return 0L; + } + + logger.info( + "Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter( + blockManager, fileBufferSizeBytes, writeMetrics, inMemSorter.numRecords()); + spillWriters.add(spillWriter); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); + + final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.freeMemory(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory + // pages, we might not be able to get memory for the pointer array. + + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); + totalSpillBytes += spillSize; + return spillSize; + } + + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer + * array. + */ + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** Return the peak memory used so far, in bytes. */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** @return the total amount of time spent sorting data (in-memory only). */ + public long getSortTimeNanos() { + UnsafeInMemorySorter sorter = inMemSorter; + if (sorter != null) { + return sorter.getSortTimeNanos(); + } + return totalSortTimeNanos; + } + + /** Return the total number of bytes that has been spilled into disk so far. */ + public long getSpillSize() { + return totalSpillBytes; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + /** + * Free this sorter's data pages. + * + * @return the number of bytes freed. + */ + private long freeMemory() { + List pagesToFree = clearAndGetAllocatedPagesToFree(); + long memoryFreed = 0; + for (MemoryBlock block : pagesToFree) { + memoryFreed += block.size(); + freePage(block); + } + return memoryFreed; + } + + /** + * Clear the allocated pages and return the list of allocated pages to let the caller free the + * page. This is to prevent the deadlock by nested locks if the caller locks the + * UnsafeExternalSorter and call freePage which locks the TaskMemoryManager and cause nested + * locks. + * + * @return list of allocated pages to free + */ + private List clearAndGetAllocatedPagesToFree() { + updatePeakMemoryUsed(); + List pagesToFree = new LinkedList<>(allocatedPages); + allocatedPages.clear(); + currentPage = null; + pageCursor = 0; + return pagesToFree; + } + + /** Deletes any spill files created by this sorter. */ + private void deleteSpillFiles() { + for (UnsafeSorterSpillWriter spill : spillWriters) { + File file = spill.getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + + /** Frees this sorter's in-memory data structures and cleans up its spill files. */ + public void cleanupResources() { + // To avoid deadlocks, we can't call methods that lock the TaskMemoryManager + // (such as various free() methods) while synchronizing on the UnsafeExternalSorter. + // Instead, we will manipulate UnsafeExternalSorter state inside the synchronized + // lock and perform the actual free() calls outside it. + UnsafeInMemorySorter inMemSorterToFree = null; + List pagesToFree = null; + try { + synchronized (this) { + deleteSpillFiles(); + pagesToFree = clearAndGetAllocatedPagesToFree(); + if (inMemSorter != null) { + inMemSorterToFree = inMemSorter; + inMemSorter = null; + } + } + } finally { + for (MemoryBlock pageToFree : pagesToFree) { + freePage(pageToFree); + } + if (inMemSorterToFree != null) { + inMemSorterToFree.freeMemory(); + } + } + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert (inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + if (inMemSorter.numRecords() <= 0) { + // Spilling was triggered just before this method was called. The pointer array was freed + // during the spill, so a new pointer array needs to be allocated here. + LongArray array = allocateArray(inMemSorter.getInitialSize()); + inMemSorter.expandPointerArray(array); + return; + } + + long used = inMemSorter.getMemoryUsage(); + LongArray array = null; + try { + // could trigger spilling + array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + } catch (SparkOutOfMemoryError e) { + if (inMemSorter.numRecords() > 0) { + logger.error("Unable to grow the pointer array"); + throw e; + } + // The new array could not be allocated, but that is not an issue as it is longer needed, + // as all records were spilled. + } + + if (inMemSorter.numRecords() <= 0) { + // Spilling was triggered while trying to allocate the new array. + if (array != null) { + // We succeeded in allocating the new array, but, since all records were spilled, a + // smaller array would also suffice. + freeArray(array); + } + // The pointer array was freed during the spill, so a new pointer array needs to be + // allocated here. + array = allocateArray(inMemSorter.getInitialSize()); + } + inMemSorter.expandPointerArray(array); + } + } + + /** + * Allocates an additional page in order to insert an additional record. This will request + * additional memory from the memory manager and spill if the requested memory can not be + * obtained. + * + * @param required the required space in the data page, in bytes, including space for storing the + * record size. + */ + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null + || pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { + // TODO: try to find space on previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the memory manager and spill if the requested memory can not be obtained. + * + * @param required the required space in the data page, in bytes, including space for storing the + * record size. + */ + private void allocateMemoryForRecordIfNecessary(int required) throws IOException { + // Step 1: + // Ensure that the pointer array has space for another record. This may cause a spill. + growPointerArrayIfNecessary(); + // Step 2: + // Ensure that the last page has space for another record. This may cause a spill. + acquireNewPageIfNecessary(required); + // Step 3: + // The allocation in step 2 could have caused a spill, which would have freed the pointer + // array allocated in step 1. Therefore we need to check again whether we have to allocate + // a new pointer array. + // + // If the allocation in this step causes a spill event then it will not cause the page + // allocated in the previous step to be freed. The function `spill` only frees memory if at + // least one record has been inserted in the in-memory sorter. This will not be the case if + // we have spilled in the previous step. + // + // If we did not spill in the previous step then `growPointerArrayIfNecessary` will be a + // no-op that does not allocate any memory, and therefore can't cause a spill event. + // + // Thus there is no need to call `acquireNewPageIfNecessary` again after this step. + growPointerArrayIfNecessary(); + } + + /** Write a record to the sorter. */ + public void insertRecord( + Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull) + throws IOException { + + assert (inMemSorter != null); + if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { + logger.info( + "Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); + spill(); + } + + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; + allocateMemoryForRecordIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + } + + /** + * Write a key-value record to the sorter. The key and value will be put together in-memory, using + * the following format: + * + *

record length (4 bytes), key length (4 bytes), key data, value data + * + *

record length = key length + value length + 4 + */ + public void insertKVRecord( + Object keyBase, + long keyOffset, + int keyLen, + Object valueBase, + long valueOffset, + int valueLen, + long prefix, + boolean prefixIsNull) + throws IOException { + + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final int required = keyLen + valueLen + (2 * uaoSize); + allocateMemoryForRecordIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen + valueLen + uaoSize); + pageCursor += uaoSize; + UnsafeAlignedOffset.putSize(base, pageCursor, keyLen); + pageCursor += uaoSize; + Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen); + pageCursor += keyLen; + Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen); + pageCursor += valueLen; + + assert (inMemSorter != null); + inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + } + + /** Merges another UnsafeExternalSorters into this one, the other one will be emptied. */ + public void merge(UnsafeExternalSorter other) throws IOException { + other.spill(); + totalSpillBytes += other.totalSpillBytes; + spillWriters.addAll(other.spillWriters); + // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`. + other.spillWriters.clear(); + other.cleanupResources(); + } + + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` after + * consuming this iterator. + */ + public UnsafeSorterIterator getSortedIterator() throws IOException { + assert (recordComparatorSupplier != null); + if (spillWriters.isEmpty()) { + assert (inMemSorter != null); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + return readingIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger( + recordComparatorSupplier.get(), prefixComparator, spillWriters.size()); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); + } + if (inMemSorter != null) { + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + spillMerger.addSpillIfNotEmpty(readingIterator); + } + return spillMerger.getSortedIterator(); + } + } + + @VisibleForTesting + boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + + private static void spillIterator( + UnsafeSorterIterator inMemIterator, UnsafeSorterSpillWriter spillWriter) throws IOException { + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + } + + /** An UnsafeSorterIterator that support spilling. */ + class SpillableIterator extends UnsafeSorterIterator { + private UnsafeSorterIterator upstream; + private MemoryBlock lastPage = null; + private boolean loaded = false; + private int numRecords; + + private Object currentBaseObject; + private long currentBaseOffset; + private int currentRecordLength; + private long currentKeyPrefix; + + SpillableIterator(UnsafeSorterIterator inMemIterator) { + this.upstream = inMemIterator; + this.numRecords = inMemIterator.getNumRecords(); + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + + public long spill() throws IOException { + UnsafeInMemorySorter inMemSorterToFree = null; + List pagesToFree = new LinkedList<>(); + try { + synchronized (this) { + if (inMemSorter == null) { + return 0L; + } + + long currentPageNumber = upstream.getCurrentPageNumber(); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + if (numRecords > 0) { + // Iterate over the records that have not been returned and spill them. + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter( + blockManager, fileBufferSizeBytes, writeMetrics, numRecords); + spillIterator(upstream, spillWriter); + spillWriters.add(spillWriter); + upstream = spillWriter.getReader(serializerManager); + } else { + // Nothing to spill as all records have been read already, but do not return yet, as the + // memory still has to be freed. + upstream = null; + } + + long released = 0L; + synchronized (UnsafeExternalSorter.this) { + // release the pages except the one that is used. There can still be a caller that + // is accessing the current record. We free this page in that caller's next loadNext() + // call. + for (MemoryBlock page : allocatedPages) { + if (!loaded || page.pageNumber != currentPageNumber) { + released += page.size(); + // Do not free the page, while we are locking `SpillableIterator`. The `freePage` + // method locks the `TaskMemoryManager`, and it's not a good idea to lock 2 objects + // in sequence. We may hit dead lock if another thread locks `TaskMemoryManager` + // and `SpillableIterator` in sequence, which may happen in + // `TaskMemoryManager.acquireExecutionMemory`. + pagesToFree.add(page); + } else { + lastPage = page; + } + } + allocatedPages.clear(); + if (lastPage != null) { + // Add the last page back to the list of allocated pages to make sure it gets freed in + // case loadNext() never gets called again. + allocatedPages.add(lastPage); + } + } + + // in-memory sorter will not be used after spilling + assert (inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + totalSortTimeNanos += inMemSorter.getSortTimeNanos(); + // Do not free the sorter while we are locking `SpillableIterator`, + // as this can cause a deadlock. + inMemSorterToFree = inMemSorter; + inMemSorter = null; + taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); + totalSpillBytes += released; + return released; + } + } finally { + for (MemoryBlock pageToFree : pagesToFree) { + freePage(pageToFree); + } + if (inMemSorterToFree != null) { + inMemSorterToFree.freeMemory(); + } + } + } + + @Override + public boolean hasNext() { + return numRecords > 0; + } + + @Override + public void loadNext() throws IOException { + assert upstream != null; + MemoryBlock pageToFree = null; + try { + synchronized (this) { + loaded = true; + // Just consumed the last record from the in-memory iterator. + if (lastPage != null) { + // Do not free the page here, while we are locking `SpillableIterator`. The `freePage` + // method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in + // sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and + // `SpillableIterator` in sequence, which may happen in + // `TaskMemoryManager.acquireExecutionMemory`. + pageToFree = lastPage; + allocatedPages.clear(); + lastPage = null; + } + numRecords--; + upstream.loadNext(); + + // Keep track of the current base object, base offset, record length, and key prefix, + // so that the current record can still be read in case a spill is triggered and we + // switch to the spill writer's iterator. + currentBaseObject = upstream.getBaseObject(); + currentBaseOffset = upstream.getBaseOffset(); + currentRecordLength = upstream.getRecordLength(); + currentKeyPrefix = upstream.getKeyPrefix(); + } + } finally { + if (pageToFree != null) { + freePage(pageToFree); + } + } + } + + @Override + public Object getBaseObject() { + return currentBaseObject; + } + + @Override + public long getBaseOffset() { + return currentBaseOffset; + } + + @Override + public int getRecordLength() { + return currentRecordLength; + } + + @Override + public long getKeyPrefix() { + return currentKeyPrefix; + } + } + + /** + * Returns an iterator starts from startIndex, which will return the rows in the order as + * inserted. + * + *

It is the caller's responsibility to call `cleanupResources()` after consuming this + * iterator. + * + *

TODO: support forced spilling + */ + public UnsafeSorterIterator getIterator(int startIndex) throws IOException { + if (spillWriters.isEmpty()) { + assert (inMemSorter != null); + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex); + return iter; + } else { + LinkedList queue = new LinkedList<>(); + int i = 0; + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + if (i + spillWriter.recordsSpilled() > startIndex) { + UnsafeSorterIterator iter = spillWriter.getReader(serializerManager); + moveOver(iter, startIndex - i); + queue.add(iter); + } + i += spillWriter.recordsSpilled(); + } + if (inMemSorter != null && inMemSorter.numRecords() > 0) { + UnsafeSorterIterator iter = inMemSorter.getSortedIterator(); + moveOver(iter, startIndex - i); + queue.add(iter); + } + return new ChainedIterator(queue); + } + } + + private void moveOver(UnsafeSorterIterator iter, int steps) throws IOException { + if (steps > 0) { + for (int i = 0; i < steps; i++) { + if (iter.hasNext()) { + iter.loadNext(); + } else { + throw new ArrayIndexOutOfBoundsException( + "Failed to move the iterator " + steps + " steps forward"); + } + } + } + } + + /** Chain multiple UnsafeSorterIterator together as single one. */ + static class ChainedIterator extends UnsafeSorterIterator { + + private final Queue iterators; + private UnsafeSorterIterator current; + private int numRecords; + private final int[] iteratorsLength; + + ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.numRecords = 0; + this.iteratorsLength = new int[iterators.size()]; + int i = 0; + for (UnsafeSorterIterator iter : iterators) { + this.numRecords += iter.getNumRecords(); + iteratorsLength[i++] = iter.getNumRecords(); + } + this.iterators = iterators; + this.current = iterators.remove(); + } + + int[] numRecordForEach() { + return iteratorsLength; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public long getCurrentPageNumber() { + return current.getCurrentPageNumber(); + } + + @Override + public boolean hasNext() { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + current.loadNext(); + } + + @Override + public Object getBaseObject() { + return current.getBaseObject(); + } + + @Override + public long getBaseOffset() { + return current.getBaseOffset(); + } + + @Override + public int getRecordLength() { + return current.getRecordLength(); + } + + @Override + public long getKeyPrefix() { + return current.getKeyPrefix(); + } + } +}