From 6321824149eac85fe409a04f87622d9f78d1f0b5 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 18 Dec 2024 09:41:54 +0800 Subject: [PATCH] [VL] Fix `RetryOnOomMemoryTarget` only spills one single consumer on retrying --- .../memory/memtarget/MemoryTargetVisitor.java | 2 +- .../memory/memtarget/MemoryTargets.java | 9 +- .../memory/memtarget/TreeMemoryTargets.java | 159 +----------------- .../memtarget/spark/TreeMemoryConsumer.java | 145 +++++++++++++++- .../memtarget/spark/TreeMemoryConsumers.java | 22 +-- .../apache/spark/memory/SparkMemoryUtil.scala | 2 +- .../spark/TreeMemoryConsumerTest.java | 38 ++++- 7 files changed, 193 insertions(+), 184 deletions(-) diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java index a42a51e0ce4e..f6ef49a78920 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java @@ -28,7 +28,7 @@ public interface MemoryTargetVisitor { T visit(TreeMemoryConsumer treeMemoryConsumer); - T visit(TreeMemoryTargets.Node node); + T visit(TreeMemoryConsumer.Node node); T visit(LoggingMemoryTarget loggingMemoryTarget); diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java index 1997ce61d205..c6f5b59de8c2 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java @@ -64,9 +64,11 @@ public static TreeMemoryTarget newConsumer( Map virtualChildren) { final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(tmm); if (GlutenConfig.getConf().memoryIsolation()) { - return factory.newIsolatedConsumer(name, spiller, virtualChildren); + return TreeMemoryTargets.newChild(factory.isolatedRoot(), name, spiller, virtualChildren); } - final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, virtualChildren); + final TreeMemoryTarget root = factory.legacyRoot(); + final TreeMemoryTarget consumer = + TreeMemoryTargets.newChild(root, name, spiller, virtualChildren); if (SparkEnv.get() == null) { // We are likely in test code. Return the consumer directly. LOGGER.info("SparkEnv not found. We are likely in test code."); @@ -86,7 +88,8 @@ public static TreeMemoryTarget newConsumer( consumer, () -> { LOGGER.info("Request for spilling on consumer {}...", consumer.name()); - long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE); + // Note: Spill from root node so other consumers also get spilled. + long spilled = TreeMemoryTargets.spillTree(root, Long.MAX_VALUE); LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), spilled); }); } diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java index 598317a3c46a..6d94e7206959 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java @@ -17,18 +17,10 @@ package org.apache.gluten.memory.memtarget; import org.apache.gluten.memory.MemoryUsageStatsBuilder; -import org.apache.gluten.memory.SimpleMemoryUsageRecorder; -import org.apache.gluten.proto.MemoryUsageStats; -import com.google.common.base.Preconditions; -import org.apache.spark.util.Utils; - -import java.util.Collections; -import java.util.HashMap; import java.util.Map; import java.util.PriorityQueue; import java.util.Queue; -import java.util.stream.Collectors; public class TreeMemoryTargets { @@ -36,13 +28,16 @@ private TreeMemoryTargets() { // enclose factory ctor } - public static TreeMemoryTarget newChild( + /** + * A short-cut method to create a child target of `parent`. The child will follow the parent's + * maximum capacity. + */ + static TreeMemoryTarget newChild( TreeMemoryTarget parent, String name, - long capacity, Spiller spiller, Map virtualChildren) { - return new Node(parent, name, capacity, spiller, virtualChildren); + return parent.newChild(name, TreeMemoryTarget.CAPACITY_UNLIMITED, spiller, virtualChildren); } public static long spillTree(TreeMemoryTarget node, final long bytes) { @@ -83,146 +78,4 @@ private static long spillTree(TreeMemoryTarget node, Spiller.Phase phase, final return bytes - remainingBytes; } - - // non-root nodes are not Spark memory consumer - public static class Node implements TreeMemoryTarget, KnownNameAndStats { - private final Map children = new HashMap<>(); - private final TreeMemoryTarget parent; - private final String name; - private final long capacity; - private final Spiller spiller; - private final Map virtualChildren; - private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder(); - - private Node( - TreeMemoryTarget parent, - String name, - long capacity, - Spiller spiller, - Map virtualChildren) { - this.parent = parent; - this.capacity = capacity; - final String uniqueName = MemoryTargetUtil.toUniqueName(name); - if (capacity == CAPACITY_UNLIMITED) { - this.name = uniqueName; - } else { - this.name = String.format("%s, %s", uniqueName, Utils.bytesToString(capacity)); - } - this.spiller = spiller; - this.virtualChildren = virtualChildren; - } - - @Override - public long borrow(long size) { - if (size == 0) { - return 0; - } - ensureFreeCapacity(size); - return borrow0(Math.min(freeBytes(), size)); - } - - private long freeBytes() { - return capacity - usedBytes(); - } - - private long borrow0(long size) { - long granted = parent.borrow(size); - selfRecorder.inc(granted); - return granted; - } - - @Override - public Spiller getNodeSpiller() { - return spiller; - } - - private boolean ensureFreeCapacity(long bytesNeeded) { - while (true) { // FIXME should we add retry limit? - long freeBytes = freeBytes(); - Preconditions.checkState(freeBytes >= 0); - if (freeBytes >= bytesNeeded) { - // free bytes fit requirement - return true; - } - // spill - long bytesToSpill = bytesNeeded - freeBytes; - long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill); - Preconditions.checkState(spilledBytes >= 0); - if (spilledBytes == 0) { - // OOM - return false; - } - } - } - - @Override - public long repay(long size) { - if (size == 0) { - return 0; - } - long toFree = Math.min(usedBytes(), size); - long freed = parent.repay(toFree); - selfRecorder.inc(-freed); - return freed; - } - - @Override - public long usedBytes() { - return selfRecorder.current(); - } - - @Override - public T accept(MemoryTargetVisitor visitor) { - return visitor.visit(this); - } - - @Override - public String name() { - return name; - } - - @Override - public MemoryUsageStats stats() { - final Map childrenStats = - new HashMap<>( - children.entrySet().stream() - .collect(Collectors.toMap(e -> e.getValue().name(), e -> e.getValue().stats()))); - - Preconditions.checkState(childrenStats.size() == children.size()); - - // add virtual children - for (Map.Entry entry : virtualChildren.entrySet()) { - if (childrenStats.containsKey(entry.getKey())) { - throw new IllegalArgumentException("Child stats already exists: " + entry.getKey()); - } - childrenStats.put(entry.getKey(), entry.getValue().toStats()); - } - return selfRecorder.toStats(childrenStats); - } - - @Override - public TreeMemoryTarget newChild( - String name, - long capacity, - Spiller spiller, - Map virtualChildren) { - final Node child = - new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren); - if (children.containsKey(child.name())) { - throw new IllegalArgumentException("Child already registered: " + child.name()); - } - children.put(child.name(), child); - return child; - } - - @Override - public Map children() { - return Collections.unmodifiableMap(children); - } - - @Override - public TreeMemoryTarget parent() { - return parent; - } - } } diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java index 1289a01c349e..38ac7d9733b6 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; import java.io.IOException; import java.util.Collections; @@ -121,8 +122,7 @@ public TreeMemoryTarget newChild( long capacity, Spiller spiller, Map virtualChildren) { - final TreeMemoryTarget child = - TreeMemoryTargets.newChild(this, name, capacity, spiller, virtualChildren); + final TreeMemoryTarget child = new Node(this, name, capacity, spiller, virtualChildren); if (children.containsKey(child.name())) { throw new IllegalArgumentException("Child already registered: " + child.name()); } @@ -151,4 +151,145 @@ public Spiller getNodeSpiller() { public TaskMemoryManager getTaskMemoryManager() { return taskMemoryManager; } + + public static class Node implements TreeMemoryTarget, KnownNameAndStats { + private final Map children = new HashMap<>(); + private final TreeMemoryTarget parent; + private final String name; + private final long capacity; + private final Spiller spiller; + private final Map virtualChildren; + private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder(); + + private Node( + TreeMemoryTarget parent, + String name, + long capacity, + Spiller spiller, + Map virtualChildren) { + this.parent = parent; + this.capacity = capacity; + final String uniqueName = MemoryTargetUtil.toUniqueName(name); + if (capacity == TreeMemoryTarget.CAPACITY_UNLIMITED) { + this.name = uniqueName; + } else { + this.name = String.format("%s, %s", uniqueName, Utils.bytesToString(capacity)); + } + this.spiller = spiller; + this.virtualChildren = virtualChildren; + } + + @Override + public long borrow(long size) { + if (size == 0) { + return 0; + } + ensureFreeCapacity(size); + return borrow0(Math.min(freeBytes(), size)); + } + + private long freeBytes() { + return capacity - usedBytes(); + } + + private long borrow0(long size) { + long granted = parent.borrow(size); + selfRecorder.inc(granted); + return granted; + } + + @Override + public Spiller getNodeSpiller() { + return spiller; + } + + private boolean ensureFreeCapacity(long bytesNeeded) { + while (true) { // FIXME should we add retry limit? + long freeBytes = freeBytes(); + Preconditions.checkState(freeBytes >= 0); + if (freeBytes >= bytesNeeded) { + // free bytes fit requirement + return true; + } + // spill + long bytesToSpill = bytesNeeded - freeBytes; + long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill); + Preconditions.checkState(spilledBytes >= 0); + if (spilledBytes == 0) { + // OOM + return false; + } + } + } + + @Override + public long repay(long size) { + if (size == 0) { + return 0; + } + long toFree = Math.min(usedBytes(), size); + long freed = parent.repay(toFree); + selfRecorder.inc(-freed); + return freed; + } + + @Override + public long usedBytes() { + return selfRecorder.current(); + } + + @Override + public T accept(MemoryTargetVisitor visitor) { + return visitor.visit(this); + } + + @Override + public String name() { + return name; + } + + @Override + public MemoryUsageStats stats() { + final Map childrenStats = + new HashMap<>( + children.entrySet().stream() + .collect(Collectors.toMap(e -> e.getValue().name(), e -> e.getValue().stats()))); + + Preconditions.checkState(childrenStats.size() == children.size()); + + // add virtual children + for (Map.Entry entry : virtualChildren.entrySet()) { + if (childrenStats.containsKey(entry.getKey())) { + throw new IllegalArgumentException("Child stats already exists: " + entry.getKey()); + } + childrenStats.put(entry.getKey(), entry.getValue().toStats()); + } + return selfRecorder.toStats(childrenStats); + } + + @Override + public TreeMemoryTarget newChild( + String name, + long capacity, + Spiller spiller, + Map virtualChildren) { + final Node child = + new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren); + if (children.containsKey(child.name())) { + throw new IllegalArgumentException("Child already registered: " + child.name()); + } + children.put(child.name(), child); + return child; + } + + @Override + public Map children() { + return Collections.unmodifiableMap(children); + } + + @Override + public TreeMemoryTarget parent() { + return parent; + } + } } diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java index e8bfb5cf7569..a11a4a3e4a19 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java @@ -17,8 +17,6 @@ package org.apache.gluten.memory.memtarget.spark; import org.apache.gluten.GlutenConfig; -import org.apache.gluten.memory.MemoryUsageStatsBuilder; -import org.apache.gluten.memory.memtarget.Spiller; import org.apache.gluten.memory.memtarget.Spillers; import org.apache.gluten.memory.memtarget.TreeMemoryTarget; @@ -61,22 +59,12 @@ private TreeMemoryTarget ofCapacity(long capacity) { Collections.emptyMap())); } - private TreeMemoryTarget legacyRoot() { - return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED); - } - - private TreeMemoryTarget isolatedRoot() { - return ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize()); - } - /** * This works as a legacy Spark memory consumer which grants as much as possible of memory * capacity to each task. */ - public TreeMemoryTarget newLegacyConsumer( - String name, Spiller spiller, Map virtualChildren) { - final TreeMemoryTarget parent = legacyRoot(); - return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren); + public TreeMemoryTarget legacyRoot() { + return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED); } /** @@ -88,10 +76,8 @@ public TreeMemoryTarget newLegacyConsumer( * *

See GLUTEN-3030 */ - public TreeMemoryTarget newIsolatedConsumer( - String name, Spiller spiller, Map virtualChildren) { - final TreeMemoryTarget parent = isolatedRoot(); - return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren); + public TreeMemoryTarget isolatedRoot() { + return ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize()); } } } diff --git a/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala b/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala index 637ef8b22fd4..338854cf086c 100644 --- a/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala @@ -111,7 +111,7 @@ object SparkMemoryUtil { collectFromTaskMemoryManager(treeMemoryConsumer.getTaskMemoryManager) } - override def visit(node: TreeMemoryTargets.Node): String = { + override def visit(node: TreeMemoryConsumer.Node): String = { node.parent().accept(this) // walk up to find the one bound with task memory manager } diff --git a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java index 934300a1acd7..6cb38fe8d5d3 100644 --- a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java +++ b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java @@ -52,7 +52,13 @@ public void testIsolated() { final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()); final TreeMemoryTarget consumer = - factory.newIsolatedConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); + factory + .isolatedRoot() + .newChild( + "FOO", + TreeMemoryTarget.CAPACITY_UNLIMITED, + Spillers.NOOP, + Collections.emptyMap()); Assert.assertEquals(20, consumer.borrow(20)); Assert.assertEquals(70, consumer.borrow(70)); Assert.assertEquals(10, consumer.borrow(20)); @@ -67,7 +73,13 @@ public void testLegacy() { final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()); final TreeMemoryTarget consumer = - factory.newLegacyConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); + factory + .legacyRoot() + .newChild( + "FOO", + TreeMemoryTarget.CAPACITY_UNLIMITED, + Spillers.NOOP, + Collections.emptyMap()); Assert.assertEquals(20, consumer.borrow(20)); Assert.assertEquals(70, consumer.borrow(70)); Assert.assertEquals(20, consumer.borrow(20)); @@ -81,11 +93,21 @@ public void testIsolatedAndLegacy() { () -> { final TreeMemoryTarget legacy = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) - .newLegacyConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); + .legacyRoot() + .newChild( + "FOO", + TreeMemoryTarget.CAPACITY_UNLIMITED, + Spillers.NOOP, + Collections.emptyMap()); Assert.assertEquals(110, legacy.borrow(110)); final TreeMemoryTarget isolated = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) - .newIsolatedConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); + .isolatedRoot() + .newChild( + "FOO", + TreeMemoryTarget.CAPACITY_UNLIMITED, + Spillers.NOOP, + Collections.emptyMap()); Assert.assertEquals(100, isolated.borrow(110)); }); } @@ -97,7 +119,9 @@ public void testSpill() { final Spillers.AppendableSpillerList spillers = Spillers.appendable(); final TreeMemoryTarget legacy = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) - .newLegacyConsumer("FOO", spillers, Collections.emptyMap()); + .legacyRoot() + .newChild( + "FOO", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers, Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); spillers.append( @@ -131,7 +155,9 @@ public void testOverSpill() { final Spillers.AppendableSpillerList spillers = Spillers.appendable(); final TreeMemoryTarget legacy = TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) - .newLegacyConsumer("FOO", spillers, Collections.emptyMap()); + .legacyRoot() + .newChild( + "FOO", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers, Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); spillers.append(