Skip to content

Commit

Permalink
[CORE] Minor code cleanups for TreeMemoryConsumer (#8254)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Dec 18, 2024
1 parent 3dbbd82 commit ef7ccad
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import org.apache.spark.annotation.Experimental;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.SparkResourceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

public final class MemoryTargets {
private static final Logger LOGGER = LoggerFactory.getLogger(MemoryTargets.class);

private MemoryTargets() {
// enclose factory ctor
Expand All @@ -45,14 +48,6 @@ public static MemoryTarget overAcquire(
return new OverAcquire(target, overTarget, overAcquiredRatio);
}

public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) {
SparkEnv env = SparkEnv.get();
if (env != null && env.conf() != null && SparkResourceUtil.getTaskSlots(env.conf()) > 1) {
return new RetryOnOomMemoryTarget(target);
}
return target;
}

@Experimental
public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarget) {
if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) {
Expand All @@ -67,14 +62,32 @@ public static TreeMemoryTarget newConsumer(
String name,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory;
final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(tmm);
if (GlutenConfig.getConf().memoryIsolation()) {
return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller, virtualChildren);
} else {
// Retry of spilling is needed in shared mode because the maxMemoryPerTask of Vanilla Spark
// ExecutionMemoryPool is dynamic when with multi-slot config.
return MemoryTargets.retrySpillOnOom(
TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller, virtualChildren));
return factory.newIsolatedConsumer(name, spiller, virtualChildren);
}
final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, virtualChildren);
if (SparkEnv.get() == null) {
// We are likely in test code. Return the consumer directly.
LOGGER.info("SparkEnv not found. We are likely in test code.");
return consumer;
}
final int taskSlots = SparkResourceUtil.getTaskSlots(SparkEnv.get().conf());
if (taskSlots == 1) {
// We don't need to retry on OOM in the case one single task occupies the whole executor.
return consumer;
}
// Since https://github.com/apache/incubator-gluten/pull/8132.
// Retry of spilling is needed in multi-slot and legacy mode (formerly named as share mode)
// because the maxMemoryPerTask defined by vanilla Spark's ExecutionMemoryPool is dynamic.
//
// See the original issue https://github.com/apache/incubator-gluten/issues/8128.
return new RetryOnOomMemoryTarget(
consumer,
() -> {
LOGGER.info("Request for spilling on consumer {}...", consumer.name());
long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);

This comment has been minimized.

Copy link
@zhztheplayer

zhztheplayer Dec 18, 2024

Author Member

My mistake here. The code may not match the original intention.

I'll do another follow-up.

LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), spilled);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,30 @@
public class RetryOnOomMemoryTarget implements TreeMemoryTarget {
private static final Logger LOGGER = LoggerFactory.getLogger(RetryOnOomMemoryTarget.class);
private final TreeMemoryTarget target;
private final Runnable onRetry;

RetryOnOomMemoryTarget(TreeMemoryTarget target) {
RetryOnOomMemoryTarget(TreeMemoryTarget target, Runnable onRetry) {
this.target = target;
this.onRetry = onRetry;
}

@Override
public long borrow(long size) {
long granted = target.borrow(size);
if (granted < size) {
LOGGER.info("Retrying spill require:{} got:{}", size, granted);
final long spilled = retryingSpill(Long.MAX_VALUE);
LOGGER.info("Granted size {} is less than requested size {}, retrying...", granted, size);
final long remaining = size - granted;
if (spilled >= remaining) {
granted += target.borrow(remaining);
}
LOGGER.info("Retrying spill spilled:{} final granted:{}", spilled, granted);
// Invoke the `onRetry` callback, then retry borrowing.
// It's usually expected to run extra spilling logics in
// the `onRetry` callback so we may get enough memory space
// to allocate the remaining bytes.
onRetry.run();
granted += target.borrow(remaining);
LOGGER.info("Newest granted size after retrying: {}, requested size {}.", granted, size);
}
return granted;
}

private long retryingSpill(long size) {
TreeMemoryTarget rootTarget = target;
while (true) {
try {
rootTarget = rootTarget.parent();
} catch (IllegalStateException e) {
// Reached the root node
break;
}
}
return TreeMemoryTargets.spillTree(rootTarget, size);
}

@Override
public long repay(long size) {
return target.repay(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ public TreeMemoryTarget newChild(
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child = new Node(this, name, capacity, spiller, virtualChildren);
final Node child =
new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,71 +24,74 @@

import org.apache.commons.collections.map.ReferenceMap;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class TreeMemoryConsumers {
private static final Map<Long, Factory> FACTORIES = new ConcurrentHashMap<>();
private static final ReferenceMap FACTORIES = new ReferenceMap();

private TreeMemoryConsumers() {}

private static Factory createOrGetFactory(long perTaskCapacity) {
return FACTORIES.computeIfAbsent(perTaskCapacity, Factory::new);
@SuppressWarnings("unchecked")
public static Factory factory(TaskMemoryManager tmm) {
synchronized (FACTORIES) {
return (Factory) FACTORIES.computeIfAbsent(tmm, m -> new Factory((TaskMemoryManager) m));
}
}

/**
* A hub to provide memory target instances whose shared size (in the same task) is limited to X,
* X = executor memory / task slots.
*
* <p>Using this to prevent OOMs if the delegated memory target could possibly hold large memory
* blocks that are not spillable.
*
* <p>See <a href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
*/
public static Factory isolated() {
return createOrGetFactory(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}
public static class Factory {
private final TreeMemoryConsumer sparkConsumer;
private final Map<Long, TreeMemoryTarget> roots = new ConcurrentHashMap<>();

/**
* This works as a legacy Spark memory consumer which grants as much as possible of memory
* capacity to each task.
*/
public static Factory shared() {
return createOrGetFactory(TreeMemoryTarget.CAPACITY_UNLIMITED);
}
private Factory(TaskMemoryManager tmm) {
this.sparkConsumer = new TreeMemoryConsumer(tmm);
}

public static class Factory {
private final ReferenceMap map = new ReferenceMap(ReferenceMap.WEAK, ReferenceMap.WEAK);
private final long perTaskCapacity;
private TreeMemoryTarget ofCapacity(long capacity) {
return roots.computeIfAbsent(
capacity,
cap ->
sparkConsumer.newChild(
String.format("Capacity[%s]", Utils.bytesToString(cap)),
cap,
Spillers.NOOP,
Collections.emptyMap()));
}

private TreeMemoryTarget legacyRoot() {
return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
}

private Factory(long perTaskCapacity) {
this.perTaskCapacity = perTaskCapacity;
private TreeMemoryTarget isolatedRoot() {
return ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}

@SuppressWarnings("unchecked")
private TreeMemoryTarget getSharedAccount(TaskMemoryManager tmm) {
synchronized (map) {
return (TreeMemoryTarget)
map.computeIfAbsent(
tmm,
m -> {
TreeMemoryTarget tmc = new TreeMemoryConsumer((TaskMemoryManager) m);
return tmc.newChild(
"root", perTaskCapacity, Spillers.NOOP, Collections.emptyMap());
});
}
/**
* This works as a legacy Spark memory consumer which grants as much as possible of memory
* capacity to each task.
*/
public TreeMemoryTarget newLegacyConsumer(
String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget parent = legacyRoot();
return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
}

public TreeMemoryTarget newConsumer(
TaskMemoryManager tmm,
String name,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget account = getSharedAccount(tmm);
return account.newChild(
name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
/**
* A hub to provide memory target instances whose shared size (in the same task) is limited to
* X, X = executor memory / task slots.
*
* <p>Using this to prevent OOMs if the delegated memory target could possibly hold large memory
* blocks that are not spill-able.
*
* <p>See <a href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
*/
public TreeMemoryTarget newIsolatedConsumer(
String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget parent = isolatedRoot();
return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.task

import org.apache.gluten.GlutenConfig
import org.apache.gluten.task.TaskListener

import org.apache.spark.{TaskContext, TaskFailedReason, TaskKilledException, UnknownReason}
Expand Down Expand Up @@ -65,8 +66,8 @@ object TaskResources extends TaskListener with Logging {
properties.put(key, value)
case _ =>
}
properties.setIfMissing("spark.memory.offHeap.enabled", "true")
properties.setIfMissing("spark.memory.offHeap.size", "1TB")
properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_ENABLED, "true")
properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_SIZE_KEY, "1TB")
TaskContext.setTaskContext(newUnsafeTaskContext(properties))
}

Expand Down
Loading

0 comments on commit ef7ccad

Please sign in to comment.