Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL][1.2] Port PR #6741 #6783 #6802

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions cpp/velox/memory/VeloxMemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,20 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
return targetBytes;
}

uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override {
std::lock_guard<std::recursive_mutex> l(mutex_);
return shrinkCapacityLocked(pool, targetBytes);
}

bool growCapacity(
velox::memory::MemoryPool* pool,
const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& candidatePools,
uint64_t targetBytes) override {
velox::memory::ScopedMemoryArbitrationContext ctx(pool);
VELOX_CHECK_EQ(candidatePools.size(), 1, "ListenableArbitrator should only be used within a single root pool")
auto candidate = candidatePools.back();
VELOX_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator");
velox::memory::MemoryPool* candidate;
{
std::unique_lock guard{mutex_};
VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool")
candidate = candidates_.begin()->first;
}
VELOX_CHECK(pool->root() == candidate, "Illegal state in ListenableArbitrator");

std::lock_guard<std::recursive_mutex> l(mutex_);
growCapacityLocked(pool->root(), targetBytes);
growCapacity0(pool->root(), targetBytes);
return true;
}

Expand All @@ -80,16 +78,18 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
bool allowAbort) override {
velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr);
facebook::velox::exec::MemoryReclaimer::Stats status;
VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool");
std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have recursive locking for this mutex?
auto pool = pools.at(0);
const uint64_t oldCapacity = pool->capacity();
velox::memory::MemoryPool* pool;
{
std::unique_lock guard{mutex_};
VELOX_CHECK_EQ(candidates_.size(), 1, "ListenableArbitrator should only be used within a single root pool")
pool = candidates_.begin()->first;
}
pool->reclaim(targetBytes, 0, status); // ignore the output
shrinkPool(pool.get(), 0);
const uint64_t newCapacity = pool->capacity();
uint64_t total = oldCapacity - newCapacity;
listener_->allocationChanged(-total);
return total;
return shrinkCapacity0(pool, 0);
}

uint64_t shrinkCapacity(velox::memory::MemoryPool* pool, uint64_t targetBytes) override {
return shrinkCapacity0(pool, targetBytes);
}

Stats stats() const override {
Expand All @@ -102,7 +102,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

private:
void growCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
void growCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
// Since
// https://github.com/facebookincubator/velox/pull/9557/files#diff-436e44b7374032f8f5d7eb45869602add6f955162daa2798d01cc82f8725724dL812-L820,
// We should pass bytes as parameter "reservationBytes" when calling ::grow.
Expand All @@ -124,15 +124,19 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
pool->toString())
}

uint64_t shrinkCapacityLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
uint64_t shrinkCapacity0(velox::memory::MemoryPool* pool, uint64_t bytes) {
uint64_t freeBytes = shrinkPool(pool, bytes);
listener_->allocationChanged(-freeBytes);
return freeBytes;
}

gluten::AllocationListener* listener_;
std::recursive_mutex mutex_;
const uint64_t memoryPoolInitialCapacity_; // FIXME: Unused.
const uint64_t memoryPoolTransferCapacity_;

mutable std::mutex mutex_;
inline static std::string kind_ = "GLUTEN";
std::unordered_map<velox::memory::MemoryPool*, std::weak_ptr<velox::memory::MemoryPool>> candidates_;
};

