Skip to content

Commit

Permalink
Fix conflict in VeloxMemoryManager.cc
Browse files Browse the repository at this point in the history
[GLUTEN-6736][VL] Phase 2: Minimize lock scope in ListenableArbitrator (#6783)

Closes #6736
  • Loading branch information
zhztheplayer authored and weiting-chen committed Aug 13, 2024
1 parent 34782bd commit daddb3b
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 17 deletions.
42 changes: 25 additions & 17 deletions cpp/velox/memory/VeloxMemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& candidatePools,
uint64_t targetBytes) override {
velox::memory::ScopedMemoryArbitrationContext ctx(pool);
VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool")
auto candidate = candidatePools.back();
VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator");
velox::memory::MemoryPool* candidate;
{
std::unique_lock guard{mutex_};
VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool")
candidate = candidates_.begin()->first;
}
VELOX_CHECK(pool->root() == candidate, "Illegal state in ListenableArbitrator");

std::lock_guard<std::recursive_mutex> l(mutex_);
growCapacityLocked(pool->root(), targetBytes);
growCapacity0(pool->root(), targetBytes);
return true;
}

Expand All @@ -80,16 +83,18 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
bool allowAbort) override {
velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr);
facebook::velox::exec::MemoryReclaimer::Stats status;
VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool");
std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have recursive locking for this mutex?
auto pool = pools.at(0);
const uint64_t oldCapacity = pool->capacity();
velox::memory::MemoryPool* pool;
{
std::unique_lock guard{mutex_};
VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool")
pool = candidates_.begin()->first;
}
pool->reclaim(targetBytes, 0, status); // ignore the output
shrinkPool(pool.get(), 0);
const uint64_t newCapacity = pool->capacity();
uint64_t total = oldCapacity - newCapacity;
listener_->allocationChanged(-total);
return total;
return shrinkCapacity0(pool, 0);
}

uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override {
return shrinkCapacity0(pool, targetBytes);
}

Stats stats() const override {
Expand All @@ -102,7 +107,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

private:
void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
// Since
// https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820,
// We should pass bytes as parameter "reservationBytes" when calling ::grow.
Expand All @@ -124,14 +129,17 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
pool->toString())
}

uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
uint64_t freeBytes = shrinkPool(pool, bytes);
listener_->allocationChanged(-freeBytes);
return freeBytes;
}

gluten::AllocationListener* listener_;
std::recursive_mutex mutex_;
const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused.
const uint64_t memoryPoolTransferCapacity_;

mutable std::mutex mutex_;
inline static std::string kind_ = "GLUTEN";
};

Expand Down
270 changes: 270 additions & 0 deletions cpp/velox/tests/MemoryManagerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,274 @@ TEST_F(MemoryManagerTest, memoryAllocatorWithBlockReservation) {
ASSERT_EQ(allocator_->getBytes(), 0);
}

namespace {
class AllocationListenerWrapper : public AllocationListener {
public:
explicit AllocationListenerWrapper() {}

void set(AllocationListener* const delegate) {
if (delegate_ != nullptr) {
throw std::runtime_error("Invalid state");
}
delegate_ = delegate;
}

void allocationChanged(int64_t diff) override {
delegate_->allocationChanged(diff);
}
int64_t currentBytes() override {
return delegate_->currentBytes();
}
int64_t peakBytes() override {
return delegate_->peakBytes();
}

private:
AllocationListener* delegate_{nullptr};
};

class SpillableAllocationListener : public AllocationListener {
public:
virtual uint64_t shrink(uint64_t bytes) = 0;
virtual uint64_t spill(uint64_t bytes) = 0;
};

class MockSparkTaskMemoryManager {
public:
explicit MockSparkTaskMemoryManager(const uint64_t maxBytes);

AllocationListener* newListener(std::function<uint64_t(uint64_t)> shrink, std::function<uint64_t(uint64_t)> spill);

uint64_t acquire(uint64_t bytes);
void release(uint64_t bytes);
uint64_t currentBytes() {
return currentBytes_;
}

private:
mutable std::recursive_mutex mutex_;
std::vector<std::unique_ptr<SpillableAllocationListener>> listeners_{};

const uint64_t maxBytes_;
uint64_t currentBytes_{0L};
};

class MockSparkAllocationListener : public SpillableAllocationListener {
public:
explicit MockSparkAllocationListener(
MockSparkTaskMemoryManager* const manager,
std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill)
: manager_(manager), shrink_(shrink), spill_(spill) {}

void allocationChanged(int64_t diff) override {
if (diff == 0) {
return;
}
if (diff > 0) {
auto granted = manager_->acquire(diff);
if (granted < diff) {
throw std::runtime_error("OOM");
}
currentBytes_ += granted;
return;
}
manager_->release(-diff);
currentBytes_ -= (-diff);
}

uint64_t shrink(uint64_t bytes) override {
return shrink_(bytes);
}

uint64_t spill(uint64_t bytes) override {
return spill_(bytes);
}

int64_t currentBytes() override {
return currentBytes_;
}

private:
MockSparkTaskMemoryManager* const manager_;
std::function<uint64_t(uint64_t)> shrink_;
std::function<uint64_t(uint64_t)> spill_;
std::atomic<uint64_t> currentBytes_{0L};
};

MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t maxBytes) : maxBytes_(maxBytes) {}

