Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Minor code cleanups for TreeMemoryConsumer #8254

Merged
merged 11 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import org.apache.spark.annotation.Experimental;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.SparkResourceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

public final class MemoryTargets {
private static final Logger LOGGER = LoggerFactory.getLogger(MemoryTargets.class);

private MemoryTargets() {
// enclose factory ctor
Expand All @@ -45,14 +48,6 @@ public static MemoryTarget overAcquire(
return new OverAcquire(target, overTarget, overAcquiredRatio);
}

public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) {
SparkEnv env = SparkEnv.get();
if (env != null && env.conf() != null && SparkResourceUtil.getTaskSlots(env.conf()) > 1) {
return new RetryOnOomMemoryTarget(target);
}
return target;
}

@Experimental
public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarget) {
if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) {
Expand All @@ -67,14 +62,32 @@ public static TreeMemoryTarget newConsumer(
String name,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory;
final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(tmm);
if (GlutenConfig.getConf().memoryIsolation()) {
return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller, virtualChildren);
} else {
// Retry of spilling is needed in shared mode because the maxMemoryPerTask of Vanilla Spark
// ExecutionMemoryPool is dynamic when with multi-slot config.
return MemoryTargets.retrySpillOnOom(
TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller, virtualChildren));
return factory.newIsolatedConsumer(name, spiller, virtualChildren);
}
final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, virtualChildren);
if (SparkEnv.get() == null) {
// We are likely in test code. Return the consumer directly.
LOGGER.info("SparkEnv not found. We are likely in test code.");
return consumer;
}
final int taskSlots = SparkResourceUtil.getTaskSlots(SparkEnv.get().conf());
if (taskSlots == 1) {
// We don't need to retry on OOM in the case one single task occupies the whole executor.
return consumer;
}
// Since https://github.com/apache/incubator-gluten/pull/8132.
// Retry of spilling is needed in multi-slot and legacy mode (formerly named as share mode)
// because the maxMemoryPerTask defined by vanilla Spark's ExecutionMemoryPool is dynamic.
//
// See the original issue https://github.com/apache/incubator-gluten/issues/8128.
return new RetryOnOomMemoryTarget(
consumer,
() -> {
LOGGER.info("Request for spilling on consumer {}...", consumer.name());
long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);
LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), spilled);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,30 @@
public class RetryOnOomMemoryTarget implements TreeMemoryTarget {
private static final Logger LOGGER = LoggerFactory.getLogger(RetryOnOomMemoryTarget.class);
private final TreeMemoryTarget target;
private final Runnable onRetry;

RetryOnOomMemoryTarget(TreeMemoryTarget target) {
RetryOnOomMemoryTarget(TreeMemoryTarget target, Runnable onRetry) {
this.target = target;
this.onRetry = onRetry;
}

@Override
public long borrow(long size) {
long granted = target.borrow(size);
if (granted < size) {
LOGGER.info("Retrying spill require:{} got:{}", size, granted);
final long spilled = retryingSpill(Long.MAX_VALUE);
LOGGER.info("Granted size {} is less than requested size {}, retrying...", granted, size);
final long remaining = size - granted;
if (spilled >= remaining) {
granted += target.borrow(remaining);
}
LOGGER.info("Retrying spill spilled:{} final granted:{}", spilled, granted);
// Invoke the `onRetry` callback, then retry borrowing.
// It's usually expected to run extra spilling logics in
// the `onRetry` callback so we may get enough memory space
// to allocate the remaining bytes.
onRetry.run();
granted += target.borrow(remaining);
LOGGER.info("Newest granted size after retrying: {}, requested size {}.", granted, size);
}
return granted;
}

private long retryingSpill(long size) {
TreeMemoryTarget rootTarget = target;
while (true) {
try {
rootTarget = rootTarget.parent();
} catch (IllegalStateException e) {
// Reached the root node
break;
}
}
return TreeMemoryTargets.spillTree(rootTarget, size);
}

@Override
public long repay(long size) {
return target.repay(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ public TreeMemoryTarget newChild(
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child = new Node(this, name, capacity, spiller, virtualChildren);
final Node child =
new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,71 +24,74 @@

import org.apache.commons.collections.map.ReferenceMap;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class TreeMemoryConsumers {
private static final Map<Long, Factory> FACTORIES = new ConcurrentHashMap<>();
private static final ReferenceMap FACTORIES = new ReferenceMap();

private TreeMemoryConsumers() {}

private static Factory createOrGetFactory(long perTaskCapacity) {
return FACTORIES.computeIfAbsent(perTaskCapacity, Factory::new);
@SuppressWarnings("unchecked")
public static Factory factory(TaskMemoryManager tmm) {
synchronized (FACTORIES) {
return (Factory) FACTORIES.computeIfAbsent(tmm, m -> new Factory((TaskMemoryManager) m));
}
}

/**
* A hub to provide memory target instances whose shared size (in the same task) is limited to X,
* X = executor memory / task slots.
*
* <p>Using this to prevent OOMs if the delegated memory target could possibly hold large memory
* blocks that are not spillable.
*
* <p>See <a href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
*/
public static Factory isolated() {
return createOrGetFactory(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}
public static class Factory {
private final TreeMemoryConsumer sparkConsumer;
private final Map<Long, TreeMemoryTarget> roots = new ConcurrentHashMap<>();

/**
* This works as a legacy Spark memory consumer which grants as much as possible of memory
* capacity to each task.
*/
public static Factory shared() {
return createOrGetFactory(TreeMemoryTarget.CAPACITY_UNLIMITED);
}
private Factory(TaskMemoryManager tmm) {
this.sparkConsumer = new TreeMemoryConsumer(tmm);
}

public static class Factory {
private final ReferenceMap map = new ReferenceMap(ReferenceMap.WEAK, ReferenceMap.WEAK);
private final long perTaskCapacity;
private TreeMemoryTarget ofCapacity(long capacity) {
return roots.computeIfAbsent(
capacity,
cap ->
sparkConsumer.newChild(
String.format("Capacity[%s]", Utils.bytesToString(cap)),
cap,
Spillers.NOOP,
Collections.emptyMap()));
}

private TreeMemoryTarget legacyRoot() {
return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
}

private Factory(long perTaskCapacity) {
this.perTaskCapacity = perTaskCapacity;
private TreeMemoryTarget isolatedRoot() {
return ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}

@SuppressWarnings("unchecked")
private TreeMemoryTarget getSharedAccount(TaskMemoryManager tmm) {
synchronized (map) {
return (TreeMemoryTarget)
map.computeIfAbsent(
tmm,
m -> {
TreeMemoryTarget tmc = new TreeMemoryConsumer((TaskMemoryManager) m);
return tmc.newChild(
"root", perTaskCapacity, Spillers.NOOP, Collections.emptyMap());
});
}
/**
* This works as a legacy Spark memory consumer which grants as much as possible of memory
* capacity to each task.
*/
public TreeMemoryTarget newLegacyConsumer(
String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget parent = legacyRoot();
return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
}

public TreeMemoryTarget newConsumer(
TaskMemoryManager tmm,
String name,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget account = getSharedAccount(tmm);
return account.newChild(
name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
/**
* A hub to provide memory target instances whose shared size (in the same task) is limited to
* X, X = executor memory / task slots.
*
* <p>Using this to prevent OOMs if the delegated memory target could possibly hold large memory
* blocks that are not spill-able.
*
* <p>See <a href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
*/
public TreeMemoryTarget newIsolatedConsumer(
String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget parent = isolatedRoot();
return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, virtualChildren);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.task

import org.apache.gluten.GlutenConfig
import org.apache.gluten.task.TaskListener

import org.apache.spark.{TaskContext, TaskFailedReason, TaskKilledException, UnknownReason}
Expand Down Expand Up @@ -65,8 +66,8 @@ object TaskResources extends TaskListener with Logging {
properties.put(key, value)
case _ =>
}
properties.setIfMissing("spark.memory.offHeap.enabled", "true")
properties.setIfMissing("spark.memory.offHeap.size", "1TB")
properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_ENABLED, "true")
properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_SIZE_KEY, "1TB")
TaskContext.setTaskContext(newUnsafeTaskContext(properties))
}

Expand Down
Loading
Loading