diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ConcurrentMemoryAllocationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala similarity index 83% rename from backends-velox/src/test/scala/org/apache/gluten/execution/ConcurrentMemoryAllocationSuite.scala rename to backends-velox/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala index 1c1920111601..ebfa0e6123fd 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ConcurrentMemoryAllocationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala @@ -1,18 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.gluten.execution import org.apache.gluten.memory.MemoryUsageStatsBuilder -import org.apache.gluten.memory.listener.ReservationListeners +import org.apache.gluten.memory.listener.{ReservationListener, ReservationListeners} import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.TaskResources -import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Callable, Executors, TimeUnit} +import java.util.concurrent.atomic.AtomicLong + import scala.collection.JavaConverters._ import scala.util.Random -class ConcurrentMemoryAllocationSuite extends SparkFunSuite with SharedSparkSession { +class MassiveMemoryAllocationSuite extends SparkFunSuite with SharedSparkSession { test("concurrent allocation with spill - shared listener") { val numThreads = 50 val offHeapSize = 500 @@ -85,7 +103,8 @@ class ConcurrentMemoryAllocationSuite extends SparkFunSuite with SharedSparkSess withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") { TaskResources.runUnsafe { val total = new AtomicLong(0L) - def newListener(id: Int) = { + + def newListener(id: Int): ReservationListener = { val spiller = Spillers.appendable() val listener = ReservationListeners.create( s"listener $id", diff --git a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java index bdcb911772c7..bbc43ba5dea9 100644 --- a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java +++ b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java @@ -118,15 +118,16 @@ public void testSpill() { Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); - spillers.append(new Spiller() { - @Override - public long spill(MemoryTarget self, Phase phase, long size) { - long repaid = shared.repay(size); - numSpills.getAndIncrement(); - numSpilledBytes.getAndAdd(repaid); - return repaid; - } - }); + spillers.append( + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + long repaid = shared.repay(size); + numSpills.getAndIncrement(); + numSpilledBytes.getAndAdd(repaid); + return repaid; + } + }); Assert.assertEquals(300, shared.borrow(300)); Assert.assertEquals(300, shared.borrow(300)); Assert.assertEquals(1, numSpills.get()); @@ -138,8 +139,7 @@ public long spill(MemoryTarget self, Phase phase, long size) { Assert.assertEquals(3, numSpills.get()); Assert.assertEquals(800, numSpilledBytes.get()); Assert.assertEquals(400, shared.usedBytes()); - } - ); + }); } @Test @@ -156,15 +156,16 @@ public void testOverSpill() { Collections.emptyMap()); final AtomicInteger numSpills = new AtomicInteger(0); final AtomicLong numSpilledBytes = new AtomicLong(0L); - spillers.append(new Spiller() { - @Override - public long spill(MemoryTarget self, Phase phase, long size) { - long repaid = shared.repay(Long.MAX_VALUE); - numSpills.getAndIncrement(); - numSpilledBytes.getAndAdd(repaid); - return repaid; - } - }); + spillers.append( + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + long repaid = shared.repay(Long.MAX_VALUE); + numSpills.getAndIncrement(); + numSpilledBytes.getAndAdd(repaid); + return repaid; + } + }); Assert.assertEquals(300, shared.borrow(300)); Assert.assertEquals(300, shared.borrow(300)); Assert.assertEquals(1, numSpills.get()); @@ -176,8 +177,7 @@ public long spill(MemoryTarget self, Phase phase, long size) { Assert.assertEquals(3, numSpills.get()); Assert.assertEquals(900, numSpilledBytes.get()); Assert.assertEquals(300, shared.usedBytes()); - } - ); + }); } private void test(Runnable r) { diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java index 2638360b59a9..7c7fac8daacd 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java @@ -19,13 +19,10 @@ import org.apache.gluten.memory.SimpleMemoryUsageRecorder; import org.apache.gluten.memory.memtarget.MemoryTarget; -import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Reserve Spark managed memory. - */ +/** Reserve Spark managed memory. */ public class ManagedReservationListener implements ReservationListener { private static final Logger LOG = LoggerFactory.getLogger(ManagedReservationListener.class); @@ -38,7 +35,8 @@ public class ManagedReservationListener implements ReservationListener { // See: https://github.com/apache/incubator-gluten/issues/6622 private final Object sharedLock; - public ManagedReservationListener(MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object sharedLock) { + public ManagedReservationListener( + MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object sharedLock) { this.target = target; this.sharedUsage = sharedUsage; this.sharedLock = sharedLock; diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java index 19c57f800c7d..db5ac8426df0 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java @@ -29,8 +29,8 @@ public final class ReservationListeners { public static final ReservationListener NOOP = - new ManagedReservationListener(new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(), - new Object()); + new ManagedReservationListener( + new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(), new Object()); public static ReservationListener create( String name, Spiller spiller, Map mutableStats) { @@ -47,24 +47,23 @@ private static ReservationListener create0( final double overAcquiredRatio = GlutenConfig.getConf().memoryOverAcquiredRatio(); final long reservationBlockSize = GlutenConfig.getConf().memoryReservationBlockSize(); final TaskMemoryManager tmm = TaskResources.getLocalTaskContext().taskMemoryManager(); - final TreeMemoryTarget consumer = MemoryTargets.newConsumer( - tmm, - name, - Spillers.withMinSpillSize(spiller, reservationBlockSize), - mutableStats); - final MemoryTarget overConsumer = MemoryTargets.newConsumer( - tmm, - consumer.name() + ".OverAcquire", - new Spiller() { - @Override - public long spill(MemoryTarget self, Phase phase, long size) { - if (!Spillers.PHASE_SET_ALL.contains(phase)) { - return 0L; - } - return self.repay(size); - } - }, - Collections.emptyMap()); + final TreeMemoryTarget consumer = + MemoryTargets.newConsumer( + tmm, name, Spillers.withMinSpillSize(spiller, reservationBlockSize), mutableStats); + final MemoryTarget overConsumer = + MemoryTargets.newConsumer( + tmm, + consumer.name() + ".OverAcquire", + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + if (!Spillers.PHASE_SET_ALL.contains(phase)) { + return 0L; + } + return self.repay(size); + } + }, + Collections.emptyMap()); final MemoryTarget target = MemoryTargets.throwOnOom( MemoryTargets.overAcquire(