diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index f49beaccd264..78b137e32f94 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -45,7 +45,8 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { } uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override { - std::lock_guard l(mutex_); + std::lock_guard l(mutex_); + return shrinkCapacityLocked(pool, targetBytes); } @@ -53,12 +54,15 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { velox::memory::MemoryPool* pool, const std::vector>& candidatePools, uint64_t targetBytes) override { - velox::memory::ScopedMemoryArbitrationContext ctx(pool); + ArbitrationOperation op(pool, targetBytes, candidatePools); + ScopedArbitration scopedArbitration(this, &op); + // velox::memory::ScopedMemoryArbitrationContext ctx(pool); VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool") auto candidate = candidatePools.back(); VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator"); - std::lock_guard l(mutex_); + std::lock_guard l(mutex_); + growCapacityLocked(pool->root(), targetBytes); return true; } @@ -68,10 +72,13 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { uint64_t targetBytes, bool allowSpill, bool allowAbort) override { - velox::memory::ScopedMemoryArbitrationContext ctx(nullptr); + ArbitrationOperation op(targetBytes, pools); + ScopedArbitration scopedArbitration(this, &op); + // velox::memory::ScopedMemoryArbitrationContext ctx(nullptr); facebook::velox::exec::MemoryReclaimer::Stats status; VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool"); - std::lock_guard l(mutex_); // FIXME: Do we have recursive locking for this mutex? + std::lock_guard l(mutex_); + auto pool = pools.at(0); const uint64_t oldCapacity = pool->capacity(); pool->reclaim(targetBytes, 0, status); // ignore the output @@ -92,6 +99,96 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { } private: + struct ArbitrationOperation { + velox::memory::MemoryPool* const requestPool; + velox::memory::MemoryPool* const requestRoot; + const uint64_t targetBytes; + + ArbitrationOperation( + uint64_t targetBytes, + const std::vector>& candidatePools) + : ArbitrationOperation(nullptr, targetBytes, candidatePools) {} + + ArbitrationOperation( + velox::memory::MemoryPool* _requestor, + uint64_t _targetBytes, + const std::vector>& _candidatePools) + : requestPool(_requestor), + requestRoot(_requestor == nullptr ? nullptr : _requestor->root()), + targetBytes(_targetBytes) {} + + void enterArbitration() { + if (requestPool != nullptr) { + requestPool->enterArbitration(); + } + } + void leaveArbitration() { + if (requestPool != nullptr) { + requestPool->leaveArbitration(); + } + } + }; + + class ScopedArbitration { + public: + ScopedArbitration(ListenableArbitrator* arbitrator, ArbitrationOperation* op) + : operation_(op), arbitrator_(arbitrator), arbitrationCtx_(op->requestPool) { + VELOX_CHECK_NOT_NULL(arbitrator_); + VELOX_CHECK_NOT_NULL(operation_); + operation_->enterArbitration(); + if (arbitrator_->arbitrationStateCheckCb_ != nullptr && operation_->requestPool != nullptr) { + arbitrator_->arbitrationStateCheckCb_(*operation_->requestPool); + } + arbitrator_->startArbitration(operation_); + } + + ~ScopedArbitration() { + operation_->leaveArbitration(); + arbitrator_->finishArbitration(operation_); + } + + private: + ArbitrationOperation* const operation_; + ListenableArbitrator* const arbitrator_; + const velox::memory::ScopedMemoryArbitrationContext arbitrationCtx_; + }; + + void startArbitration(ArbitrationOperation* op) { + velox::ContinueFuture waitPromise{velox::ContinueFuture::makeEmpty()}; + { + std::lock_guard l(mutex_); + if (running_) { + waitPromises_.emplace_back( + fmt::format("Wait for arbitration {}/{}", op->requestPool->name(), op->requestRoot->name())); + waitPromise = waitPromises_.back().getSemiFuture(); + } else { + VELOX_CHECK(waitPromises_.empty()); + running_ = true; + } + } + + if (waitPromise.valid()) { + waitPromise.wait(); + } + } + + void finishArbitration(ArbitrationOperation* op) { + velox::ContinuePromise resumePromise{velox::ContinuePromise::makeEmpty()}; + { + std::lock_guard l(mutex_); + VELOX_CHECK(running_); + if (!waitPromises_.empty()) { + resumePromise = std::move(waitPromises_.back()); + waitPromises_.pop_back(); + } else { + running_ = false; + } + } + if (resumePromise.valid()) { + resumePromise.setValue(); + } + } + void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) { // Since // https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820, @@ -121,8 +218,15 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { } gluten::AllocationListener* listener_; - std::recursive_mutex mutex_; + // std::recursive_mutex mutex_; inline static std::string kind_ = "GLUTEN"; + + mutable std::mutex mutex_; + // Indicates if there is a running arbitration request or not. + bool running_{false}; + // The promises of the arbitration requests waiting for the serialized + // execution. + std::vector waitPromises_; }; class ArbitratorFactoryRegister { @@ -171,6 +275,9 @@ VeloxMemoryManager::VeloxMemoryManager( .memoryPoolInitCapacity = 0, .memoryPoolTransferCapacity = 32 << 20, .memoryReclaimWaitMs = 0}; + if (name_ == "WholeStageIterator") { + mmOptions.arbitrationStateCheckCb = velox::exec::memoryArbitrationStateCheck; + } veloxMemoryManager_ = std::make_unique(mmOptions); veloxAggregatePool_ = veloxMemoryManager_->addRootPool(