diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index 8f8002b2c216..1d784f3a5eda 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -403,17 +403,25 @@ class SparkAllocationListener final : public gluten::AllocationListener { env->CallLongMethod(jListenerGlobalRef_, jReserveMethod_, size); checkException(env); } - // atomic operation is enough here, no need to use mutex - bytesReserved_.fetch_add(size); - maxBytesReserved_.store(std::max(bytesReserved_.load(), maxBytesReserved_.load())); + usedBytes_ += size; + while (true) { + int64_t savedPeakBytes = peakBytes_; + if (usedBytes_ <= savedPeakBytes) { + break; + } + // usedBytes_ > savedPeakBytes, update peak + if (peakBytes_.compare_exchange_weak(savedPeakBytes, usedBytes_)) { + break; + } + } } int64_t currentBytes() override { - return bytesReserved_; + return usedBytes_; } int64_t peakBytes() override { - return maxBytesReserved_; + return peakBytes_; } private: @@ -421,8 +429,8 @@ class SparkAllocationListener final : public gluten::AllocationListener { jobject jListenerGlobalRef_; const jmethodID jReserveMethod_; const jmethodID jUnreserveMethod_; - std::atomic_int64_t bytesReserved_{0L}; - std::atomic_int64_t maxBytesReserved_{0L}; + std::atomic_int64_t usedBytes_{0L}; + std::atomic_int64_t peakBytes_{0L}; }; class BacktraceAllocationListener final : public gluten::AllocationListener { diff --git a/cpp/core/memory/AllocationListener.h b/cpp/core/memory/AllocationListener.h index 695552cefbe3..41797641fe14 100644 --- a/cpp/core/memory/AllocationListener.h +++ b/cpp/core/memory/AllocationListener.h @@ -50,32 +50,18 @@ class AllocationListener { // The class must be thread safe class BlockAllocationListener final : public AllocationListener { public: - BlockAllocationListener(AllocationListener* delegated, uint64_t blockSize) + BlockAllocationListener(AllocationListener* delegated, int64_t blockSize) : delegated_(delegated), blockSize_(blockSize) {} void allocationChanged(int64_t diff) override { if (diff == 0) { return; } - std::unique_lock guard{mutex_}; - if (diff > 0) { - if (reservationBytes_ - usedBytes_ < diff) { - auto roundSize = (diff + (blockSize_ - 1)) / blockSize_ * blockSize_; - reservationBytes_ += roundSize; - peakBytes_ = std::max(peakBytes_, reservationBytes_); - guard.unlock(); - // unnecessary to lock the delegated listener, assume it's thread safe - delegated_->allocationChanged(roundSize); - } - usedBytes_ += diff; - } else { - usedBytes_ += diff; - auto unreservedSize = (reservationBytes_ - usedBytes_) / blockSize_ * blockSize_; - reservationBytes_ -= unreservedSize; - guard.unlock(); - // unnecessary to lock the delegated listener - delegated_->allocationChanged(-unreservedSize); + int64_t granted = reserve(diff); + if (granted == 0) { + return; } + delegated_->allocationChanged(granted); } int64_t currentBytes() override { @@ -87,11 +73,28 @@ class BlockAllocationListener final : public AllocationListener { } private: + inline int64_t reserve(int64_t diff) { + std::lock_guard lock(mutex_); + usedBytes_ += diff; + int64_t newBlockCount; + if (usedBytes_ == 0) { + newBlockCount = 0; + } else { + // ceil to get the required block number + newBlockCount = (usedBytes_ - 1) / blockSize_ + 1; + } + int64_t bytesGranted = (newBlockCount - blocksReserved_) * blockSize_; + blocksReserved_ = newBlockCount; + peakBytes_ = std::max(peakBytes_, usedBytes_); + return bytesGranted; + } + AllocationListener* const delegated_; const uint64_t blockSize_; - uint64_t usedBytes_{0L}; - uint64_t peakBytes_{0L}; - uint64_t reservationBytes_{0L}; + int64_t blocksReserved_{0L}; + int64_t usedBytes_{0L}; + int64_t peakBytes_{0L}; + int64_t reservationBytes_{0L}; mutable std::mutex mutex_; }; diff --git a/cpp/core/memory/MemoryAllocator.cc b/cpp/core/memory/MemoryAllocator.cc index 01818636aa52..c637c6a9c13d 100644 --- a/cpp/core/memory/MemoryAllocator.cc +++ b/cpp/core/memory/MemoryAllocator.cc @@ -92,8 +92,17 @@ int64_t ListenableMemoryAllocator::peakBytes() const { void ListenableMemoryAllocator::updateUsage(int64_t size) { listener_->allocationChanged(size); - usedBytes_.fetch_add(size); - peakBytes_.store(std::max(peakBytes_.load(), usedBytes_.load())); + usedBytes_ += size; + while (true) { + int64_t savedPeakBytes = peakBytes_; + if (usedBytes_ <= savedPeakBytes) { + break; + } + // usedBytes_ > savedPeakBytes, update peak + if (peakBytes_.compare_exchange_weak(savedPeakBytes, usedBytes_)) { + break; + } + } } bool StdMemoryAllocator::allocate(int64_t size, void** out) {