AllocationListener* MockSparkTaskMemoryManager::newListener(
std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill) {
listeners_.push_back(std::make_unique<MockSparkAllocationListener>(this, shrink, spill));
return listeners_.back().get();
}

uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) {
std::unique_lock l(mutex_);
auto freeBytes = maxBytes_ - currentBytes_;
if (bytes <= freeBytes) {
currentBytes_ += bytes;
return bytes;
}
// Shrink listeners.
int64_t bytesNeeded = bytes - freeBytes;
for (const auto& listener : listeners_) {
bytesNeeded -= listener->shrink(bytesNeeded);
if (bytesNeeded < 0) {
break;
}
}
if (bytesNeeded > 0) {
for (const auto& listener : listeners_) {
bytesNeeded -= listener->spill(bytesNeeded);
if (bytesNeeded < 0) {
break;
}
}
}

if (bytesNeeded > 0) {
uint64_t granted = bytes - bytesNeeded;
currentBytes_ += granted;
return granted;
}

currentBytes_ += bytes;
return bytes;
}

void MockSparkTaskMemoryManager::release(uint64_t bytes) {
std::unique_lock l(mutex_);
currentBytes_ -= bytes;
}

class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer {
public:
explicit MockMemoryReclaimer(std::vector<void*>& buffs, int32_t size) : buffs_(buffs), size_(size) {}

bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t& reclaimableBytes) const override {
uint64_t total = 0;
for (const auto& buf : buffs_) {
if (buf == nullptr) {
continue;
}
total += size_;
}
if (total == 0) {
return false;
}
reclaimableBytes = total;
return true;
}

uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t maxWaitMs, Stats& stats) override {
uint64_t total = 0;
for (auto& buf : buffs_) {
if (buf == nullptr) {
// When:
// 1. Called by allocation from the same pool so buff is not allocated yet.
// 2. Already called once.
continue;
}
pool->free(buf, size_);
buf = nullptr;
total += size_;
}
return total;
}

private:
std::vector<void*>& buffs_;
int32_t size_;
};

void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm, std::vector<std::unique_ptr<VeloxMemoryManager>>& vmms) {
uint64_t sum = 0;
for (const auto& vmm : vmms) {
if (vmm == nullptr) {
continue;
}
sum += vmm->getAggregateMemoryPool()->capacity();
}
if (tmm.currentBytes() != sum) {
ASSERT_EQ(tmm.currentBytes(), sum);
}
}
} // namespace

class MultiMemoryManagerTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
std::unordered_map<std::string, std::string> conf = {
{kMemoryReservationBlockSize, std::to_string(kMemoryReservationBlockSizeDefault)},
{kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}};
gluten::VeloxBackend::create(conf);
}

std::unique_ptr<VeloxMemoryManager> newVeloxMemoryManager(std::unique_ptr<AllocationListener> listener) {
return std::make_unique<VeloxMemoryManager>(std::move(listener));
}
};

TEST_F(MultiMemoryManagerTest, spill) {
const uint64_t maxBytes = 200 << 20;
const uint32_t numThreads = 100;
const uint32_t numAllocations = 200;
const int32_t allocateSize = 10 << 20;

MockSparkTaskMemoryManager tmm{maxBytes};
std::vector<std::unique_ptr<VeloxMemoryManager>> vmms{};
std::vector<std::thread> threads{};
std::vector<std::vector<void*>> buffs{};
for (size_t i = 0; i < numThreads; ++i) {
buffs.push_back({});
vmms.emplace_back(nullptr);
}

// Emulate a shared lock to avoid ABBA deadlock.
std::recursive_mutex mutex;

for (size_t i = 0; i < numThreads; ++i) {
threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex, &buffs]() -> void {
auto wrapper = std::make_unique<AllocationListenerWrapper>(); // Set later.
auto* listener = wrapper.get();

facebook::velox::memory::MemoryPool* pool; // Set later.
{
std::unique_lock<std::recursive_mutex> l(mutex);
vmms[i] = newVeloxMemoryManager(std::move(wrapper));
pool = vmms[i]->getLeafMemoryPool().get();
pool->setReclaimer(std::make_unique<MockMemoryReclaimer>(buffs[i], allocateSize));
listener->set(tmm.newListener(
[](uint64_t bytes) -> uint64_t { return 0; },
[i, &vmms, &mutex](uint64_t bytes) -> uint64_t {
std::unique_lock<std::recursive_mutex> l(mutex);
return vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes);
}));
}
{
std::unique_lock<std::recursive_mutex> l(mutex);
for (size_t j = 0; j < numAllocations; ++j) {
assertCapacitiesMatch(tmm, vmms);
buffs[i].push_back(pool->allocate(allocateSize));
assertCapacitiesMatch(tmm, vmms);
}
}
});
}

for (auto& thread : threads) {
thread.join();
}

for (auto& vmm : vmms) {
assertCapacitiesMatch(tmm, vmms);
vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize * numAllocations);
assertCapacitiesMatch(tmm, vmms);
}

ASSERT_EQ(tmm.currentBytes(), 0);
}
} // namespace gluten

0 comments on commit daddb3b

Please sign in to comment.