Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Aug 7, 2024
1 parent fa1f313 commit 0df063d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.memory.MemoryUsageStatsBuilder
import org.apache.gluten.memory.listener.ReservationListeners
import org.apache.gluten.memory.listener.{ReservationListener, ReservationListeners}
import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.TaskResources

import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Callable, Executors, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._
import scala.util.Random

class ConcurrentMemoryAllocationSuite extends SparkFunSuite with SharedSparkSession {
class MassiveMemoryAllocationSuite extends SparkFunSuite with SharedSparkSession {
test("concurrent allocation with spill - shared listener") {
val numThreads = 50
val offHeapSize = 500
Expand Down Expand Up @@ -85,7 +103,8 @@ class ConcurrentMemoryAllocationSuite extends SparkFunSuite with SharedSparkSess
withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") {
TaskResources.runUnsafe {
val total = new AtomicLong(0L)
def newListener(id: Int) = {

def newListener(id: Int): ReservationListener = {
val spiller = Spillers.appendable()
val listener = ReservationListeners.create(
s"listener $id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ public void testSpill() {
Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
long repaid = shared.repay(size);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
spillers.append(
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
long repaid = shared.repay(size);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
Assert.assertEquals(300, shared.borrow(300));
Assert.assertEquals(300, shared.borrow(300));
Assert.assertEquals(1, numSpills.get());
Expand All @@ -138,8 +139,7 @@ public long spill(MemoryTarget self, Phase phase, long size) {
Assert.assertEquals(3, numSpills.get());
Assert.assertEquals(800, numSpilledBytes.get());
Assert.assertEquals(400, shared.usedBytes());
}
);
});
}

@Test
Expand All @@ -156,15 +156,16 @@ public void testOverSpill() {
Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
long repaid = shared.repay(Long.MAX_VALUE);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
spillers.append(
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
long repaid = shared.repay(Long.MAX_VALUE);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
Assert.assertEquals(300, shared.borrow(300));
Assert.assertEquals(300, shared.borrow(300));
Assert.assertEquals(1, numSpills.get());
Expand All @@ -176,8 +177,7 @@ public long spill(MemoryTarget self, Phase phase, long size) {
Assert.assertEquals(3, numSpills.get());
Assert.assertEquals(900, numSpilledBytes.get());
Assert.assertEquals(300, shared.usedBytes());
}
);
});
}

private void test(Runnable r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
import org.apache.gluten.memory.memtarget.MemoryTarget;

import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Reserve Spark managed memory.
*/
/** Reserve Spark managed memory. */
public class ManagedReservationListener implements ReservationListener {

private static final Logger LOG = LoggerFactory.getLogger(ManagedReservationListener.class);
Expand All @@ -38,7 +35,8 @@ public class ManagedReservationListener implements ReservationListener {
// See: https://github.com/apache/incubator-gluten/issues/6622
private final Object sharedLock;

public ManagedReservationListener(MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object sharedLock) {
public ManagedReservationListener(
MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object sharedLock) {
this.target = target;
this.sharedUsage = sharedUsage;
this.sharedLock = sharedLock;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

public final class ReservationListeners {
public static final ReservationListener NOOP =
new ManagedReservationListener(new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(),
new Object());
new ManagedReservationListener(
new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(), new Object());

public static ReservationListener create(
String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> mutableStats) {
Expand All @@ -47,24 +47,23 @@ private static ReservationListener create0(
final double overAcquiredRatio = GlutenConfig.getConf().memoryOverAcquiredRatio();
final long reservationBlockSize = GlutenConfig.getConf().memoryReservationBlockSize();
final TaskMemoryManager tmm = TaskResources.getLocalTaskContext().taskMemoryManager();
final TreeMemoryTarget consumer = MemoryTargets.newConsumer(
tmm,
name,
Spillers.withMinSpillSize(spiller, reservationBlockSize),
mutableStats);
final MemoryTarget overConsumer = MemoryTargets.newConsumer(
tmm,
consumer.name() + ".OverAcquire",
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
if (!Spillers.PHASE_SET_ALL.contains(phase)) {
return 0L;
}
return self.repay(size);
}
},
Collections.emptyMap());
final TreeMemoryTarget consumer =
MemoryTargets.newConsumer(
tmm, name, Spillers.withMinSpillSize(spiller, reservationBlockSize), mutableStats);
final MemoryTarget overConsumer =
MemoryTargets.newConsumer(
tmm,
consumer.name() + ".OverAcquire",
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
if (!Spillers.PHASE_SET_ALL.contains(phase)) {
return 0L;
}
return self.repay(size);
}
},
Collections.emptyMap());
final MemoryTarget target =
MemoryTargets.throwOnOom(
MemoryTargets.overAcquire(
Expand Down

0 comments on commit 0df063d

Please sign in to comment.