class ArbitratorFactoryRegister {
Expand Down
270 changes: 270 additions & 0 deletions cpp/velox/tests/MemoryManagerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,274 @@ TEST_F(MemoryManagerTest, memoryAllocatorWithBlockReservation) {
ASSERT_EQ(allocator_->getBytes(), 0);
}

namespace {
class AllocationListenerWrapper : public AllocationListener {
public:
explicit AllocationListenerWrapper() {}

void set(AllocationListener* const delegate) {
if (delegate_ != nullptr) {
throw std::runtime_error("Invalid state");
}
delegate_ = delegate;
}

void allocationChanged(int64_t diff) override {
delegate_->allocationChanged(diff);
}
int64_t currentBytes() override {
return delegate_->currentBytes();
}
int64_t peakBytes() override {
return delegate_->peakBytes();
}

private:
AllocationListener* delegate_{nullptr};
};

class SpillableAllocationListener : public AllocationListener {
public:
virtual uint64_t shrink(uint64_t bytes) = 0;
virtual uint64_t spill(uint64_t bytes) = 0;
};

class MockSparkTaskMemoryManager {
public:
explicit MockSparkTaskMemoryManager(const uint64_t maxBytes);

AllocationListener* newListener(std::function<uint64_t(uint64_t)> shrink, std::function<uint64_t(uint64_t)> spill);

uint64_t acquire(uint64_t bytes);
void release(uint64_t bytes);
uint64_t currentBytes() {
return currentBytes_;
}

private:
mutable std::recursive_mutex mutex_;
std::vector<std::unique_ptr<SpillableAllocationListener>> listeners_{};

const uint64_t maxBytes_;
uint64_t currentBytes_{0L};
};

class MockSparkAllocationListener : public SpillableAllocationListener {
public:
explicit MockSparkAllocationListener(
MockSparkTaskMemoryManager* const manager,
std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill)
: manager_(manager), shrink_(shrink), spill_(spill) {}

void allocationChanged(int64_t diff) override {
if (diff == 0) {
return;
}
if (diff > 0) {
auto granted = manager_->acquire(diff);
if (granted < diff) {
throw std::runtime_error("OOM");
}
currentBytes_ += granted;
return;
}
manager_->release(-diff);
currentBytes_ -= (-diff);
}

uint64_t shrink(uint64_t bytes) override {
return shrink_(bytes);
}

uint64_t spill(uint64_t bytes) override {
return spill_(bytes);
}

int64_t currentBytes() override {
return currentBytes_;
}

private:
MockSparkTaskMemoryManager* const manager_;
std::function<uint64_t(uint64_t)> shrink_;
std::function<uint64_t(uint64_t)> spill_;
std::atomic<uint64_t> currentBytes_{0L};
};

MockSparkTaskMemoryManager::MockSparkTaskMemoryManager(const uint64_t maxBytes) : maxBytes_(maxBytes) {}

AllocationListener* MockSparkTaskMemoryManager::newListener(
std::function<uint64_t(uint64_t)> shrink,
std::function<uint64_t(uint64_t)> spill) {
listeners_.push_back(std::make_unique<MockSparkAllocationListener>(this, shrink, spill));
return listeners_.back().get();
}

uint64_t MockSparkTaskMemoryManager::acquire(uint64_t bytes) {
std::unique_lock l(mutex_);
auto freeBytes = maxBytes_ - currentBytes_;
if (bytes <= freeBytes) {
currentBytes_ += bytes;
return bytes;
}
// Shrink listeners.
int64_t bytesNeeded = bytes - freeBytes;
for (const auto& listener : listeners_) {
bytesNeeded -= listener->shrink(bytesNeeded);
if (bytesNeeded < 0) {
break;
}
}
if (bytesNeeded > 0) {
for (const auto& listener : listeners_) {
bytesNeeded -= listener->spill(bytesNeeded);
if (bytesNeeded < 0) {
break;
}
}
}

if (bytesNeeded > 0) {
uint64_t granted = bytes - bytesNeeded;
currentBytes_ += granted;
return granted;
}

currentBytes_ += bytes;
return bytes;
}

void MockSparkTaskMemoryManager::release(uint64_t bytes) {
std::unique_lock l(mutex_);
currentBytes_ -= bytes;
}

class MockMemoryReclaimer : public facebook::velox::memory::MemoryReclaimer {
public:
explicit MockMemoryReclaimer(std::vector<void*>& buffs, int32_t size) : buffs_(buffs), size_(size) {}

bool reclaimableBytes(const memory::MemoryPool& pool, uint64_t& reclaimableBytes) const override {
uint64_t total = 0;
for (const auto& buf : buffs_) {
if (buf == nullptr) {
continue;
}
total += size_;
}
if (total == 0) {
return false;
}
reclaimableBytes = total;
return true;
}

uint64_t reclaim(memory::MemoryPool* pool, uint64_t targetBytes, uint64_t maxWaitMs, Stats& stats) override {
uint64_t total = 0;
for (auto& buf : buffs_) {
if (buf == nullptr) {
// When:
// 1. Called by allocation from the same pool so buff is not allocated yet.
// 2. Already called once.
continue;
}
pool->free(buf, size_);
buf = nullptr;
total += size_;
}
return total;
}

private:
std::vector<void*>& buffs_;
int32_t size_;
};

void assertCapacitiesMatch(MockSparkTaskMemoryManager& tmm, std::vector<std::unique_ptr<VeloxMemoryManager>>& vmms) {
uint64_t sum = 0;
for (const auto& vmm : vmms) {
if (vmm == nullptr) {
continue;
}
sum += vmm->getAggregateMemoryPool()->capacity();
}
if (tmm.currentBytes() != sum) {
ASSERT_EQ(tmm.currentBytes(), sum);
}
}
} // namespace

class MultiMemoryManagerTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
std::unordered_map<std::string, std::string> conf = {
{kMemoryReservationBlockSize, std::to_string(kMemoryReservationBlockSizeDefault)},
{kVeloxMemInitCapacity, std::to_string(kVeloxMemInitCapacityDefault)}};
gluten::VeloxBackend::create(conf);
}

