Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohahaha committed May 13, 2024
1 parent c8c17dd commit 5063841
Showing 1 changed file with 113 additions and 6 deletions.
119 changes: 113 additions & 6 deletions cpp/velox/memory/VeloxMemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,24 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override {
std::lock_guard<std::recursive_mutex> l(mutex_);
std::lock_guard<std::mutex> l(mutex_);

return shrinkCapacityLocked(pool, targetBytes);
}

bool growCapacity(
velox::memory::MemoryPool* pool,
const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& candidatePools,
uint64_t targetBytes) override {
velox::memory::ScopedMemoryArbitrationContext ctx(pool);
ArbitrationOperation op(pool, targetBytes, candidatePools);
ScopedArbitration scopedArbitration(this, &op);
// velox::memory::ScopedMemoryArbitrationContext ctx(pool);
VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool")
auto candidate = candidatePools.back();
VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator");

std::lock_guard<std::recursive_mutex> l(mutex_);
std::lock_guard<std::mutex> l(mutex_);

growCapacityLocked(pool->root(), targetBytes);
return true;
}
Expand All @@ -68,10 +72,13 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
uint64_t targetBytes,
bool allowSpill,
bool allowAbort) override {
velox::memory::ScopedMemoryArbitrationContext ctx(nullptr);
ArbitrationOperation op(targetBytes, pools);
ScopedArbitration scopedArbitration(this, &op);
// velox::memory::ScopedMemoryArbitrationContext ctx(nullptr);
facebook::velox::exec::MemoryReclaimer::Stats status;
VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool");
std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have recursive locking for this mutex?
std::lock_guard<std::mutex> l(mutex_);

auto pool = pools.at(0);
const uint64_t oldCapacity = pool->capacity();
pool->reclaim(targetBytes, 0, status); // ignore the output
Expand All @@ -92,6 +99,96 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

private:
struct ArbitrationOperation {
velox::memory::MemoryPool* const requestPool;
velox::memory::MemoryPool* const requestRoot;
const uint64_t targetBytes;

ArbitrationOperation(
uint64_t targetBytes,
const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& candidatePools)
: ArbitrationOperation(nullptr, targetBytes, candidatePools) {}

ArbitrationOperation(
velox::memory::MemoryPool* _requestor,
uint64_t _targetBytes,
const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& _candidatePools)
: requestPool(_requestor),
requestRoot(_requestor == nullptr ? nullptr : _requestor->root()),
targetBytes(_targetBytes) {}

void enterArbitration() {
if (requestPool != nullptr) {
requestPool->enterArbitration();
}
}
void leaveArbitration() {
if (requestPool != nullptr) {
requestPool->leaveArbitration();
}
}
};

class ScopedArbitration {
public:
ScopedArbitration(ListenableArbitrator* arbitrator, ArbitrationOperation* op)
: operation_(op), arbitrator_(arbitrator), arbitrationCtx_(op->requestPool) {
VELOX_CHECK_NOT_NULL(arbitrator_);
VELOX_CHECK_NOT_NULL(operation_);
operation_->enterArbitration();
if (arbitrator_->arbitrationStateCheckCb_ != nullptr && operation_->requestPool != nullptr) {
arbitrator_->arbitrationStateCheckCb_(*operation_->requestPool);
}
arbitrator_->startArbitration(operation_);
}

~ScopedArbitration() {
operation_->leaveArbitration();
arbitrator_->finishArbitration(operation_);
}

private:
ArbitrationOperation* const operation_;
ListenableArbitrator* const arbitrator_;
const velox::memory::ScopedMemoryArbitrationContext arbitrationCtx_;
};

void startArbitration(ArbitrationOperation* op) {
velox::ContinueFuture waitPromise{velox::ContinueFuture::makeEmpty()};
{
std::lock_guard<std::mutex> l(mutex_);
if (running_) {
waitPromises_.emplace_back(
fmt::format("Wait for arbitration {}/{}", op->requestPool->name(), op->requestRoot->name()));
waitPromise = waitPromises_.back().getSemiFuture();
} else {
VELOX_CHECK(waitPromises_.empty());
running_ = true;
}
}

if (waitPromise.valid()) {
waitPromise.wait();
}
}

void finishArbitration(ArbitrationOperation* op) {
velox::ContinuePromise resumePromise{velox::ContinuePromise::makeEmpty()};
{
std::lock_guard<std::mutex> l(mutex_);
VELOX_CHECK(running_);
if (!waitPromises_.empty()) {
resumePromise = std::move(waitPromises_.back());
waitPromises_.pop_back();
} else {
running_ = false;
}
}
if (resumePromise.valid()) {
resumePromise.setValue();
}
}

void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
// Since
// https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820,
Expand Down Expand Up @@ -121,8 +218,15 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

gluten::AllocationListener* listener_;
std::recursive_mutex mutex_;
// std::recursive_mutex mutex_;
inline static std::string kind_ = "GLUTEN";

mutable std::mutex mutex_;
// Indicates if there is a running arbitration request or not.
bool running_{false};
// The promises of the arbitration requests waiting for the serialized
// execution.
std::vector<velox::ContinuePromise> waitPromises_;
};

class ArbitratorFactoryRegister {
Expand Down Expand Up @@ -171,6 +275,9 @@ VeloxMemoryManager::VeloxMemoryManager(
.memoryPoolInitCapacity = 0,
.memoryPoolTransferCapacity = 32 << 20,
.memoryReclaimWaitMs = 0};
if (name_ == "WholeStageIterator") {
mmOptions.arbitrationStateCheckCb = velox::exec::memoryArbitrationStateCheck;
}
veloxMemoryManager_ = std::make_unique<velox::memory::MemoryManager>(mmOptions);

veloxAggregatePool_ = veloxMemoryManager_->addRootPool(
Expand Down

0 comments on commit 5063841

Please sign in to comment.