diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 442090004a41..e00eb88f2f09 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -54,22 +54,20 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { return targetBytes; } - uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override { - std::lock_guard l(mutex_); - return shrinkCapacityLocked(pool, targetBytes); - } - bool growCapacity( velox::memory::MemoryPool* pool, const std::vector>& candidatePools, uint64_t targetBytes) override { velox::memory::ScopedMemoryArbitrationContext ctx(pool); - VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool") - auto candidate = candidatePools.back(); - VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator"); + velox::memory::MemoryPool* candidate; + { + std::unique_lock guard{mutex_}; + VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool") + candidate = candidates_.begin()->first; + } + VELOX_CHECK(pool->root() == candidate, "Illegal state in ListenableArbitrator"); - std::lock_guard l(mutex_); - growCapacityLocked(pool->root(), targetBytes); + growCapacity0(pool->root(), targetBytes); return true; } @@ -80,16 +78,18 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { bool allowAbort) override { velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr); facebook::velox::exec::MemoryReclaimer::Stats status; - VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool"); - std::lock_guard l(mutex_); // FIXME: Do we have recursive locking for this mutex? - auto pool = pools.at(0); - const uint64_t oldCapacity = pool->capacity(); + velox::memory::MemoryPool* pool; + { + std::unique_lock guard{mutex_}; + VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool") + pool = candidates_.begin()->first; + } pool->reclaim(targetBytes, 0, status); // ignore the output - shrinkPool(pool.get(), 0); - const uint64_t newCapacity = pool->capacity(); - uint64_t total = oldCapacity - newCapacity; - listener_->allocationChanged(-total); - return total; + return shrinkCapacity0(pool, 0); + } + + uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override { + return shrinkCapacity0(pool, targetBytes); } Stats stats() const override { @@ -102,7 +102,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { } private: - void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) { + void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) { // Since // https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820, // We should pass bytes as parameter "reservationBytes" when calling ::grow. @@ -124,15 +124,19 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { pool->toString()) } - uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) { + uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) { uint64_t freeBytes = shrinkPool(pool, bytes); listener_->allocationChanged(-freeBytes); return freeBytes; } gluten::AllocationListener* listener_; - std::recursive_mutex mutex_; + const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused. + const uint64_t memoryPoolTransferCapacity_; + + mutable std::mutex mutex_; inline static std::string kind_ = "GLUTEN"; + std::unordered_map> candidates_; }; class ArbitratorFactoryRegister { diff --git a/cpp/velox/tests/MemoryManagerTest.cc b/cpp/velox/tests/MemoryManagerTest.cc index 52f2fa8b661c..f2a7dc46461d 100644 --- a/cpp/velox/tests/MemoryManagerTest.cc +++ b/cpp/velox/tests/MemoryManagerTest.cc @@ -129,4 +129,274 @@ TEST_F(MemoryManagerTest, memoryAllocatorWithBlockReservation) { ASSERT_EQ(allocator_->getBytes(), 0); } +namespace { +class AllocationListenerWrapper : public AllocationListener { + public: + explicit AllocationListenerWrapper() {} + + void set(AllocationListener* const delegate) { + if (delegate_ != nullptr) { + throw std::runtime_error("Invalid state"); + } + delegate_ = delegate; + } + + void allocationChanged(int64_t diff) override { + delegate_->allocationChanged(diff); + } + int64_t currentBytes() override { + return delegate_->currentBytes(); + } + int64_t peakBytes() override { + return delegate_->peakBytes(); + } + + private: + AllocationListener* delegate_{nullptr}; +}; + +class SpillableAllocationListener : public AllocationListener { + public: + virtual uint64_t shrink(uint64_t bytes) = 0; + virtual uint64_t spill(uint64_t bytes) = 0; +}; + +class MockSparkTaskMemoryManager { + public: + explicit MockSparkTaskMemoryManager(const uint64_t maxBytes); + + AllocationListener* newListener(std::function shrink, std::function spill); + + uint64_t acquire(uint64_t bytes); + void release(uint64_t bytes); + uint64_t currentBytes() { + return currentBytes_; + } + + private: + mutable std::recursive_mutex mutex_; + std::vector> listeners_{}; + + const uint64_t maxBytes_; + uint64_t currentBytes_{0L}; +}; + +class MockSparkAllocationListener : public SpillableAllocationListener { + public: + explicit MockSparkAllocationListener( + MockSparkTaskMemoryManager* const manager, + std::function shrink, + std::function spill) + : manager_(manager), shrink_(shrink), spill_(spill) {} + + void allocationChanged(int64_t diff) override { + if (diff == 0) { + return; + } + if (diff > 0) { + auto granted = manager_->acquire(diff); + if (granted < diff) { + throw std::runtime_error("OOM"); + } + currentBytes_ += granted; + return; + } + manager_->release(-diff); + currentBytes_ -= (-diff); + } + + uint64_t shrink(uint64_t bytes) override { + return shrink_(bytes); + } + + uint64_t spill(uint64_t bytes) override { + return spill_(bytes); + } + + int64_t currentBytes() override { + return currentBytes_; + } + + private: + MockSparkTaskMemoryManager* const manager_; + std::function shrink_; + std::function spill_; + std::atomic currentBytes_{0L}; +}; + +MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t maxBytes) : maxBytes_(maxBytes) {} + +AllocationListener* MockSparkTaskMemoryManager::newListener( + std::function shrink, + std::function spill) { + listeners_.push_back(std::make_unique(this, shrink, spill)); + return listeners_.back().get(); +} + +uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) { + std::unique_lock l(mutex_); + auto freeBytes = maxBytes_ - currentBytes_; + if (bytes <= freeBytes) { + currentBytes_ += bytes; + return bytes; + } + // Shrink listeners. + int64_t bytesNeeded = bytes - freeBytes; + for (const auto& listener : listeners_) { + bytesNeeded -= listener->shrink(bytesNeeded); + if (bytesNeeded < 0) { + break; + } + } + if (bytesNeeded > 0) { + for (const auto& listener : listeners_) { + bytesNeeded -= listener->spill(bytesNeeded); + if (bytesNeeded < 0) { + break; + } + } + } + + if (bytesNeeded > 0) { + uint64_t granted = bytes - bytesNeeded; + currentBytes_ += granted; + return granted; + } + + currentBytes_ += bytes; + return bytes; +} + +void MockSparkTaskMemoryManager::release(uint64_t bytes) { + std::unique_lock l(mutex_); + currentBytes_ -= bytes; +} + +class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer { + public: + explicit MockMemoryReclaimer(std::vector& buffs, int32_t size) : buffs_(buffs), size_(size) {} + + bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t& reclaimableBytes) const override { + uint64_t total = 0; + for (const auto& buf : buffs_) { + if (buf == nullptr) { + continue; + } + total += size_; + } + if (total == 0) { + return false; + } + reclaimableBytes = total; + return true; + } + + uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t maxWaitMs, Stats& stats) override { + uint64_t total = 0; + for (auto& buf : buffs_) { + if (buf == nullptr) { + // When: + // 1. Called by allocation from the same pool so buff is not allocated yet. + // 2. Already called once. + continue; + } + pool->free(buf, size_); + buf = nullptr; + total += size_; + } + return total; + } + + private: + std::vector& buffs_; + int32_t size_; +}; + +void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm, std::vector>& vmms) { + uint64_t sum = 0; + for (const auto& vmm : vmms) { + if (vmm == nullptr) { + continue; + } + sum += vmm->getAggregateMemoryPool()->capacity(); + } + if (tmm.currentBytes() != sum) { + ASSERT_EQ(tmm.currentBytes(), sum); + } +} +} // namespace + +class MultiMemoryManagerTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + std::unordered_map conf = { + {kMemoryReservationBlockSize, std::to_string(kMemoryReservationBlockSizeDefault)}, + {kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}}; + gluten::VeloxBackend::create(conf); + } + + std::unique_ptr newVeloxMemoryManager(std::unique_ptr listener) { + return std::make_unique(std::move(listener)); + } +}; + +TEST_F(MultiMemoryManagerTest, spill) { + const uint64_t maxBytes = 200 << 20; + const uint32_t numThreads = 100; + const uint32_t numAllocations = 200; + const int32_t allocateSize = 10 << 20; + + MockSparkTaskMemoryManager tmm{maxBytes}; + std::vector> vmms{}; + std::vector threads{}; + std::vector> buffs{}; + for (size_t i = 0; i < numThreads; ++i) { + buffs.push_back({}); + vmms.emplace_back(nullptr); + } + + // Emulate a shared lock to avoid ABBA deadlock. + std::recursive_mutex mutex; + + for (size_t i = 0; i < numThreads; ++i) { + threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex, &buffs]() -> void { + auto wrapper = std::make_unique(); // Set later. + auto* listener = wrapper.get(); + + facebook::velox::memory::MemoryPool* pool; // Set later. + { + std::unique_lock l(mutex); + vmms[i] = newVeloxMemoryManager(std::move(wrapper)); + pool = vmms[i]->getLeafMemoryPool().get(); + pool->setReclaimer(std::make_unique(buffs[i], allocateSize)); + listener->set(tmm.newListener( + [](uint64_t bytes) -> uint64_t { return 0; }, + [i, &vmms, &mutex](uint64_t bytes) -> uint64_t { + std::unique_lock l(mutex); + return vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes); + })); + } + { + std::unique_lock l(mutex); + for (size_t j = 0; j < numAllocations; ++j) { + assertCapacitiesMatch(tmm, vmms); + buffs[i].push_back(pool->allocate(allocateSize)); + assertCapacitiesMatch(tmm, vmms); + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + for (auto& vmm : vmms) { + assertCapacitiesMatch(tmm, vmms); + vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize * numAllocations); + assertCapacitiesMatch(tmm, vmms); + } + + ASSERT_EQ(tmm.currentBytes(), 0); +} } // namespace gluten diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java b/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java index 16b260469f5f..fb8b0d1e2b61 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/SimpleMemoryUsageRecorder.java @@ -30,13 +30,13 @@ public class SimpleMemoryUsageRecorder implements MemoryUsageRecorder { @Override public void inc(long bytes) { final long total = this.current.addAndGet(bytes); - long prev_peak; + long prevPeak; do { - prev_peak = this.peak.get(); - if (total <= prev_peak) { + prevPeak = this.peak.get(); + if (total <= prevPeak) { break; } - } while (!this.peak.compareAndSet(prev_peak, total)); + } while (!this.peak.compareAndSet(prevPeak, total)); } // peak used bytes diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java index 75e3db2e7d1f..bb1e7102b1c3 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java @@ -50,7 +50,7 @@ public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarg return memoryTarget; } - public static MemoryTarget newConsumer( + public static TreeMemoryTarget newConsumer( TaskMemoryManager tmm, String name, Spiller spiller, diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java index ac82161ba7a5..e7321b4b7e0e 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/OverAcquire.java @@ -52,31 +52,28 @@ public class OverAcquire implements MemoryTarget { @Override public long borrow(long size) { - Preconditions.checkArgument(size != 0, "Size to borrow is zero"); + if (size == 0) { + return 0; + } + Preconditions.checkState(overTarget.usedBytes() == 0); long granted = target.borrow(size); long majorSize = target.usedBytes(); - long expectedOverAcquired = (long) (ratio * majorSize); - long overAcquired = overTarget.usedBytes(); - long diff = expectedOverAcquired - overAcquired; - if (diff >= 0) { // otherwise, there might be a spill happened during the last borrow() call - overTarget.borrow(diff); // we don't have to check the returned value - } + long overSize = (long) (ratio * majorSize); + long overAcquired = overTarget.borrow(overSize); + Preconditions.checkState(overAcquired == overTarget.usedBytes()); + long releasedOverSize = overTarget.repay(overAcquired); + Preconditions.checkState(releasedOverSize == overAcquired); + Preconditions.checkState(overTarget.usedBytes() == 0); return granted; } @Override public long repay(long size) { - Preconditions.checkArgument(size != 0, "Size to repay is zero"); - long freed = target.repay(size); - // clean up the over-acquired target - long overAcquired = overTarget.usedBytes(); - long freedOverAcquired = overTarget.repay(overAcquired); - Preconditions.checkArgument( - freedOverAcquired == overAcquired, - "Freed over-acquired size is not equal to requested size"); - Preconditions.checkArgument( - overTarget.usedBytes() == 0, "Over-acquired target was not cleaned up"); - return freed; + if (size == 0) { + return 0; + } + Preconditions.checkState(overTarget.usedBytes() == 0); + return target.repay(size); } @Override diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java index 24d9fc0e2d4a..98f79bfff367 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java @@ -114,6 +114,9 @@ private Node( @Override public long borrow(long size) { + if (size == 0) { + return 0; + } ensureFreeCapacity(size); return borrow0(Math.min(freeBytes(), size)); } @@ -154,6 +157,9 @@ private boolean ensureFreeCapacity(long bytesNeeded) { @Override public long repay(long size) { + if (size == 0) { + return 0; + } long toFree = Math.min(usedBytes(), size); long freed = parent.repay(toFree); selfRecorder.inc(-freed); diff --git a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java index db018ffe4043..62bb59f78efb 100644 --- a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java +++ b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java @@ -17,6 +17,8 @@ package org.apache.gluten.memory.memtarget.spark; import org.apache.gluten.GlutenConfig; +import org.apache.gluten.memory.memtarget.MemoryTarget; +import org.apache.gluten.memory.memtarget.Spiller; import org.apache.gluten.memory.memtarget.Spillers; import org.apache.gluten.memory.memtarget.TreeMemoryTarget; @@ -27,6 +29,8 @@ import org.junit.Test; import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import scala.Function0; @@ -102,7 +106,83 @@ public void testIsolatedAndShared() { }); } - private void test(SQLConf conf, Runnable r) { + @Test + public void testSpill() { + test( + () -> { + final Spillers.AppendableSpillerList spillers = Spillers.appendable(); + final TreeMemoryTarget shared = + TreeMemoryConsumers.shared() + .newConsumer( + TaskContext.get().taskMemoryManager(), + "FOO", + spillers, + Collections.emptyMap()); + final AtomicInteger numSpills = new AtomicInteger(0); + final AtomicLong numSpilledBytes = new AtomicLong(0L); + spillers.append( + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + long repaid = shared.repay(size); + numSpills.getAndIncrement(); + numSpilledBytes.getAndAdd(repaid); + return repaid; + } + }); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(1, numSpills.get()); + Assert.assertEquals(200, numSpilledBytes.get()); + Assert.assertEquals(400, shared.usedBytes()); + + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(3, numSpills.get()); + Assert.assertEquals(800, numSpilledBytes.get()); + Assert.assertEquals(400, shared.usedBytes()); + }); + } + + @Test + public void testOverSpill() { + test( + () -> { + final Spillers.AppendableSpillerList spillers = Spillers.appendable(); + final TreeMemoryTarget shared = + TreeMemoryConsumers.shared() + .newConsumer( + TaskContext.get().taskMemoryManager(), + "FOO", + spillers, + Collections.emptyMap()); + final AtomicInteger numSpills = new AtomicInteger(0); + final AtomicLong numSpilledBytes = new AtomicLong(0L); + spillers.append( + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + long repaid = shared.repay(Long.MAX_VALUE); + numSpills.getAndIncrement(); + numSpilledBytes.getAndAdd(repaid); + return repaid; + } + }); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(1, numSpills.get()); + Assert.assertEquals(300, numSpilledBytes.get()); + Assert.assertEquals(300, shared.usedBytes()); + + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(300, shared.borrow(300)); + Assert.assertEquals(3, numSpills.get()); + Assert.assertEquals(900, numSpilledBytes.get()); + Assert.assertEquals(300, shared.usedBytes()); + }); + } + + private void test(Runnable r) { TaskResources$.MODULE$.runUnsafe( new Function0() { @Override diff --git a/gluten-data/pom.xml b/gluten-data/pom.xml index 654045b5a885..65bafdeb00c8 100644 --- a/gluten-data/pom.xml +++ b/gluten-data/pom.xml @@ -191,6 +191,52 @@ + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-core + 2.23.4 + test + + + junit + junit + test + + + org.scalatestplus + scalatestplus-mockito_${scala.binary.version} + 1.0.0-M2 + test + + + org.scalatestplus + scalatestplus-scalacheck_${scala.binary.version} + 3.1.0.0-RC2 + test + diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java index b7d6ecd67589..7c7fac8daacd 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ManagedReservationListener.java @@ -19,7 +19,6 @@ import org.apache.gluten.memory.SimpleMemoryUsageRecorder; import org.apache.gluten.memory.memtarget.MemoryTarget; -import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,16 +28,23 @@ public class ManagedReservationListener implements ReservationListener { private static final Logger LOG = LoggerFactory.getLogger(ManagedReservationListener.class); private final MemoryTarget target; - private final SimpleMemoryUsageRecorder sharedUsage; // shared task metrics + // Metrics shared by task. + private final SimpleMemoryUsageRecorder sharedUsage; + // Lock shared by task. Using a common lock avoids ABBA deadlock + // when multiple listeners created under the same TMM. + // See: https://github.com/apache/incubator-gluten/issues/6622 + private final Object sharedLock; - public ManagedReservationListener(MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage) { + public ManagedReservationListener( + MemoryTarget target, SimpleMemoryUsageRecorder sharedUsage, Object sharedLock) { this.target = target; this.sharedUsage = sharedUsage; + this.sharedLock = sharedLock; } @Override public long reserve(long size) { - synchronized (this) { + synchronized (sharedLock) { try { long granted = target.borrow(size); sharedUsage.inc(granted); @@ -52,11 +58,15 @@ public long reserve(long size) { @Override public long unreserve(long size) { - synchronized (this) { - long freed = target.repay(size); - sharedUsage.inc(-freed); - Preconditions.checkState(freed == size); - return freed; + synchronized (sharedLock) { + try { + long freed = target.repay(size); + sharedUsage.inc(-freed); + return freed; + } catch (Exception e) { + LOG.error("Error unreserving memory from target", e); + throw e; + } } } diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java index 47b9937eb7a3..db5ac8426df0 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/listener/ReservationListeners.java @@ -29,7 +29,8 @@ public final class ReservationListeners { public static final ReservationListener NOOP = - new ManagedReservationListener(new NoopMemoryTarget(), new SimpleMemoryUsageRecorder()); + new ManagedReservationListener( + new NoopMemoryTarget(), new SimpleMemoryUsageRecorder(), new Object()); public static ReservationListener create( String name, Spiller spiller, Map mutableStats) { @@ -46,32 +47,31 @@ private static ReservationListener create0( final double overAcquiredRatio = GlutenConfig.getConf().memoryOverAcquiredRatio(); final long reservationBlockSize = GlutenConfig.getConf().memoryReservationBlockSize(); final TaskMemoryManager tmm = TaskResources.getLocalTaskContext().taskMemoryManager(); + final TreeMemoryTarget consumer = + MemoryTargets.newConsumer( + tmm, name, Spillers.withMinSpillSize(spiller, reservationBlockSize), mutableStats); + final MemoryTarget overConsumer = + MemoryTargets.newConsumer( + tmm, + consumer.name() + ".OverAcquire", + new Spiller() { + @Override + public long spill(MemoryTarget self, Phase phase, long size) { + if (!Spillers.PHASE_SET_ALL.contains(phase)) { + return 0L; + } + return self.repay(size); + } + }, + Collections.emptyMap()); final MemoryTarget target = MemoryTargets.throwOnOom( MemoryTargets.overAcquire( - MemoryTargets.dynamicOffHeapSizingIfEnabled( - MemoryTargets.newConsumer( - tmm, - name, - Spillers.withMinSpillSize(spiller, reservationBlockSize), - mutableStats)), - MemoryTargets.dynamicOffHeapSizingIfEnabled( - MemoryTargets.newConsumer( - tmm, - "OverAcquire.DummyTarget", - new Spiller() { - @Override - public long spill(MemoryTarget self, Spiller.Phase phase, long size) { - if (!Spillers.PHASE_SET_ALL.contains(phase)) { - return 0L; - } - return self.repay(size); - } - }, - Collections.emptyMap())), + MemoryTargets.dynamicOffHeapSizingIfEnabled(consumer), + MemoryTargets.dynamicOffHeapSizingIfEnabled(overConsumer), overAcquiredRatio)); // Listener. - return new ManagedReservationListener(target, TaskResources.getSharedUsage()); + return new ManagedReservationListener(target, TaskResources.getSharedUsage(), tmm); } } diff --git a/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala b/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala new file mode 100644 index 000000000000..ebfa0e6123fd --- /dev/null +++ b/gluten-data/src/test/scala/org/apache/gluten/execution/MassiveMemoryAllocationSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.memory.MemoryUsageStatsBuilder +import org.apache.gluten.memory.listener.{ReservationListener, ReservationListeners} +import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.TaskResources + +import java.util.concurrent.{Callable, Executors, TimeUnit} +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.JavaConverters._ +import scala.util.Random + +class MassiveMemoryAllocationSuite extends SparkFunSuite with SharedSparkSession { + test("concurrent allocation with spill - shared listener") { + val numThreads = 50 + val offHeapSize = 500 + val minExtraSpillSize = 2 + val maxExtraSpillSize = 5 + val numAllocations = 100 + val minAllocationSize = 40 + val maxAllocationSize = 100 + val minAllocationDelayMs = 0 + val maxAllocationDelayMs = 0 + withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") { + val total = new AtomicLong(0L) + TaskResources.runUnsafe { + val spiller = Spillers.appendable() + val listener = ReservationListeners.create( + s"listener", + spiller, + Map[String, MemoryUsageStatsBuilder]().asJava) + spiller.append(new Spiller() { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + val extraSpillSize = randomInt(minExtraSpillSize, maxExtraSpillSize) + val spillSize = size + extraSpillSize + val released = listener.unreserve(spillSize) + assert(released <= spillSize) + total.getAndAdd(-released) + spillSize + } + }) + val pool = Executors.newFixedThreadPool(numThreads) + val tasks = (0 until numThreads).map { + _ => + new Callable[Unit]() { + override def call(): Unit = { + (0 until numAllocations).foreach { + _ => + val allocSize = + randomInt(minAllocationSize, maxAllocationSize) + val granted = listener.reserve(allocSize) + assert(granted == allocSize) + total.getAndAdd(granted) + val sleepMs = + randomInt(minAllocationDelayMs, maxAllocationDelayMs) + Thread.sleep(sleepMs) + } + } + } + }.toList + val futures = pool.invokeAll(tasks.asJava) + pool.shutdown() + pool.awaitTermination(60, TimeUnit.SECONDS) + futures.forEach(_.get()) + val totalBytes = total.get() + val released = listener.unreserve(totalBytes) + assert(released == totalBytes) + assert(listener.getUsedBytes == 0) + } + } + } + + test("concurrent allocation with spill - dedicated listeners") { + val numThreads = 50 + val offHeapSize = 500 + val minExtraSpillSize = 2 + val maxExtraSpillSize = 5 + val numAllocations = 100 + val minAllocationSize = 40 + val maxAllocationSize = 100 + val minAllocationDelayMs = 0 + val maxAllocationDelayMs = 0 + withSQLConf("spark.memory.offHeap.size" -> s"$offHeapSize") { + TaskResources.runUnsafe { + val total = new AtomicLong(0L) + + def newListener(id: Int): ReservationListener = { + val spiller = Spillers.appendable() + val listener = ReservationListeners.create( + s"listener $id", + spiller, + Map[String, MemoryUsageStatsBuilder]().asJava) + spiller.append(new Spiller() { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + val extraSpillSize = randomInt(minExtraSpillSize, maxExtraSpillSize) + val spillSize = size + extraSpillSize + val released = listener.unreserve(spillSize) + assert(released <= spillSize) + total.getAndAdd(-released) + spillSize + } + }) + listener + } + + val listeners = (0 until numThreads).map(newListener).toList + val pool = Executors.newFixedThreadPool(numThreads) + val tasks = (0 until numThreads).map { + i => + new Callable[Unit]() { + override def call(): Unit = { + val listener = listeners(i) + (0 until numAllocations).foreach { + _ => + val allocSize = + randomInt(minAllocationSize, maxAllocationSize) + val granted = listener.reserve(allocSize) + assert(granted == allocSize) + total.getAndAdd(granted) + val sleepMs = + randomInt(minAllocationDelayMs, maxAllocationDelayMs) + Thread.sleep(sleepMs) + } + } + } + }.toList + val futures = pool.invokeAll(tasks.asJava) + pool.shutdown() + pool.awaitTermination(60, TimeUnit.SECONDS) + futures.forEach(_.get()) + val totalBytes = total.get() + val remaining = listeners.foldLeft(totalBytes) { + case (remainingBytes, listener) => + assert(remainingBytes >= 0) + val unreserved = listener.unreserve(remainingBytes) + remainingBytes - unreserved + } + assert(remaining == 0) + assert(listeners.map(_.getUsedBytes).sum == 0) + } + } + } + + private def randomInt(from: Int, to: Int): Int = { + from + Random.nextInt(to - from + 1) + } +}