std::unique_ptr<VeloxMemoryManager> newVeloxMemoryManager(std::unique_ptr<AllocationListener> listener) {
return std::make_unique<VeloxMemoryManager>(std::move(listener));
}
};

TEST_F(MultiMemoryManagerTest, spill) {
const uint64_t maxBytes = 200 << 20;
const uint32_t numThreads = 100;
const uint32_t numAllocations = 200;
const int32_t allocateSize = 10 << 20;

MockSparkTaskMemoryManager tmm{maxBytes};
std::vector<std::unique_ptr<VeloxMemoryManager>> vmms{};
std::vector<std::thread> threads{};
std::vector<std::vector<void*>> buffs{};
for (size_t i = 0; i < numThreads; ++i) {
buffs.push_back({});
vmms.emplace_back(nullptr);
}

// Emulate a shared lock to avoid ABBA deadlock.
std::recursive_mutex mutex;

for (size_t i = 0; i < numThreads; ++i) {
threads.emplace_back([this, i, allocateSize, &tmm, &vmms, &mutex, &buffs]() -> void {
auto wrapper = std::make_unique<AllocationListenerWrapper>(); // Set later.
auto* listener = wrapper.get();

facebook::velox::memory::MemoryPool* pool; // Set later.
{
std::unique_lock<std::recursive_mutex> l(mutex);
vmms[i] = newVeloxMemoryManager(std::move(wrapper));
pool = vmms[i]->getLeafMemoryPool().get();
pool->setReclaimer(std::make_unique<MockMemoryReclaimer>(buffs[i], allocateSize));
listener->set(tmm.newListener(
[](uint64_t bytes) -> uint64_t { return 0; },
[i, &vmms, &mutex](uint64_t bytes) -> uint64_t {
std::unique_lock<std::recursive_mutex> l(mutex);
return vmms[i]->getMemoryManager()->arbitrator()->shrinkCapacity(bytes);
}));
}
{
std::unique_lock<std::recursive_mutex> l(mutex);
for (size_t j = 0; j < numAllocations; ++j) {
assertCapacitiesMatch(tmm, vmms);
buffs[i].push_back(pool->allocate(allocateSize));
assertCapacitiesMatch(tmm, vmms);
}
}
});
}

for (auto& thread : threads) {
thread.join();
}

for (auto& vmm : vmms) {
assertCapacitiesMatch(tmm, vmms);
vmm->getMemoryManager()->arbitrator()->shrinkCapacity(allocateSize * numAllocations);
assertCapacitiesMatch(tmm, vmms);
}

ASSERT_EQ(tmm.currentBytes(), 0);
}
} // namespace gluten
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ public class SimpleMemoryUsageRecorder implements MemoryUsageRecorder {
@Override
public void inc(long bytes) {
final long total = this.current.addAndGet(bytes);
long prev_peak;
long prevPeak;
do {
prev_peak = this.peak.get();
if (total <= prev_peak) {
prevPeak = this.peak.get();
if (total <= prevPeak) {
break;
}
} while (!this.peak.compareAndSet(prev_peak, total));
} while (!this.peak.compareAndSet(prevPeak, total));
}

// peak used bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget memoryTarg
return memoryTarget;
}

public static MemoryTarget newConsumer(
public static TreeMemoryTarget newConsumer(
TaskMemoryManager tmm,
String name,
Spiller spiller,
Expand Down
Loading
Loading