From daddb3b1707e62781bd646413f722eb1b975552d Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 12 Aug 2024 17:02:13 +0800 Subject: [PATCH] Fix conflict in VeloxMemoryManager.cc [GLUTEN-6736][VL] Phase 2: Minimize lock scope in ListenableArbitrator (#6783) Closes #6736 --- cpp/velox/memory/VeloxMemoryManager.cc | 42 ++-- cpp/velox/tests/MemoryManagerTest.cc | 270 +++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 17 deletions(-) diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 442090004a41..8c39841cadd2 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -64,12 +64,15 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { const std::vector>& candidatePools, uint64_t targetBytes) override { velox::memory::ScopedMemoryArbitrationContext ctx(pool); - VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool") - auto candidate = candidatePools.back(); - VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator"); + velox::memory::MemoryPool* candidate; + { + std::unique_lock guard{mutex_}; + VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool") + candidate = candidates_.begin()->first; + } + VELOX_CHECK(pool->root() == candidate, "Illegal state in ListenableArbitrator"); - std::lock_guard l(mutex_); - growCapacityLocked(pool->root(), targetBytes); + growCapacity0(pool->root(), targetBytes); return true; } @@ -80,16 +83,18 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { bool allowAbort) override { velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr); facebook::velox::exec::MemoryReclaimer::Stats status; - VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool"); - std::lock_guard l(mutex_); // FIXME: Do we have recursive locking for this mutex? - auto pool = pools.at(0); - const uint64_t oldCapacity = pool->capacity(); + velox::memory::MemoryPool* pool; + { + std::unique_lock guard{mutex_}; + VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool") + pool = candidates_.begin()->first; + } pool->reclaim(targetBytes, 0, status); // ignore the output - shrinkPool(pool.get(), 0); - const uint64_t newCapacity = pool->capacity(); - uint64_t total = oldCapacity - newCapacity; - listener_->allocationChanged(-total); - return total; + return shrinkCapacity0(pool, 0); + } + + uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override { + return shrinkCapacity0(pool, targetBytes); } Stats stats() const override { @@ -102,7 +107,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { } private: - void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) { + void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) { // Since // https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820, // We should pass bytes as parameter "reservationBytes" when calling ::grow. @@ -124,14 +129,17 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { pool->toString()) } - uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) { + uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) { uint64_t freeBytes = shrinkPool(pool, bytes); listener_->allocationChanged(-freeBytes); return freeBytes; } gluten::AllocationListener* listener_; - std::recursive_mutex mutex_; + const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused. + const uint64_t memoryPoolTransferCapacity_; + + mutable std::mutex mutex_; inline static std::string kind_ = "GLUTEN"; }; diff --git a/cpp/velox/tests/MemoryManagerTest.cc b/cpp/velox/tests/MemoryManagerTest.cc index 52f2fa8b661c..f2a7dc46461d 100644 --- a/cpp/velox/tests/MemoryManagerTest.cc +++ b/cpp/velox/tests/MemoryManagerTest.cc @@ -129,4 +129,274 @@ TEST_F(MemoryManagerTest, memoryAllocatorWithBlockReservation) { ASSERT_EQ(allocator_->getBytes(), 0); } +namespace { +class AllocationListenerWrapper : public AllocationListener { + public: + explicit AllocationListenerWrapper() {} + + void set(AllocationListener* const delegate) { + if (delegate_ != nullptr) { + throw std::runtime_error("Invalid state"); + } + delegate_ = delegate; + } + + void allocationChanged(int64_t diff) override { + delegate_->allocationChanged(diff); + } + int64_t currentBytes() override { + return delegate_->currentBytes(); + } + int64_t peakBytes() override { + return delegate_->peakBytes(); + } + + private: + AllocationListener* delegate_{nullptr}; +}; + +class SpillableAllocationListener : public AllocationListener { + public: + virtual uint64_t shrink(uint64_t bytes) = 0; + virtual uint64_t spill(uint64_t bytes) = 0; +}; + +class MockSparkTaskMemoryManager { + public: + explicit MockSparkTaskMemoryManager(const uint64_t maxBytes); + + AllocationListener* newListener(std::function shrink, std::function spill); + + uint64_t acquire(uint64_t bytes); + void release(uint64_t bytes); + uint64_t currentBytes() { + return currentBytes_; + } + + private: + mutable std::recursive_mutex mutex_; + std::vector> listeners_{}; + + const uint64_t maxBytes_; + uint64_t currentBytes_{0L}; +}; + +class MockSparkAllocationListener : public SpillableAllocationListener { + public: + explicit MockSparkAllocationListener( + MockSparkTaskMemoryManager* const manager, + std::function shrink, + std::function spill) + : manager_(manager), shrink_(shrink), spill_(spill) {} + + void allocationChanged(int64_t diff) override { + if (diff == 0) { + return; + } + if (diff > 0) { + auto granted = manager_->acquire(diff); + if (granted < diff) { + throw std::runtime_error("OOM"); + } + currentBytes_ += granted; + return; + } + manager_->release(-diff); + currentBytes_ -= (-diff); + } + + uint64_t shrink(uint64_t bytes) override { + return shrink_(bytes); + } + + uint64_t spill(uint64_t bytes) override { + return spill_(bytes); + } + + int64_t currentBytes() override { + return currentBytes_; + } + + private: + MockSparkTaskMemoryManager* const manager_; + std::function shrink_; + std::function spill_; + std::atomic currentBytes_{0L}; +}; + +MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t maxBytes) : maxBytes_(maxBytes) {} + +AllocationListener* MockSparkTaskMemoryManager::newListener( + std::function shrink, + std::function spill) { + listeners_.push_back(std::make_unique(this, shrink, spill)); + return listeners_.back().get(); +} + +uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) { + std::unique_lock l(mutex_); + auto freeBytes = maxBytes_ - currentBytes_; + if (bytes <= freeBytes) { + currentBytes_ += bytes; + return bytes; + } + // Shrink listeners. + int64_t bytesNeeded = bytes - freeBytes; + for (const auto& listener : listeners_) { + bytesNeeded -= listener->shrink(bytesNeeded); + if (bytesNeeded < 0) { + break; + } + } + if (bytesNeeded > 0) { + for (const auto& listener : listeners_) { + bytesNeeded -= listener->spill(bytesNeeded); + if (bytesNeeded < 0) { + break; + } + } + } + + if (bytesNeeded > 0) { + uint64_t granted = bytes - bytesNeeded; + currentBytes_ += granted; + return granted; + } + + currentBytes_ += bytes; + return bytes; +} + +void MockSparkTaskMemoryManager::release(uint64_t bytes) { + std::unique_lock l(mutex_); + currentBytes_ -= bytes; +} + +class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer { + public: + explicit MockMemoryReclaimer(std::vector& buffs, int32_t size) : buffs_(buffs), size_(size) {} + + bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t& reclaimableBytes) const override { + uint64_t total = 0; + for (const auto& buf : buffs_) { + if (buf == nullptr) { + continue; + } + total += size_; + } + if (total == 0) { + return false; + } + reclaimableBytes = total; + return true; + } + + uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t maxWaitMs, Stats& stats) override { + uint64_t total = 0; + for (auto& buf : buffs_) { + if (buf == nullptr) { + // When: + // 1. Called by allocation from the same pool so buff is not allocated yet. + // 2. Already called once. + continue; + } + pool->free(buf, size_); + buf = nullptr; + total += size_; + } + return total; + } + + private: + std::vector& buffs_; + int32_t size_; +}; + +void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm, std::vector>& vmms) { + uint64_t sum = 0; + for (const auto& vmm : vmms) { + if (vmm == nullptr) { + continue; + } + sum += vmm->getAggregateMemoryPool()->capacity(); + } + if (tmm.currentBytes() != sum) { + ASSERT_EQ(tmm.currentBytes(), sum); + } +} +} // namespace + +class MultiMemoryManagerTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + std::unordered_map conf = { + {kMemoryReservationBlockSize, std::to_string(kMemoryReservationBlockSizeDefault)}, + {kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}}; + gluten::VeloxBackend::create(conf); + } + + std::unique_ptr newVeloxMemoryManager(std::unique_ptr listener) { + return std::make_unique(std::move(listener)); + } +}; + +TEST_F(MultiMemoryManagerTest, spill) { + const uint64_t maxBytes = 200 << 20; + const uint32_t numThreads = 100; + const uint32_t numAllocations = 200; + const int32_t allocateSize = 10 << 20; + + MockSparkTaskMemoryManager tmm{maxBytes}; + std::vector> vmms{}; + std::vector threads{}; + std::vector> buffs{}; + for (size_t i = 0; i < numThreads; ++i) { + buffs.push_back({}); + vmms.emplace_back(nullptr); + } + + // Emulate a shared lock to avoid ABBA deadlock. + std::recursive_mutex mutex; + + for (size_t i = 0; i < numThreads; ++i) { + threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex, &buffs]() -> void { + auto wrapper = std::make_unique(); // Set later. + auto* listener = wrapper.get(); + + facebook::velox::memory::MemoryPool* pool; // Set later. + { + std::unique_lock l(mutex); + vmms[i] = newVeloxMemoryManager(std::move(wrapper)); + pool = vmms[i]->getLeafMemoryPool().get(); + pool->setReclaimer(std::make_unique(buffs[i], allocateSize)); + listener->set(tmm.newListener( + [](uint64_t bytes) -> uint64_t { return 0; }, + [i, &vmms, &mutex](uint64_t bytes) -> uint64_t { + std::unique_lock l(mutex); + return vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes); + })); + } + { + std::unique_lock l(mutex); + for (size_t j = 0; j < numAllocations; ++j) { + assertCapacitiesMatch(tmm, vmms); + buffs[i].push_back(pool->allocate(allocateSize)); + assertCapacitiesMatch(tmm, vmms); + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + for (auto& vmm : vmms) { + assertCapacitiesMatch(tmm, vmms); + vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize * numAllocations); + assertCapacitiesMatch(tmm, vmms); + } + + ASSERT_EQ(tmm.currentBytes(), 0); +} } // namespace gluten