Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Fix RetryOnOomMemoryTarget only spills one single consumer on retrying #8262

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public interface MemoryTargetVisitor<T> {

T visit(TreeMemoryConsumer treeMemoryConsumer);

T visit(TreeMemoryTargets.Node node);
T visit(TreeMemoryConsumer.Node node);

T visit(LoggingMemoryTarget loggingMemoryTarget);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ public static TreeMemoryTarget newConsumer(
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory = TreeMemoryConsumers.factory(tmm);
if (GlutenConfig.getConf().memoryIsolation()) {
return factory.newIsolatedConsumer(name, spiller, virtualChildren);
return TreeMemoryTargets.newChild(factory.isolatedRoot(), name, spiller, virtualChildren);
}
final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, virtualChildren);
final TreeMemoryTarget root = factory.legacyRoot();
final TreeMemoryTarget consumer =
TreeMemoryTargets.newChild(root, name, spiller, virtualChildren);
if (SparkEnv.get() == null) {
// We are likely in test code. Return the consumer directly.
LOGGER.info("SparkEnv not found. We are likely in test code.");
Expand All @@ -86,7 +88,8 @@ public static TreeMemoryTarget newConsumer(
consumer,
() -> {
LOGGER.info("Request for spilling on consumer {}...", consumer.name());
long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);
// Note: Spill from root node so other consumers also get spilled.
long spilled = TreeMemoryTargets.spillTree(root, Long.MAX_VALUE);
kecookier marked this conversation as resolved.
Show resolved Hide resolved
LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), spilled);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,27 @@
package org.apache.gluten.memory.memtarget;

import org.apache.gluten.memory.MemoryUsageStatsBuilder;
import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
import org.apache.gluten.proto.MemoryUsageStats;

import com.google.common.base.Preconditions;
import org.apache.spark.util.Utils;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.stream.Collectors;

public class TreeMemoryTargets {

private TreeMemoryTargets() {
// enclose factory ctor
}

public static TreeMemoryTarget newChild(
/**
* A short-cut method to create a child target of `parent`. The child will follow the parent's
* maximum capacity.
*/
static TreeMemoryTarget newChild(
TreeMemoryTarget parent,
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
return new Node(parent, name, capacity, spiller, virtualChildren);
return parent.newChild(name, TreeMemoryTarget.CAPACITY_UNLIMITED, spiller, virtualChildren);
}

public static long spillTree(TreeMemoryTarget node, final long bytes) {
Expand Down Expand Up @@ -83,146 +78,4 @@ private static long spillTree(TreeMemoryTarget node, Spiller.Phase phase, final

return bytes - remainingBytes;
}

// non-root nodes are not Spark memory consumer
public static class Node implements TreeMemoryTarget, KnownNameAndStats {
private final Map<String, Node> children = new HashMap<>();
private final TreeMemoryTarget parent;
private final String name;
private final long capacity;
private final Spiller spiller;
private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder();

private Node(
TreeMemoryTarget parent,
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
this.parent = parent;
this.capacity = capacity;
final String uniqueName = MemoryTargetUtil.toUniqueName(name);
if (capacity == CAPACITY_UNLIMITED) {
this.name = uniqueName;
} else {
this.name = String.format("%s, %s", uniqueName, Utils.bytesToString(capacity));
}
this.spiller = spiller;
this.virtualChildren = virtualChildren;
}

@Override
public long borrow(long size) {
if (size == 0) {
return 0;
}
ensureFreeCapacity(size);
return borrow0(Math.min(freeBytes(), size));
}

private long freeBytes() {
return capacity - usedBytes();
}

private long borrow0(long size) {
long granted = parent.borrow(size);
selfRecorder.inc(granted);
return granted;
}

@Override
public Spiller getNodeSpiller() {
return spiller;
}

private boolean ensureFreeCapacity(long bytesNeeded) {
while (true) { // FIXME should we add retry limit?
long freeBytes = freeBytes();
Preconditions.checkState(freeBytes >= 0);
if (freeBytes >= bytesNeeded) {
// free bytes fit requirement
return true;
}
// spill
long bytesToSpill = bytesNeeded - freeBytes;
long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill);
Preconditions.checkState(spilledBytes >= 0);
if (spilledBytes == 0) {
// OOM
return false;
}
}
}

@Override
public long repay(long size) {
if (size == 0) {
return 0;
}
long toFree = Math.min(usedBytes(), size);
long freed = parent.repay(toFree);
selfRecorder.inc(-freed);
return freed;
}

@Override
public long usedBytes() {
return selfRecorder.current();
}

@Override
public <T> T accept(MemoryTargetVisitor<T> visitor) {
return visitor.visit(this);
}

@Override
public String name() {
return name;
}

@Override
public MemoryUsageStats stats() {
final Map<String, MemoryUsageStats> childrenStats =
new HashMap<>(
children.entrySet().stream()
.collect(Collectors.toMap(e -> e.getValue().name(), e -> e.getValue().stats())));

Preconditions.checkState(childrenStats.size() == children.size());

// add virtual children
for (Map.Entry<String, MemoryUsageStatsBuilder> entry : virtualChildren.entrySet()) {
if (childrenStats.containsKey(entry.getKey())) {
throw new IllegalArgumentException("Child stats already exists: " + entry.getKey());
}
childrenStats.put(entry.getKey(), entry.getValue().toStats());
}
return selfRecorder.toStats(childrenStats);
}

@Override
public TreeMemoryTarget newChild(
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child =
new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
children.put(child.name(), child);
return child;
}

@Override
public Map<String, TreeMemoryTarget> children() {
return Collections.unmodifiableMap(children);
}

@Override
public TreeMemoryTarget parent() {
return parent;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -121,8 +122,7 @@ public TreeMemoryTarget newChild(
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget child =
TreeMemoryTargets.newChild(this, name, capacity, spiller, virtualChildren);
final TreeMemoryTarget child = new Node(this, name, capacity, spiller, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
Expand Down Expand Up @@ -151,4 +151,145 @@ public Spiller getNodeSpiller() {
public TaskMemoryManager getTaskMemoryManager() {
return taskMemoryManager;
}

public static class Node implements TreeMemoryTarget, KnownNameAndStats {
private final Map<String, Node> children = new HashMap<>();
private final TreeMemoryTarget parent;
private final String name;
private final long capacity;
private final Spiller spiller;
private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder();

private Node(
TreeMemoryTarget parent,
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
this.parent = parent;
this.capacity = capacity;
final String uniqueName = MemoryTargetUtil.toUniqueName(name);
if (capacity == TreeMemoryTarget.CAPACITY_UNLIMITED) {
this.name = uniqueName;
} else {
this.name = String.format("%s, %s", uniqueName, Utils.bytesToString(capacity));
}
this.spiller = spiller;
this.virtualChildren = virtualChildren;
}

@Override
public long borrow(long size) {
if (size == 0) {
return 0;
}
ensureFreeCapacity(size);
return borrow0(Math.min(freeBytes(), size));
}

private long freeBytes() {
return capacity - usedBytes();
}

private long borrow0(long size) {
long granted = parent.borrow(size);
selfRecorder.inc(granted);
return granted;
}

@Override
public Spiller getNodeSpiller() {
return spiller;
}

private boolean ensureFreeCapacity(long bytesNeeded) {
while (true) { // FIXME should we add retry limit?
long freeBytes = freeBytes();
Preconditions.checkState(freeBytes >= 0);
if (freeBytes >= bytesNeeded) {
// free bytes fit requirement
return true;
}
// spill
long bytesToSpill = bytesNeeded - freeBytes;
long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill);
Preconditions.checkState(spilledBytes >= 0);
if (spilledBytes == 0) {
// OOM
return false;
}
}
}

@Override
public long repay(long size) {
if (size == 0) {
return 0;
}
long toFree = Math.min(usedBytes(), size);
long freed = parent.repay(toFree);
selfRecorder.inc(-freed);
return freed;
}

@Override
public long usedBytes() {
return selfRecorder.current();
}

@Override
public <T> T accept(MemoryTargetVisitor<T> visitor) {
return visitor.visit(this);
}

@Override
public String name() {
return name;
}

@Override
public MemoryUsageStats stats() {
final Map<String, MemoryUsageStats> childrenStats =
new HashMap<>(
children.entrySet().stream()
.collect(Collectors.toMap(e -> e.getValue().name(), e -> e.getValue().stats())));

Preconditions.checkState(childrenStats.size() == children.size());

// add virtual children
for (Map.Entry<String, MemoryUsageStatsBuilder> entry : virtualChildren.entrySet()) {
if (childrenStats.containsKey(entry.getKey())) {
throw new IllegalArgumentException("Child stats already exists: " + entry.getKey());
}
childrenStats.put(entry.getKey(), entry.getValue().toStats());
}
return selfRecorder.toStats(childrenStats);
}

@Override
public TreeMemoryTarget newChild(
String name,
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child =
new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
children.put(child.name(), child);
return child;
}

@Override
public Map<String, TreeMemoryTarget> children() {
return Collections.unmodifiableMap(children);
}

@Override
public TreeMemoryTarget parent() {
return parent;
}
}
}
Loading
Loading