diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java index c0f74c7990d1..1997ce61d205 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java @@ -24,10 +24,13 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.SparkResourceUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Map; public final class MemoryTargets { + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryTargets.class); private MemoryTargets() { // enclose factory ctor @@ -45,14 +48,6 @@ public static MemoryTarget overAcquire( return new OverAcquire(target, overTarget, overAcquiredRatio); } - public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) { - SparkEnv env = SparkEnv.get(); - if (env != null && env.conf() != null && SparkResourceUtil.getTaskSlots(env.conf()) > 1) { - return new RetryOnOomMemoryTarget(target); - } - return target; - } - @Experimental public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarget) { if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) { @@ -67,14 +62,32 @@ public static TreeMemoryTarget newConsumer( String name, Spiller spiller, Map virtualChildren) { - final TreeMemoryConsumers.Factory factory; + final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(tmm); if (GlutenConfig.getConf().memoryIsolation()) { - return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller, virtualChildren); - } else { - // Retry of spilling is needed in shared mode because the maxMemoryPerTask of Vanilla Spark - // ExecutionMemoryPool is dynamic when with multi-slot config. - return MemoryTargets.retrySpillOnOom( - TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller, virtualChildren)); + return factory.newIsolatedConsumer(name, spiller, virtualChildren); + } + final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, virtualChildren); + if (SparkEnv.get() == null) { + // We are likely in test code. Return the consumer directly. + LOGGER.info("SparkEnv not found. We are likely in test code."); + return consumer; + } + final int taskSlots = SparkResourceUtil.getTaskSlots(SparkEnv.get().conf()); + if (taskSlots == 1) { + // We don't need to retry on OOM in the case one single task occupies the whole executor. + return consumer; } + // Since https://github.com/apache/incubator-gluten/pull/8132. + // Retry of spilling is needed in multi-slot and legacy mode (formerly named as share mode) + // because the maxMemoryPerTask defined by vanilla Spark's ExecutionMemoryPool is dynamic. + // + // See the original issue https://github.com/apache/incubator-gluten/issues/8128. + return new RetryOnOomMemoryTarget( + consumer, + () -> { + LOGGER.info("Request for spilling on consumer {}...", consumer.name()); + long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE); + LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), spilled); + }); } } diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java index 1a5388d0d187..b564bbcaa41c 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java @@ -27,39 +27,30 @@ public class RetryOnOomMemoryTarget implements TreeMemoryTarget { private static final Logger LOGGER = LoggerFactory.getLogger(RetryOnOomMemoryTarget.class); private final TreeMemoryTarget target; + private final Runnable onRetry; - RetryOnOomMemoryTarget(TreeMemoryTarget target) { + RetryOnOomMemoryTarget(TreeMemoryTarget target, Runnable onRetry) { this.target = target; + this.onRetry = onRetry; } @Override public long borrow(long size) { long granted = target.borrow(size); if (granted < size) { - LOGGER.info("Retrying spill require:{} got:{}", size, granted); - final long spilled = retryingSpill(Long.MAX_VALUE); + LOGGER.info("Granted size {} is less than requested size {}, retrying...", granted, size); final long remaining = size - granted; - if (spilled >= remaining) { - granted += target.borrow(remaining); - } - LOGGER.info("Retrying spill spilled:{} final granted:{}", spilled, granted); + // Invoke the `onRetry` callback, then retry borrowing. + // It's usually expected to run extra spilling logics in + // the `onRetry` callback so we may get enough memory space + // to allocate the remaining bytes. + onRetry.run(); + granted += target.borrow(remaining); + LOGGER.info("Newest granted size after retrying: {}, requested size {}.", granted, size); } return granted; } - private long retryingSpill(long size) { - TreeMemoryTarget rootTarget = target; - while (true) { - try { - rootTarget = rootTarget.parent(); - } catch (IllegalStateException e) { - // Reached the root node - break; - } - } - return TreeMemoryTargets.spillTree(rootTarget, size); - } - @Override public long repay(long size) { return target.repay(size); diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java index 26c6ea48008a..598317a3c46a 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java @@ -206,7 +206,8 @@ public TreeMemoryTarget newChild( long capacity, Spiller spiller, Map virtualChildren) { - final Node child = new Node(this, name, capacity, spiller, virtualChildren); + final Node child = + new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren); if (children.containsKey(child.name())) { throw new IllegalArgumentException("Child already registered: " + child.name()); } diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java index 7ab05bd3a2e7..e8bfb5cf7569 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java @@ -24,71 +24,74 @@ import org.apache.commons.collections.map.ReferenceMap; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; public final class TreeMemoryConsumers { - private static final Map FACTORIES = new ConcurrentHashMap<>(); + private static final ReferenceMap FACTORIES = new ReferenceMap(); private TreeMemoryConsumers() {} - private static Factory createOrGetFactory(long perTaskCapacity) { - return FACTORIES.computeIfAbsent(perTaskCapacity, Factory::new); + @SuppressWarnings("unchecked") + public static Factory factory(TaskMemoryManager tmm) { + synchronized (FACTORIES) { + return (Factory) FACTORIES.computeIfAbsent(tmm, m -> new Factory((TaskMemoryManager) m)); + } } - /** - * A hub to provide memory target instances whose shared size (in the same task) is limited to X, - * X = executor memory / task slots. - * - *

Using this to prevent OOMs if the delegated memory target could possibly hold large memory - * blocks that are not spillable. - * - *

See GLUTEN-3030 - */ - public static Factory isolated() { - return createOrGetFactory(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize()); - } + public static class Factory { + private final TreeMemoryConsumer sparkConsumer; + private final Map roots = new ConcurrentHashMap<>(); - /** - * This works as a legacy Spark memory consumer which grants as much as possible of memory - * capacity to each task. - */ - public static Factory shared() { - return createOrGetFactory(TreeMemoryTarget.CAPACITY_UNLIMITED); - } + private Factory(TaskMemoryManager tmm) { + this.sparkConsumer = new TreeMemoryConsumer(tmm); + } - public static class Factory { - private final ReferenceMap map = new ReferenceMap(ReferenceMap.WEAK, ReferenceMap.WEAK); - private final long perTaskCapacity; + private TreeMemoryTarget ofCapacity(long capacity) { + return roots.computeIfAbsent( + capacity, + cap -> + sparkConsumer.newChild( + String.format("Capacity[%s]", Utils.bytesToString(cap)), + cap, + Spillers.NOOP, + Collections.emptyMap())); + } + + private TreeMemoryTarget legacyRoot() { + return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED); + } - private Factory(long perTaskCapacity) { - this.perTaskCapacity = perTaskCapacity; + private TreeMemoryTarget isolatedRoot() { + return ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize()); } - @SuppressWarnings("unchecked") - private TreeMemoryTarget getSharedAccount(TaskMemoryManager tmm) { - synchronized (map) { - return (TreeMemoryTarget) - map.computeIfAbsent( - tmm, - m -> { - TreeMemoryTarget tmc = new TreeMemoryConsumer((TaskMemoryManager) m); - return tmc.newChild( - "root", perTaskCapacity, Spillers.NOOP, Collections.emptyMap()); - }); - } + /** + * This works as a legacy Spark memory consumer which grants as much as possible of memory + * capacity to each task. + */ + public TreeMemoryTarget newLegacyConsumer( + String name, Spiller spiller, Map virtualChildren) { + final TreeMemoryTarget parent = legacyRoot(); + return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren); } - public TreeMemoryTarget newConsumer( - TaskMemoryManager tmm, - String name, - Spiller spiller, - Map virtualChildren) { - final TreeMemoryTarget account = getSharedAccount(tmm); - return account.newChild( - name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren); + /** + * A hub to provide memory target instances whose shared size (in the same task) is limited to + * X, X = executor memory / task slots. + * + *

Using this to prevent OOMs if the delegated memory target could possibly hold large memory + * blocks that are not spill-able. + * + *

See GLUTEN-3030 + */ + public TreeMemoryTarget newIsolatedConsumer( + String name, Spiller spiller, Map virtualChildren) { + final TreeMemoryTarget parent = isolatedRoot(); + return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren); } } } diff --git a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala index df5917125b64..2f609b026db3 100644 --- a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala +++ b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.task +import org.apache.gluten.GlutenConfig import org.apache.gluten.task.TaskListener import org.apache.spark.{TaskContext, TaskFailedReason, TaskKilledException, UnknownReason} @@ -65,8 +66,8 @@ object TaskResources extends TaskListener with Logging { properties.put(key, value) case _ => } - properties.setIfMissing("spark.memory.offHeap.enabled", "true") - properties.setIfMissing("spark.memory.offHeap.size", "1TB") + properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_ENABLED, "true") + properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_SIZE_KEY, "1TB") TaskContext.setTaskContext(newUnsafeTaskContext(properties)) } diff --git a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java index befe449186e7..934300a1acd7 100644 --- a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java +++ b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java @@ -49,13 +49,10 @@ public void setUp() throws Exception { public void testIsolated() { test( () -> { - final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.isolated(); + final TreeMemoryConsumers.Factory factory = + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()); final TreeMemoryTarget consumer = - factory.newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - Spillers.NOOP, - Collections.emptyMap()); + factory.newIsolatedConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); Assert.assertEquals(20, consumer.borrow(20)); Assert.assertEquals(70, consumer.borrow(70)); Assert.assertEquals(10, consumer.borrow(20)); @@ -64,16 +61,13 @@ public void testIsolated() { } @Test - public void testShared() { + public void testLegacy() { test( () -> { - final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.shared(); + final TreeMemoryConsumers.Factory factory = + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()); final TreeMemoryTarget consumer = - factory.newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - Spillers.NOOP, - Collections.emptyMap()); + factory.newLegacyConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); Assert.assertEquals(20, consumer.borrow(20)); Assert.assertEquals(70, consumer.borrow(70)); Assert.assertEquals(20, consumer.borrow(20)); @@ -82,24 +76,16 @@ public void testShared() { } @Test - public void testIsolatedAndShared() { + public void testIsolatedAndLegacy() { test( () -> { - final TreeMemoryTarget shared = - TreeMemoryConsumers.shared() - .newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - Spillers.NOOP, - Collections.emptyMap()); - Assert.assertEquals(110, shared.borrow(110)); + final TreeMemoryTarget legacy = + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) + .newLegacyConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); + Assert.assertEquals(110, legacy.borrow(110)); final TreeMemoryTarget isolated = - TreeMemoryConsumers.isolated() - .newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - Spillers.NOOP, - Collections.emptyMap()); + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) + .newIsolatedConsumer("FOO", Spillers.NOOP, Collections.emptyMap()); Assert.assertEquals(100, isolated.borrow(110)); }); } @@ -109,36 +95,32 @@ public void testSpill() { test( () -> { final Spillers.AppendableSpillerList spillers = Spillers.appendable(); - final TreeMemoryTarget shared = - TreeMemoryConsumers.shared() - .newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - spillers, - Collections.emptyMap()); + final TreeMemoryTarget legacy = + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) + .newLegacyConsumer("FOO", spillers, Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); spillers.append( new Spiller() { @Override public long spill(MemoryTarget self, Phase phase, long size) { - long repaid = shared.repay(size); + long repaid = legacy.repay(size); numSpills.getAndIncrement(); numSpilledBytes.getAndAdd(repaid); return repaid; } }); - Assert.assertEquals(300, shared.borrow(300)); - Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); Assert.assertEquals(1, numSpills.get()); Assert.assertEquals(200, numSpilledBytes.get()); - Assert.assertEquals(400, shared.usedBytes()); + Assert.assertEquals(400, legacy.usedBytes()); - Assert.assertEquals(300, shared.borrow(300)); - Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); Assert.assertEquals(3, numSpills.get()); Assert.assertEquals(800, numSpilledBytes.get()); - Assert.assertEquals(400, shared.usedBytes()); + Assert.assertEquals(400, legacy.usedBytes()); }); } @@ -147,36 +129,32 @@ public void testOverSpill() { test( () -> { final Spillers.AppendableSpillerList spillers = Spillers.appendable(); - final TreeMemoryTarget shared = - TreeMemoryConsumers.shared() - .newConsumer( - TaskContext.get().taskMemoryManager(), - "FOO", - spillers, - Collections.emptyMap()); + final TreeMemoryTarget legacy = + TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager()) + .newLegacyConsumer("FOO", spillers, Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); spillers.append( new Spiller() { @Override public long spill(MemoryTarget self, Phase phase, long size) { - long repaid = shared.repay(Long.MAX_VALUE); + long repaid = legacy.repay(Long.MAX_VALUE); numSpills.getAndIncrement(); numSpilledBytes.getAndAdd(repaid); return repaid; } }); - Assert.assertEquals(300, shared.borrow(300)); - Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); Assert.assertEquals(1, numSpills.get()); Assert.assertEquals(300, numSpilledBytes.get()); - Assert.assertEquals(300, shared.usedBytes()); + Assert.assertEquals(300, legacy.usedBytes()); - Assert.assertEquals(300, shared.borrow(300)); - Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); + Assert.assertEquals(300, legacy.borrow(300)); Assert.assertEquals(3, numSpills.get()); Assert.assertEquals(900, numSpilledBytes.get()); - Assert.assertEquals(300, shared.usedBytes()); + Assert.assertEquals(300, legacy.usedBytes()); }); }