diff --git a/velox/common/memory/ArbitrationOperation.cpp b/velox/common/memory/ArbitrationOperation.cpp new file mode 100644 index 000000000000..2a6ca38981fd --- /dev/null +++ b/velox/common/memory/ArbitrationOperation.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/memory/ArbitrationOperation.h" +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::memory { +using namespace facebook::velox::memory; + +ArbitrationOperation::ArbitrationOperation( + ScopedArbitrationParticipant&& participant, + uint64_t requestBytes, + uint64_t timeoutMs) + : requestBytes_(requestBytes), + timeoutMs_(timeoutMs), + createTimeMs_(getCurrentTimeMs()), + participant_(std::move(participant)) { + VELOX_CHECK_GT(requestBytes_, 0); +} + +ArbitrationOperation::~ArbitrationOperation() { + VELOX_CHECK_NE( + state_, + State::kRunning, + "Unexpected arbitration operation state on destruction"); + VELOX_CHECK( + allocatedBytes_ == 0 || allocatedBytes_ >= requestBytes_, + "Unexpected allocatedBytes_ {} vs requestBytes_ {}", + succinctBytes(allocatedBytes_), + succinctBytes(requestBytes_)); +} + +std::string ArbitrationOperation::stateName(State state) { + switch (state) { + case State::kInit: + return "init"; + case State::kWaiting: + return "waiting"; + case State::kRunning: + return "running"; + case State::kFinished: + return "finished"; + default: + return fmt::format("unknown state: {}", static_cast(state)); + } +} + +void ArbitrationOperation::setState(State state) { + switch (state) { + case State::kWaiting: + VELOX_CHECK_EQ(state_, State::kInit); + break; + case State::kRunning: + VELOX_CHECK(this->state_ == State::kWaiting || state_ == State::kInit); + break; + case State::kFinished: + VELOX_CHECK_EQ(this->state_, State::kRunning); + break; + default: + VELOX_UNREACHABLE( + "Unexpected state transition from {} to {}", state_, state); + break; + } + state_ = state; +} + +void ArbitrationOperation::start() { + VELOX_CHECK_EQ(state_, State::kInit); + participant_->startArbitration(this); + setState(ArbitrationOperation::State::kRunning); +} + +void ArbitrationOperation::finish() { + setState(State::kFinished); + VELOX_CHECK_EQ(finishTimeMs_, 0); + finishTimeMs_ = getCurrentTimeMs(); + participant_->finishArbitration(this); +} + +bool ArbitrationOperation::aborted() const { + return participant_->aborted(); +} + +size_t ArbitrationOperation::executionTimeMs() const { + if (state_ == State::kFinished) { + VELOX_CHECK_GE(finishTimeMs_, createTimeMs_); + return finishTimeMs_ - createTimeMs_; + } else { + const auto currentTimeMs = getCurrentTimeMs(); + VELOX_CHECK_GE(currentTimeMs, createTimeMs_); + return currentTimeMs - createTimeMs_; + } +} + +bool ArbitrationOperation::hasTimeout() const { + return state_ != State::kFinished && timeoutMs() <= 0; +} + +size_t ArbitrationOperation::timeoutMs() const { + if (state_ == State::kFinished) { + return 0; + } + const auto execTimeMs = executionTimeMs(); + if (execTimeMs >= timeoutMs_) { + return 0; + } + return timeoutMs_ - execTimeMs; +} + +void ArbitrationOperation::setGrowTargets() { + // We shall only set grow targets once after start execution. + VELOX_CHECK_EQ(state_, State::kRunning); + VELOX_CHECK( + maxGrowBytes_ == 0 && minGrowBytes_ == 0, + "Arbitration operation grow targets have already been set: {}/{}", + succinctBytes(maxGrowBytes_), + succinctBytes(minGrowBytes_)); + participant_->getGrowTargets(requestBytes_, maxGrowBytes_, minGrowBytes_); + VELOX_CHECK_LE(requestBytes_, maxGrowBytes_); +} + +std::ostream& operator<<(std::ostream& out, ArbitrationOperation::State state) { + out << ArbitrationOperation::stateName(state); + return out; +} +} // namespace facebook::velox::memory diff --git a/velox/common/memory/ArbitrationOperation.h b/velox/common/memory/ArbitrationOperation.h new file mode 100644 index 000000000000..884d601b6027 --- /dev/null +++ b/velox/common/memory/ArbitrationOperation.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/Counters.h" +#include "velox/common/base/GTestMacros.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/common/memory/ArbitrationParticipant.h" +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::memory { + +/// Manages the execution of a memory arbitration request within the arbitrator. +class ArbitrationOperation { + public: + ArbitrationOperation( + ScopedArbitrationParticipant&& pool, + uint64_t requestBytes, + uint64_t timeoutMs); + + ~ArbitrationOperation(); + + enum class State { + kInit, + kWaiting, + kRunning, + kFinished, + }; + + State state() const { + return state_; + } + + static std::string stateName(State state); + + /// Returns the corresponding arbitration participant. + const ScopedArbitrationParticipant& participant() { + return participant_; + } + + /// Invoked to start arbitration execution on the arbitration participant. The + /// latter ensures the serialized execution of arbitration operations from the + /// same query with one at a time. So this method blocks until all the prior + /// arbitration operations finish. + void start(); + + /// Invoked to finish arbitration execution on the arbitration participant. It + /// also resumes the next waiting arbitration operation to execute if there is + /// one. + void finish(); + + /// Returns true if the corresponding arbitration participant has been + /// aborted. + bool aborted() const; + + /// Invoked to set the grow targets for this arbitration operation based on + /// the request size. + /// + /// NOTE: this should be called once after the arbitration operation is + /// started. + void setGrowTargets(); + + uint64_t requestBytes() const { + return requestBytes_; + } + + /// Returns the max grow bytes for this arbitration operation which could be + /// larger than the request bytes for exponential growth. + uint64_t maxGrowBytes() const { + return maxGrowBytes_; + } + + /// Returns the min grow bytes for this arbitration operation to ensure the + /// arbitration participant has the minimum amount of memory capacity. The + /// arbitrator might allocate memory from the reserved memory capacity pool + /// for the min grow bytes. + uint64_t minGrowBytes() const { + return minGrowBytes_; + } + + /// Returns the allocated bytes by this arbitration operation. + uint64_t& allocatedBytes() { + return allocatedBytes_; + } + + /// Returns the remaining execution time for this operation before time out. + /// If the operation has already finished, this returns zero. + size_t timeoutMs() const; + + /// Returns true if this operation has timed out. + bool hasTimeout() const; + + /// Returns the execution time of this arbitration operation since creation. + size_t executionTimeMs() const; + + /// Getters/Setters of the wait time in (local) arbitration paritcipant wait + /// queue or (global) arbitrator request wait queue. + void setLocalArbitrationWaitTimeUs(uint64_t waitTimeUs) { + VELOX_CHECK_EQ(localArbitrationWaitTimeUs_, 0); + VELOX_CHECK_EQ(state_, State::kWaiting); + localArbitrationWaitTimeUs_ = waitTimeUs; + } + + uint64_t localArbitrationWaitTimeUs() const { + return localArbitrationWaitTimeUs_; + } + + void setGlobalArbitrationWaitTimeUs(uint64_t waitTimeUs) { + VELOX_CHECK_EQ(globalArbitrationWaitTimeUs_, 0); + VELOX_CHECK_EQ(state_, State::kRunning); + globalArbitrationWaitTimeUs_ = waitTimeUs; + } + + uint64_t globalArbitrationWaitTimeUs() const { + return globalArbitrationWaitTimeUs_; + } + + private: + void setState(State state); + + const uint64_t requestBytes_; + const uint64_t timeoutMs_; + + // The start time of this arbitration operation. + const uint64_t createTimeMs_; + const ScopedArbitrationParticipant participant_; + + State state_{State::kInit}; + + uint64_t finishTimeMs_{0}; + + uint64_t maxGrowBytes_{0}; + uint64_t minGrowBytes_{0}; + + // The actual bytes allocated from arbitrator based on the request bytes and + // grow targets. It is either zero on failure or between 'requestBytes_' and + // 'maxGrowBytes_' on success. + uint64_t allocatedBytes_{0}; + + // The time that waits in local arbitration queue. + uint64_t localArbitrationWaitTimeUs_{0}; + + // The time that waits for global arbitration queue. + uint64_t globalArbitrationWaitTimeUs_{0}; + + friend class ArbitrationParticipant; +}; + +std::ostream& operator<<(std::ostream& out, ArbitrationOperation::State state); +} // namespace facebook::velox::memory + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::memory::ArbitrationOperation::State state, + format_context& ctx) { + return formatter::format( + facebook::velox::memory::ArbitrationOperation::stateName(state), ctx); + } +}; diff --git a/velox/common/memory/ArbitrationParticipant.cpp b/velox/common/memory/ArbitrationParticipant.cpp new file mode 100644 index 000000000000..ceaea9dd888a --- /dev/null +++ b/velox/common/memory/ArbitrationParticipant.cpp @@ -0,0 +1,390 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/memory/ArbitrationParticipant.h" +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/memory/ArbitrationOperation.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::memory { +using namespace facebook::velox::memory; + +std::string ArbitrationParticipant::Config::toString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}", + minCapacity, + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio, + minFreeCapacity, + minFreeCapacityRatio); +} + +ArbitrationParticipant::Config::Config( + uint64_t _minCapacity, + uint64_t _fastExponentialGrowthCapacityLimit, + double _slowCapacityGrowRatio, + uint64_t _minFreeCapacity, + double _minFreeCapacityRatio) + : minCapacity(_minCapacity), + fastExponentialGrowthCapacityLimit(_fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio(_slowCapacityGrowRatio), + minFreeCapacity(_minFreeCapacity), + minFreeCapacityRatio(_minFreeCapacityRatio) { + VELOX_CHECK_GE(slowCapacityGrowRatio, 0); + VELOX_CHECK_EQ( + fastExponentialGrowthCapacityLimit == 0, + slowCapacityGrowRatio == 0, + "fastExponentialGrowthCapacityLimit {} and slowCapacityGrowRatio {} " + "both need to be set (non-zero) at the same time to enable growth capacity " + "adjustment.", + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio); + + VELOX_CHECK_GE(minFreeCapacityRatio, 0); + VELOX_CHECK_LE(minFreeCapacityRatio, 1); + VELOX_CHECK_EQ( + minFreeCapacity == 0, + minFreeCapacityRatio == 0, + "minFreeCapacity {} and minFreeCapacityRatio {} both " + "need to be set (non-zero) at the same time to enable shrink capacity " + "adjustment.", + minFreeCapacity, + minFreeCapacityRatio); +} + +std::shared_ptr ArbitrationParticipant::create( + uint64_t id, + const std::shared_ptr& pool, + const Config* config) { + return std::shared_ptr( + new ArbitrationParticipant(id, pool, config)); +} + +ArbitrationParticipant::ArbitrationParticipant( + uint64_t id, + const std::shared_ptr& pool, + const Config* config) + : id_(id), + poolWeakPtr_(pool), + pool_(pool.get()), + config_(config), + maxCapacity_(pool_->maxCapacity()), + createTimeUs_(getCurrentTimeMicro()) { + VELOX_CHECK_LE( + config_->minCapacity, + maxCapacity_, + "The min capacity is larger than the max capacity for memory pool {}.", + pool_->name()); +} + +ArbitrationParticipant::~ArbitrationParticipant() { + VELOX_CHECK_NULL(runningOp_); + VELOX_CHECK(waitOps_.empty()); +} + +std::optional ArbitrationParticipant::lock() { + auto sharedPtr = poolWeakPtr_.lock(); + if (sharedPtr == nullptr) { + return {}; + } + return ScopedArbitrationParticipant(shared_from_this(), std::move(sharedPtr)); +} + +uint64_t ArbitrationParticipant::maxGrowCapacity() const { + const auto capacity = pool_->capacity(); + VELOX_CHECK_LE(capacity, maxCapacity_); + return maxCapacity_ - capacity; +} + +uint64_t ArbitrationParticipant::minGrowCapacity() const { + const auto capacity = pool_->capacity(); + if (capacity >= config_->minCapacity) { + return 0; + } + return config_->minCapacity - capacity; +} + +bool ArbitrationParticipant::inactivePool() const { + // Checks if a query memory pool is actively used by query execution or not. + // If not, then we don't have to respect the memory pool min limit or reserved + // capacity check. + // + // NOTE: for query system like Prestissimo, it holds a finished query + // state in minutes for query stats fetch request from the Presto + // coordinator. + return pool_->reservedBytes() == 0 && pool_->peakBytes() != 0; +} + +uint64_t ArbitrationParticipant::reclaimableFreeCapacity() const { + return std::min(maxShrinkCapacity(), maxReclaimableCapacity()); +} + +uint64_t ArbitrationParticipant::maxReclaimableCapacity() const { + if (inactivePool()) { + return pool_->capacity(); + } + const uint64_t capacityBytes = pool_->capacity(); + if (capacityBytes < config_->minCapacity) { + return 0; + } + return capacityBytes - config_->minCapacity; +} + +uint64_t ArbitrationParticipant::reclaimableUsedCapacity() const { + const auto maxReclaimableBytes = maxReclaimableCapacity(); + const auto reclaimableBytes = pool_->reclaimableBytes(); + return std::min(maxReclaimableBytes, reclaimableBytes.value_or(0)); +} + +uint64_t ArbitrationParticipant::maxShrinkCapacity() const { + const uint64_t capacity = pool_->capacity(); + const uint64_t freeBytes = pool_->freeBytes(); + if (config_->minFreeCapacity != 0 && !inactivePool()) { + const uint64_t minFreeBytes = std::min( + static_cast(capacity * config_->minFreeCapacityRatio), + config_->minFreeCapacity); + if (freeBytes <= minFreeBytes) { + return 0; + } else { + return freeBytes - minFreeBytes; + } + } else { + return freeBytes; + } +} + +bool ArbitrationParticipant::checkCapacityGrowth(uint64_t requestBytes) const { + return maxGrowCapacity() >= requestBytes; +} + +void ArbitrationParticipant::getGrowTargets( + uint64_t requestBytes, + uint64_t& maxGrowBytes, + uint64_t& minGrowBytes) const { + const uint64_t capacity = pool_->capacity(); + if (config_->fastExponentialGrowthCapacityLimit == 0 && + config_->slowCapacityGrowRatio == 0) { + maxGrowBytes = requestBytes; + } else { + if (capacity * 2 <= config_->fastExponentialGrowthCapacityLimit) { + maxGrowBytes = capacity; + } else { + maxGrowBytes = capacity * config_->slowCapacityGrowRatio; + } + } + maxGrowBytes = std::max(requestBytes, maxGrowBytes); + minGrowBytes = minGrowCapacity(); + maxGrowBytes = std::max(maxGrowBytes, minGrowBytes); + maxGrowBytes = std::min(maxGrowCapacity(), maxGrowBytes); + + VELOX_CHECK_LE(minGrowBytes, maxGrowBytes); + VELOX_CHECK_LE(requestBytes, maxGrowBytes); +} + +void ArbitrationParticipant::startArbitration(ArbitrationOperation* op) { + ContinueFuture waitPromise{ContinueFuture::makeEmpty()}; + { + std::lock_guard l(stateLock_); + ++numRequests_; + if (runningOp_ != nullptr) { + op->setState(ArbitrationOperation::State::kWaiting); + WaitOp waitOp{ + op, + ContinuePromise{fmt::format( + "Wait for arbitration on {}", op->participant()->name())}}; + waitPromise = waitOp.waitPromise.getSemiFuture(); + waitOps_.emplace_back(std::move(waitOp)); + } else { + runningOp_ = op; + } + } + + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::startArbitration", + this); + + if (waitPromise.valid()) { + uint64_t waitTimeUs{0}; + { + MicrosecondTimer timer(&waitTimeUs); + waitPromise.wait(); + } + op->setLocalArbitrationWaitTimeUs(waitTimeUs); + } +} + +void ArbitrationParticipant::finishArbitration(ArbitrationOperation* op) { + ContinuePromise resumePromise{ContinuePromise::makeEmpty()}; + { + std::lock_guard l(stateLock_); + VELOX_CHECK_EQ(static_cast(op), static_cast(runningOp_)); + if (!waitOps_.empty()) { + resumePromise = std::move(waitOps_.front().waitPromise); + runningOp_ = waitOps_.front().op; + waitOps_.pop_front(); + } else { + runningOp_ = nullptr; + } + } + if (resumePromise.valid()) { + resumePromise.setValue(); + } +} + +uint64_t ArbitrationParticipant::reclaim( + uint64_t targetBytes, + uint64_t maxWaitTimeMs) noexcept { + if (targetBytes == 0) { + return 0; + } + std::lock_guard l(reclaimLock_); + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::reclaim", this); + uint64_t reclaimedBytes{0}; + MemoryReclaimer::Stats reclaimStats; + try { + ++numReclaims_; + pool_->reclaim(targetBytes, maxWaitTimeMs, reclaimStats); + reclaimedBytes = shrink(); + } catch (const std::exception& e) { + VELOX_MEM_LOG(ERROR) << "Failed to reclaim from memory pool " + << pool_->name() << ", aborting it: " << e.what(); + abortLocked(std::current_exception()); + reclaimedBytes = shrink(/*reclaimAll=*/true); + } + return reclaimedBytes; +} + +bool ArbitrationParticipant::grow( + uint64_t growBytes, + uint64_t reservationBytes) { + std::lock_guard l(stateLock_); + ++numGrows_; + const bool success = pool_->grow(growBytes, reservationBytes); + if (success) { + growBytes_ += growBytes; + } + return success; +} + +uint64_t ArbitrationParticipant::shrink(bool reclaimAll) { + std::lock_guard l(stateLock_); + ++numShrinks_; + + uint64_t reclaimedBytes{0}; + if (reclaimAll) { + reclaimedBytes = pool_->shrink(0); + } else { + const uint64_t reclaimTargetBytes = reclaimableFreeCapacity(); + if (reclaimTargetBytes > 0) { + reclaimedBytes = pool_->shrink(reclaimTargetBytes); + } + } + reclaimedBytes_ += reclaimedBytes; + return reclaimedBytes; +} + +uint64_t ArbitrationParticipant::abort( + const std::exception_ptr& error) noexcept { + std::lock_guard l(reclaimLock_); + return abortLocked(error); +} + +uint64_t ArbitrationParticipant::abortLocked( + const std::exception_ptr& error) noexcept { + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", this); + { + std::lock_guard l(stateLock_); + if (aborted_) { + return 0; + } + aborted_ = true; + } + try { + pool_->abort(error); + } catch (const std::exception& e) { + VELOX_MEM_LOG(WARNING) << "Failed to abort memory pool " + << pool_->toString() << ", error: " << e.what(); + } + // NOTE: no matter query memory pool abort throws or not, it should have been + // marked as aborted to prevent any new memory arbitration operations. + VELOX_CHECK(pool_->aborted()); + return shrink(/*reclaimAll=*/true); +} + +bool ArbitrationParticipant::waitForReclaimOrAbort( + uint64_t maxWaitTimeMs) const { + std::unique_lock l( + reclaimLock_, std::chrono::milliseconds(maxWaitTimeMs)); + return l.owns_lock(); +} + +bool ArbitrationParticipant::hasRunningOp() const { + std::lock_guard l(stateLock_); + return runningOp_ != nullptr; +} + +size_t ArbitrationParticipant::numWaitingOps() const { + std::lock_guard l(stateLock_); + return waitOps_.size(); +} + +std::string ArbitrationParticipant::Stats::toString() const { + return fmt::format( + "numRequests: {}, numReclaims: {}, numShrinks: {}, numGrows: {}, reclaimedBytes: {}, growBytes: {}, aborted: {}, duration: {}", + numRequests, + numReclaims, + numShrinks, + numGrows, + succinctBytes(reclaimedBytes), + succinctBytes(growBytes), + aborted, + succinctMicros(durationUs)); +} + +ScopedArbitrationParticipant::ScopedArbitrationParticipant( + std::shared_ptr ArbitrationParticipant, + std::shared_ptr memPool) + : ArbitrationParticipant_(std::move(ArbitrationParticipant)), + pool_(std::move(memPool)) { + VELOX_CHECK_NOT_NULL(ArbitrationParticipant_); + VELOX_CHECK_NOT_NULL(pool_); +} + +ArbitrationCandidate::ArbitrationCandidate( + ScopedArbitrationParticipant&& _participant, + bool freeCapacityOnly) + : participant(std::move(_participant)), + reclaimableUsedCapacity( + freeCapacityOnly ? 0 : participant->reclaimableUsedCapacity()), + reclaimableFreeCapacity(participant->reclaimableFreeCapacity()) {} + +std::string ArbitrationCandidate::toString() const { + return fmt::format( + "{} RECLAIMABLE_USED_CAPACITY {} RECLAIMABLE_FREE_CAPACITY {}", + participant->name(), + succinctBytes(reclaimableUsedCapacity), + succinctBytes(reclaimableFreeCapacity)); +} +} // namespace facebook::velox::memory diff --git a/velox/common/memory/ArbitrationParticipant.h b/velox/common/memory/ArbitrationParticipant.h new file mode 100644 index 000000000000..b014a3c3327f --- /dev/null +++ b/velox/common/memory/ArbitrationParticipant.h @@ -0,0 +1,353 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/memory/MemoryArbitrator.h" + +#include "velox/common/base/Counters.h" +#include "velox/common/base/GTestMacros.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::memory { + +class ArbitrationOperation; +class ScopedArbitrationParticipant; + +/// Manages the memory arbitration operations on a query memory pool. It also +/// tracks the arbitration stats during the query memory pool's lifecycle. +class ArbitrationParticipant + : public std::enable_shared_from_this { + public: + struct Config { + /// The minimum capacity of a query memory pool. + uint64_t minCapacity; + + /// When growing a query memory pool capacity, the growth bytes will be + /// adjusted in the following way: + /// - If 2 * current capacity is less than or equal to + /// 'fastExponentialGrowthCapacityLimit', grow through fast path by at + /// least doubling the current capacity, when conditions allow (see below + /// NOTE section). + /// - If 2 * current capacity is greater than + /// 'fastExponentialGrowthCapacityLimit', grow through slow path by + /// growing capacity by at least 'slowCapacityGrowRatio' * current + /// capacity if allowed (see below NOTE section). + /// + /// NOTE: if original requested growth bytes is larger than the adjusted + /// growth bytes or adjusted growth bytes reaches max capacity limit, the + /// adjusted growth bytes will not be respected. + /// + /// NOTE: capacity growth adjust is only enabled if both + /// 'fastExponentialGrowthCapacityLimit' and 'slowCapacityGrowRatio' are + /// set, otherwise it is disabled. + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + + /// When shrinking a memory pool capacity, the shrink bytes will be adjusted + /// in a way such that AFTER shrink, the stricter (whichever is smaller) of + /// the following conditions is met, in order to better fit the query memory + /// pool's current memory usage: + /// - Free capacity is greater or equal to capacity * + /// 'minFreeCapacityRatio' + /// - Free capacity is greater or equal to 'minFreeCapacity' + /// + /// NOTE: in the conditions when original requested shrink bytes ends up + /// with more free capacity than above 2 conditions, the adjusted shrink + /// bytes is not respected. + /// + /// NOTE: capacity shrink adjustment is enabled when both + /// 'minFreeCapacityRatio' and 'minFreeCapacity' are set. + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + + Config( + uint64_t _minCapacity, + uint64_t _fastExponentialGrowthCapacityLimit, + double _slowCapacityGrowRatio, + uint64_t _minFreeCapacity, + double _minFreeCapacityRatio); + + std::string toString() const; + }; + + static std::shared_ptr create( + uint64_t id, + const std::shared_ptr& pool, + const Config* config); + + ~ArbitrationParticipant(); + + /// Returns the query memory pool name of this arbitration participant. + std::string name() const { + return pool_->name(); + } + + /// Returns the id of this arbitration participant assigned by the arbitrator. + /// The id is monotonically increasing and unique across all the alive + /// arbitration participants. + uint64_t id() const { + return id_; + } + + /// Returns the max capacity of the underlying query memory pool. + uint64_t maxCapacity() const { + return maxCapacity_; + } + + /// Returns the min capacity of the underlying query memory pool. + uint64_t minCapacity() const { + return config_->minCapacity; + } + + /// Returns the duration of this arbitration participant since its creation. + uint64_t durationUs() const { + const auto now = getCurrentTimeMicro(); + VELOX_CHECK_GE(now, createTimeUs_); + return now - createTimeUs_; + } + + /// Invoked to acquire a shared reference to this arbitration participant + /// which ensures the liveness of underlying query memory pool. If the query + /// memory pool is being destroyed, then this function returns std::nullopt. + /// + // NOTE: it is not safe to directly access arbitration participant as it only + // holds a weak ptr to the query memory pool. Use 'lock()' to get a scoped + // arbitration participant for access. + std::optional lock(); + + /// Returns the corresponding query memory pool. + MemoryPool* pool() const { + return pool_; + } + + /// Returns the current capacity of the query memory pool. + uint64_t capacity() const { + return pool_->capacity(); + } + + /// Gets the capacity growth targets based on 'requestBytes' and the query + /// memory pool's current capacity. 'maxGrowBytes' is set to allow fast + /// exponential growth when the query memory pool is small and switch to the + /// slow incremental growth after the query memory pool has grown big. + /// 'minGrowBytes' is set to ensure the query memory pool has the minimum + /// capacity and certain headroom free capacity after shrink. Both targets are + /// set to a coarser granularity to reduce the number of unnecessary future + /// memory arbitration requests. The parameters used to set the targets are + /// defined in 'config_'. + void getGrowTargets( + uint64_t requestBytes, + uint64_t& maxGrowBytes, + uint64_t& minGrowBytes) const; + + /// Returns the unused free memory capacity that can be reclaimed from the + /// query memory pool by shrink. + uint64_t reclaimableFreeCapacity() const; + + /// Returns the used memory capacity that can be reclaimed from the query + /// memory pool through disk spilling. + uint64_t reclaimableUsedCapacity() const; + + /// Checks if the query memory pool can grow 'requestBytes' from its current + /// capacity under the max capacity limit. + bool checkCapacityGrowth(uint64_t requestBytes) const; + + /// Invoked to grow the query memory pool capacity by 'growBytes' and commit + /// used reservation by 'reservationBytes'. The function throws if the growth + /// fails. + bool grow(uint64_t growBytes, uint64_t reservationBytes); + + /// Invoked to release the unused memory capacity by reducing its capacity. If + /// 'reclaimAll' is true, the function releases all the unused memory capacity + /// from the query memory pool without regarding to the minimum free capacity + /// restriction. + uint64_t shrink(bool reclaimAll = false); + + // Invoked to reclaim used memory from this memory pool with specified + // 'targetBytes'. The function returns the actually freed capacity. + uint64_t reclaim(uint64_t targetBytes, uint64_t maxWaitTimeMs) noexcept; + + /// Invoked to abort the query memory pool and returns the reclaimed bytes + /// after abort. + uint64_t abort(const std::exception_ptr& error) noexcept; + + /// Returns true if the query memory pool has been aborted. + bool aborted() const { + std::lock_guard l(stateLock_); + return aborted_; + } + + /// Invoked to wait for the pending memory reclaim or abort operation to + /// complete within a 'maxWaitTimeMs' time window. The function returns false + /// if the wait has timed out. + bool waitForReclaimOrAbort(uint64_t maxWaitTimeMs) const; + + /// Invoked to start arbitration operation 'op'. The operation needs to wait + /// for the prior arbitration operations to finish first before executing to + /// ensure the serialized execution of arbitration operations from the same + /// query memory pool. + void startArbitration(ArbitrationOperation* op); + + /// Invoked by a finished arbitration operation 'op' to kick off the next + /// waiting operation to start execution if there is one. + void finishArbitration(ArbitrationOperation* op); + + /// Returns true if there is a running arbitration operation on this + /// participant. + bool hasRunningOp() const; + + /// Returns the number of waiting arbitration operations on this participant. + size_t numWaitingOps() const; + + struct Stats { + uint64_t durationUs{0}; + uint32_t numRequests{0}; + uint32_t numReclaims{0}; + uint32_t numShrinks{0}; + uint32_t numGrows{0}; + uint64_t reclaimedBytes{0}; + uint64_t growBytes{0}; + bool aborted{false}; + + std::string toString() const; + }; + + Stats stats() const { + Stats stats; + stats.durationUs = durationUs(); + stats.aborted = aborted_; + stats.numRequests = numRequests_; + stats.numGrows = numGrows_; + stats.numShrinks = numShrinks_; + stats.numReclaims = numReclaims_; + stats.reclaimedBytes = reclaimedBytes_; + stats.growBytes = growBytes_; + return stats; + } + + private: + ArbitrationParticipant( + uint64_t id, + const std::shared_ptr& pool, + const Config* config); + + // Indicates if the query memory pool is actively used by a query execution or + // not. + bool inactivePool() const; + + // Returns the max capacity to reclaim from the query memory pool assuming all + // the query memory is reclaimable. + uint64_t maxReclaimableCapacity() const; + + // Returns the max capacity to shrink from the query memory pool. It ensures + // the memory pool having headroom free capacity after shrink as specified by + // 'minFreeCapacityRatio' and 'minFreeCapacity' in 'config_'. This helps to + // reduce the number of unnecessary memory arbitration requests. + uint64_t maxShrinkCapacity() const; + + // Returns the max capacity to grow of the query memory pool as specified by + // 'fastExponentialGrowthCapacityLimit' and 'slowCapacityGrowRatio' in + // 'config_'. + uint64_t maxGrowCapacity() const; + + // Returns the min capacity to grow the query memory pool to have the minnimal + // capacity as specified by 'minCapacity' in 'config_'. + uint64_t minGrowCapacity() const; + + // Aborts the query memory pool and returns the reclaimed bytes after abort. + uint64_t abortLocked(const std::exception_ptr& error) noexcept; + + const uint64_t id_; + const std::weak_ptr poolWeakPtr_; + MemoryPool* const pool_; + const Config* const config_; + const uint64_t maxCapacity_; + const size_t createTimeUs_; + + mutable std::mutex stateLock_; + bool aborted_{false}; + + // Points to the current running arbitration operation on this participant. + ArbitrationOperation* runningOp_{nullptr}; + + struct WaitOp { + ArbitrationOperation* op; + ContinuePromise waitPromise; + }; + /// The resume promises of the arbitration operations on this participant + /// waiting for serial execution. + std::deque waitOps_; + + tsan_atomic numRequests_{0}; + tsan_atomic numReclaims_{0}; + tsan_atomic numShrinks_{0}; + tsan_atomic numGrows_{0}; + tsan_atomic reclaimedBytes_{0}; + tsan_atomic growBytes_{0}; + + mutable std::timed_mutex reclaimLock_; + + friend class ScopedArbitrationParticipant; +}; + +/// The wrapper of the arbitration participant which holds a shared reference to +/// the query memory pool to ensure its liveness during memory arbitration +/// execution. +class ScopedArbitrationParticipant { + public: + ScopedArbitrationParticipant( + std::shared_ptr ArbitrationParticipant, + std::shared_ptr pool); + + ArbitrationParticipant* operator->() const { + return ArbitrationParticipant_.get(); + } + + ArbitrationParticipant& operator*() const { + return *ArbitrationParticipant_; + } + + ArbitrationParticipant& operator()() const { + return *ArbitrationParticipant_; + } + + ArbitrationParticipant* get() const { + return ArbitrationParticipant_.get(); + } + + private: + std::shared_ptr ArbitrationParticipant_; + std::shared_ptr pool_; +}; + +/// The candidate participant stats used by arbitrator to make arbitration +/// decisions. +struct ArbitrationCandidate { + ScopedArbitrationParticipant participant; + int64_t reclaimableUsedCapacity{0}; + int64_t reclaimableFreeCapacity{0}; + + /// If 'freeCapacityOnly' is true, the candidate is only used to reclaim free + /// capacity so only collects the free capacity stats. + ArbitrationCandidate( + ScopedArbitrationParticipant&& _participant, + bool freeCapacityOnly); + + std::string toString() const; +}; +} // namespace facebook::velox::memory diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 1f3b33ed3a90..985c21d2f91e 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -39,6 +39,10 @@ namespace facebook::velox::exec { class ParallelMemoryReclaimer; } +namespace facebook::velox::memory { +class TestArbitrator; +} + namespace facebook::velox::memory { #define VELOX_MEM_POOL_CAP_EXCEEDED(errorMessage) \ _VELOX_THROW( \ @@ -558,7 +562,9 @@ class MemoryPool : public std::enable_shared_from_this { friend class velox::exec::ParallelMemoryReclaimer; friend class MemoryManager; friend class MemoryArbitrator; + friend class velox::memory::TestArbitrator; friend class ScopedMemoryPoolArbitrationCtx; + friend class ArbitrationParticipant; VELOX_FRIEND_TEST(MemoryPoolTest, shrinkAndGrowAPIs); VELOX_FRIEND_TEST(MemoryPoolTest, grow); diff --git a/velox/common/memory/tests/ArbitrationParticipantTest.cpp b/velox/common/memory/tests/ArbitrationParticipantTest.cpp new file mode 100644 index 000000000000..d546d7f6647d --- /dev/null +++ b/velox/common/memory/tests/ArbitrationParticipantTest.cpp @@ -0,0 +1,1790 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include "folly/experimental/EventCount.h" +#include "folly/futures/Barrier.h" + +#include "gmock/gmock-matchers.h" +#include "velox/common/base/SuccinctPrinter.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/ArbitrationOperation.h" +#include "velox/common/memory/ArbitrationParticipant.h" +#include "velox/common/memory/MallocAllocator.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" + +DECLARE_bool(velox_memory_leak_check_enabled); +DECLARE_bool(velox_suppress_memory_capacity_exceeding_error_message); + +using namespace ::testing; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::memory { +static const std::string arbitratorKind("TEST"); + +class TestArbitrator : public MemoryArbitrator { + public: + explicit TestArbitrator(const Config& config) + : MemoryArbitrator( + {.kind = config.kind, + .capacity = config.capacity, + .extraConfigs = config.extraConfigs}) {} + + void addPool(const std::shared_ptr& /*unused*/) override {} + + void removePool(MemoryPool* /*unused*/) override {} + + bool growCapacity(MemoryPool* memoryPool, uint64_t requestBytes) override { + VELOX_CHECK_LE( + memoryPool->capacity() + requestBytes, memoryPool->maxCapacity()); + memoryPool->grow(requestBytes, requestBytes); + return true; + } + + uint64_t shrinkCapacity(uint64_t /*unused*/, bool /*unused*/, bool /*unused*/) + override { + VELOX_NYI(); + } + + uint64_t shrinkCapacity(MemoryPool* /*unused*/, uint64_t /*unused*/) + override { + VELOX_NYI(); + } + + Stats stats() const override { + VELOX_NYI(); + } + + std::string toString() const override { + VELOX_NYI(); + } + + std::string kind() const override { + return arbitratorKind; + } +}; + +namespace { +constexpr int64_t KB = 1024L; +constexpr int64_t MB = 1024L * KB; + +constexpr uint64_t kMemoryCapacity = 512 * MB; +constexpr uint64_t kMemoryPoolReservedCapacity = 64 * MB; +constexpr uint64_t kMemoryPoolMinFreeCapacity = 32 * MB; +constexpr double kMemoryPoolMinFreeCapacityRatio = 0.25; +constexpr uint64_t kFastExponentialGrowthCapacityLimit = 256 * MB; +constexpr double kSlowCapacityGrowRatio = 0.25; + +class MemoryReclaimer; + +using ReclaimInjectionCallback = + std::function; +using ArbitrationInjectionCallback = std::function; + +struct Allocation { + void* buffer{nullptr}; + size_t size{0}; +}; + +class MockTask : public std::enable_shared_from_this { + public: + MockTask(MemoryManager* manager, uint64_t capacity) + : root_(manager->addRootPool( + fmt::format("TaskPool-{}", taskId_++), + capacity)), + pool_(root_->addLeafChild("MockOperator")) {} + + ~MockTask() { + free(); + } + + class RootMemoryReclaimer : public memory::MemoryReclaimer { + public: + RootMemoryReclaimer(const std::shared_ptr& task) : task_(task) {} + + static std::unique_ptr create( + const std::shared_ptr& task) { + return std::make_unique(task); + } + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + auto task = task_.lock(); + if (task == nullptr) { + return false; + } + return memory::MemoryReclaimer::reclaimableBytes(pool, reclaimableBytes); + } + + uint64_t reclaim( + MemoryPool* pool, + uint64_t targetBytes, + uint64_t maxWaitMs, + Stats& stats) override { + auto task = task_.lock(); + if (task == nullptr) { + return 0; + } + return memory::MemoryReclaimer::reclaim( + pool, targetBytes, maxWaitMs, stats); + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) override { + auto task = task_.lock(); + if (task == nullptr) { + return; + } + memory::MemoryReclaimer::abort(pool, error); + } + + private: + std::weak_ptr task_; + }; + + class LeafMemoryReclaimer : public memory::MemoryReclaimer { + public: + LeafMemoryReclaimer( + std::shared_ptr task, + bool reclaimable, + ReclaimInjectionCallback reclaimInjectCb = nullptr, + ArbitrationInjectionCallback arbitrationInjectCb = nullptr) + : task_(task), + reclaimable_(reclaimable), + reclaimInjectCb_(std::move(reclaimInjectCb)), + arbitrationInjectCb_(std::move(arbitrationInjectCb)) {} + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + if (!reclaimable_) { + return false; + } + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + return task->reclaimableBytes(pool, reclaimableBytes); + } + + uint64_t reclaim( + MemoryPool* pool, + uint64_t targetBytes, + uint64_t /*unused*/, + Stats& stats) override { + if (!reclaimable_) { + return 0; + } + if (reclaimInjectCb_ != nullptr) { + reclaimInjectCb_(pool, targetBytes); + } + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + const auto reclaimBytes = task->reclaim(pool, targetBytes); + stats.reclaimedBytes += reclaimBytes; + return reclaimBytes; + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) override { + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + task->abort(pool, error); + } + + private: + std::weak_ptr task_; + const bool reclaimable_; + const ReclaimInjectionCallback reclaimInjectCb_; + const ArbitrationInjectionCallback arbitrationInjectCb_; + + std::exception_ptr abortError_; + }; + + void setMemoryReclaimers( + bool reclaimable, + ReclaimInjectionCallback reclaimInjectCb, + ArbitrationInjectionCallback arbitrationInjectCb) { + root_->setReclaimer(RootMemoryReclaimer::create(shared_from_this())); + pool_->setReclaimer(std::make_unique( + shared_from_this(), + reclaimable, + std::move(reclaimInjectCb), + std::move(arbitrationInjectCb))); + } + + const std::shared_ptr& pool() const { + return root_; + } + + std::exception_ptr abortError() const { + return abortError_; + } + + uint64_t capacity() const { + return root_->capacity(); + } + + void* allocate(uint64_t bytes) { + VELOX_CHECK_EQ(bytes % pool_->alignment(), 0); + + void* buffer = pool_->allocate(bytes); + std::lock_guard l(mu_); + totalBytes_ += bytes; + allocations_.emplace(buffer, bytes); + VELOX_CHECK_EQ(allocations_.count(buffer), 1); + return buffer; + } + + void free(void* buffer) { + size_t size{0}; + { + std::lock_guard l(mu_); + VELOX_CHECK_EQ(allocations_.count(buffer), 1); + size = allocations_[buffer]; + totalBytes_ -= size; + allocations_.erase(buffer); + } + pool_->free(buffer, size); + } + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const { + std::lock_guard l(mu_); + VELOX_CHECK_EQ(pool.name(), pool_->name()); + reclaimableBytes = totalBytes_; + return true; + } + + uint64_t reclaim(MemoryPool* pool, uint64_t targetBytes) { + VELOX_CHECK_GT(targetBytes, 0); + ++numReclaims_; + reclaimTargetBytes_.push_back(targetBytes); + uint64_t bytesReclaimed{0}; + std::vector allocationsToFree; + { + std::lock_guard l(mu_); + VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_EQ(pool->name(), pool_->name()); + + auto allocIt = allocations_.begin(); + while (allocIt != allocations_.end() && + ((targetBytes != 0) && (bytesReclaimed < targetBytes))) { + allocationsToFree.push_back({allocIt->first, allocIt->second}); + bytesReclaimed += allocIt->second; + allocIt = allocations_.erase(allocIt); + } + totalBytes_ -= bytesReclaimed; + } + for (const auto& allocation : allocationsToFree) { + pool_->free(allocation.buffer, allocation.size); + } + return bytesReclaimed; + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) { + ++numAborts_; + abortError_ = error; + free(); + } + + struct Stats { + uint64_t numReclaims; + uint64_t numAborts; + std::vector reclaimTargetBytes; + }; + + Stats stats() const { + Stats stats; + stats.numReclaims = numReclaims_; + stats.reclaimTargetBytes = reclaimTargetBytes_; + stats.numAborts = numAborts_; + return stats; + } + + private: + void free() { + std::unordered_map allocationsToFree; + { + std::lock_guard l(mu_); + for (auto entry : allocations_) { + totalBytes_ -= entry.second; + } + VELOX_CHECK_EQ(totalBytes_, 0); + allocationsToFree.swap(allocations_); + } + for (auto entry : allocationsToFree) { + pool_->free(entry.first, entry.second); + } + } + + inline static std::atomic_int taskId_{0}; + + const std::shared_ptr root_; + const std::shared_ptr pool_; + + mutable std::mutex mu_; + uint64_t totalBytes_{0}; + + std::unordered_map allocations_; + std::atomic_uint64_t numReclaims_{0}; + std::atomic_uint64_t numAborts_{0}; + std::vector reclaimTargetBytes_; + std::exception_ptr abortError_{nullptr}; +}; + +class ArbitrationParticipantTest : public testing::Test { + protected: + static void SetUpTestCase() { + SharedArbitrator::registerFactory(); + FLAGS_velox_memory_leak_check_enabled = true; + TestValue::enable(); + MemoryArbitrator::Factory factory = + [](const MemoryArbitrator::Config& config) { + return std::make_unique(config); + }; + MemoryArbitrator::registerFactory(arbitratorKind, factory); + } + + void SetUp() override { + setupMemory(); + } + + void TearDown() override {} + + void setupMemory(int64_t memoryCapacity = kMemoryCapacity) { + MemoryManagerOptions options; + options.allocatorCapacity = memoryCapacity; + options.arbitratorReservedCapacity = 0; + options.arbitratorKind = arbitratorKind; + options.checkUsageLeak = true; + manager_ = std::make_unique(options); + } + + std::shared_ptr createTask( + int64_t capacity = 0, + bool reclaimable = true, + ReclaimInjectionCallback reclaimInjectCb = nullptr, + ArbitrationInjectionCallback arbitrationInjectCb = nullptr) { + if (capacity == 0) { + capacity = manager_->capacity(); + } + auto task = std::make_shared(manager_.get(), capacity); + task->setMemoryReclaimers( + reclaimable, reclaimInjectCb, arbitrationInjectCb); + return task; + } + + std::unique_ptr manager_; + std::unique_ptr executor_ = + std::make_unique(4); +}; + +static ArbitrationParticipant::Config arbitrationConfig( + uint64_t minCapacity = kMemoryPoolReservedCapacity, + uint64_t fastExponentialGrowthCapacityLimit = + kFastExponentialGrowthCapacityLimit, + double slowCapacityGrowRatio = kSlowCapacityGrowRatio, + uint64_t minFreeCapacity = kMemoryPoolMinFreeCapacity, + double minFreeCapacityRatio = kMemoryPoolMinFreeCapacityRatio) { + return ArbitrationParticipant::Config{ + minCapacity, + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio, + minFreeCapacity, + minFreeCapacityRatio}; +} + +TEST_F(ArbitrationParticipantTest, config) { + struct { + uint64_t minCapacity; + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + bool expectedError; + std::string expectedToString; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit: {}, slowCapacityGrowRatio: {}, minFreeCapacity: {}, minFreeCapacityRatio: {}, expectedError: {}, expectedToString: {}", + succinctBytes(minCapacity), + succinctBytes(fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio, + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + expectedError, + expectedToString); + } + } testSettings[] = { + {1, + 1, + 0.1, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 0.1, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, + 0, + 0, + 0, + 0, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 0, minFreeCapacityRatio 0"}, + {1, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 1, + 0.1, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 0.1, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 0, + 0, + 0, + 0, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 0, minFreeCapacityRatio 0"}, + {0, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, 0, 0.1, 1, 0.1, true, ""}, + {1, 1, 0.1, 0, 0.1, true, ""}, + {1, 1, 0.1, 1, 0, true, ""}, + {1, + 1, + 2, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 2, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, 1, -1, 1, 0.1, true, ""}, + {1, 1, 0.1, 1, 2, true, ""}, + {1, 1, 0.1, 1, -1, true, ""}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + if (testData.expectedError) { + VELOX_ASSERT_THROW( + ArbitrationParticipant::Config( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio, + testData.minFreeCapacity, + testData.minFreeCapacityRatio), + ""); + continue; + } + const auto config = ArbitrationParticipant::Config( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + ASSERT_EQ(testData.minCapacity, config.minCapacity); + ASSERT_EQ( + testData.fastExponentialGrowthCapacityLimit, + config.fastExponentialGrowthCapacityLimit); + ASSERT_EQ(testData.slowCapacityGrowRatio, config.slowCapacityGrowRatio); + ASSERT_EQ(testData.minFreeCapacity, config.minFreeCapacity); + ASSERT_EQ(testData.minFreeCapacityRatio, config.minFreeCapacityRatio); + ASSERT_EQ(config.toString(), testData.expectedToString); + } +} + +TEST_F(ArbitrationParticipantTest, constructor) { + auto task = createTask(); + const auto config = arbitrationConfig(); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + ASSERT_EQ(participant->id(), 10); + ASSERT_EQ(participant->name(), task->pool()->name()); + ASSERT_EQ(participant->pool(), task->pool().get()); + ASSERT_EQ(participant->maxCapacity(), kMemoryCapacity); + ASSERT_EQ(participant->minCapacity(), kMemoryPoolReservedCapacity); + ASSERT_EQ(participant->capacity(), 0); + ASSERT_FALSE(participant->hasRunningOp()); + ASSERT_EQ(participant->numWaitingOps(), 0); + ASSERT_THAT( + participant->stats().toString(), + ::testing::StartsWith( + "numRequests: 0, numReclaims: 0, numShrinks: 0, numGrows: 0, reclaimedBytes: 0B, growBytes: 0B, aborted: false")); + + { + auto scopedParticipant = participant->lock().value(); + ASSERT_EQ(scopedParticipant->id(), 10); + ASSERT_EQ(scopedParticipant->name(), task->pool()->name()); + ASSERT_EQ(scopedParticipant->pool(), task->pool().get()); + ASSERT_EQ(scopedParticipant->maxCapacity(), kMemoryCapacity); + ASSERT_EQ(scopedParticipant->minCapacity(), kMemoryPoolReservedCapacity); + ASSERT_EQ(scopedParticipant->capacity(), 0); + + task.reset(); + + ASSERT_EQ(scopedParticipant->capacity(), 0); + ASSERT_EQ(participant->capacity(), 0); + } + + ASSERT_FALSE(participant->lock().has_value()); +} + +TEST_F(ArbitrationParticipantTest, getGrowTargets) { + struct { + uint64_t minCapacity; + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + uint64_t capacity; + uint64_t requestBytes; + uint64_t expectedMaxGrowTarget; + uint64_t expectedMinGrowTarget; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, capacity {}, requestBytes {}, expectedMaxGrowTarget {}, expectedMinGrowTarget {}", + succinctBytes(minCapacity), + succinctBytes(fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio, + succinctBytes(capacity), + succinctBytes(requestBytes), + succinctBytes(expectedMaxGrowTarget), + succinctBytes(expectedMinGrowTarget)); + } + } testSettings[] = { + // Without exponential growth. + {0, 0, 0.0, 0, 1 << 20, 1 << 20, 0}, + {0, 0, 0.0, 32 << 20, 1 << 20, 1 << 20, 0}, + {0, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 0}, + // Fast growth. + {0, 16 << 20, 1.0, 0, 1 << 20, 1 << 20, 0}, + {0, 16 << 20, 1.0, 1 << 20, 1 << 20, 1 << 20, 0}, + {0, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {0, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {0, 16 << 20, 1.0, 8 << 20, 1 << 20, 8 << 20, 0}, + {0, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {0, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {0, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + // Slow growth. + {0, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {0, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 0}, + {0, 16 << 20, 100.0, 24 << 20, 1 << 20, kMemoryCapacity - (24 << 20), 0}, + {0, 16 << 20, 0.1, 24 << 20, 1 << 20, uint64_t((24 << 20) * 0.1), 0}, + // With min capacity. + // Without exponential growth. + {4 << 20, 0, 0.0, 0, 1 << 20, 4 << 20, 4 << 20}, + {4 << 20, 0, 0.0, 32 << 20, 1 << 20, 1 << 20, 0}, + {4 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 0}, + {64 << 20, 0, 0.0, 32 << 20, 1 << 20, 32 << 20, 32 << 20}, + {64 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 32 << 20}, + {48 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 16 << 20}, + // Fast growth. + {1 << 20, 16 << 20, 1.0, 0, 1 << 20, 1 << 20, 1 << 20}, + {1 << 20, 16 << 20, 1.0, 0, 2 << 20, 2 << 20, 1 << 20}, + {4 << 20, 16 << 20, 1.0, 0, 1 << 20, 4 << 20, 4 << 20}, + {1 << 20, 16 << 20, 1.0, 1 << 20, 1 << 20, 1 << 20, 0}, + {1 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {2 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {4 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 2 << 20}, + {8 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 6 << 20, 6 << 20}, + {3 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {4 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {5 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 1 << 20}, + {1 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {12 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {13 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 1 << 20}, + {24 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 12 << 20}, + {25 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 13 << 20, 13 << 20}, + {1 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {16 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {17 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 1 << 20}, + {32 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 16 << 20}, + {64 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 48 << 20, 48 << 20}, + {1 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + {12 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + {13 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 1 << 20}, + {23 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 11 << 20}, + {35 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 23 << 20}, + {48 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 36 << 20, 36 << 20}, + // Slow growth. + {1 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {24 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {25 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 1 << 20}, + {48 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 24 << 20}, + {47 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 23 << 20}, + {64 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 40 << 20, 40 << 20}, + {1 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 0}, + {36 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 12 << 20}, + {72 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 48 << 20}, + {96 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 72 << 20, 72 << 20}, + {1 << 20, + 16 << 20, + 100.0, + 24 << 20, + 1 << 20, + kMemoryCapacity - (24 << 20), + 0}, + {36 << 20, + 16 << 20, + 100.0, + 24 << 20, + 1 << 20, + kMemoryCapacity - (24 << 20), + 12 << 20}, + {1 << 20, + 16 << 20, + 0.1, + 24 << 20, + 1 << 20, + uint64_t((24 << 20) * 0.1), + 0}, + {24 << 20, + 16 << 20, + 0.1, + 24 << 20, + 1 << 20, + uint64_t((24 << 20) * 0.1), + 0}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimFromAll=*/true); + ASSERT_EQ(scopedParticipant->capacity(), 0); + void* buffer = task->allocate(testData.capacity); + SCOPE_EXIT { + task->free(buffer); + }; + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + uint64_t maxGrowBytes{0}; + uint64_t minGrowBytes{0}; + scopedParticipant->getGrowTargets( + testData.requestBytes, maxGrowBytes, minGrowBytes); + ASSERT_EQ(maxGrowBytes, testData.expectedMaxGrowTarget); + ASSERT_EQ(minGrowBytes, testData.expectedMinGrowTarget); + + // Test operation corresponding API. + ArbitrationOperation op( + std::move(scopedParticipant), testData.requestBytes, 1 << 30); + op.start(); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + ASSERT_EQ(op.requestBytes(), testData.requestBytes); + op.setGrowTargets(); + ASSERT_EQ(op.requestBytes(), testData.requestBytes); + ASSERT_EQ(op.maxGrowBytes(), testData.expectedMaxGrowTarget); + ASSERT_EQ(op.minGrowBytes(), testData.expectedMinGrowTarget); + // Can't set grow targets twice. + VELOX_ASSERT_THROW( + op.setGrowTargets(), + "Arbitration operation grow targets have already been set"); + op.finish(); + } +} + +TEST_F(ArbitrationParticipantTest, reclaimableFreeCapacityAndShrink) { + struct { + uint64_t minCapacity; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + uint64_t capacity; + uint64_t usedBytes; + uint64_t peakBytes; + uint64_t expectedFreeCapacity; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, minFreeCapacity {}, minFreeCapacityRatio {}, capacity {}, usedBytes {}, peakBytes {}, expectedFreeCapacity {}", + succinctBytes(minCapacity), + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(peakBytes), + succinctBytes(expectedFreeCapacity)); + } + } testSettings[] = { + {128 << 20, 0, 0.0, 128 << 20, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 0, 32 << 20, 128 << 20}, + {128 << 20, 0, 0.0, 128 << 20, 32 << 20, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 128 << 20, 0, 0}, + {128 << 20, 0, 0.0, 256 << 20, 256 << 20, 0, 0}, + {128 << 20, 0, 0.0, 256 << 20, 200 << 20, 0, 56 << 20}, + {128 << 20, 0, 0.0, 256 << 20, 32 << 20, 0, 128 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 0, 0, 96 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 0, 0, 96 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 0, 0, 224 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 0, 64 << 20, 256 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 64 << 20, 0, 32 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 96 << 20, 0, 0}, + {0, 32 << 20, 0.25, 128 << 20, 72 << 20, 0, 24 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 64 << 20, 0, 32 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 96 << 20, 0, 0}, + {0, 64 << 20, 0.25, 128 << 20, 72 << 20, 0, 24 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 64 << 20, 0, 160 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 96 << 20, 0, 128 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 224 << 20, 0, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 0}, + {64 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 64 << 20}, + {96 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 32 << 20}, + {64 << 20, 64 << 20, 0.25, 128 << 20, 0, 0, 64 << 20}, + {64 << 20, 32 << 20, 0.5, 128 << 20, 0, 0, 64 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 32 << 20, 128 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 0}, + {64 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 32 << 20}, + {96 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 32 << 20}, + {96 << 20, 32 << 20, 0.25, 256 << 20, 64 << 20, 64 << 20, 160 << 20}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + for (bool reclaimAll : {false, true}) { + SCOPED_TRACE(fmt::format("reclaimAll {}", reclaimAll)); + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + 0, + 0.0, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + ASSERT_EQ(scopedParticipant->stats().numShrinks, 0); + if (testData.peakBytes > 0) { + void* buffer = task->allocate(testData.peakBytes); + task->free(buffer); + ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); + } + + scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + void* buffer{nullptr}; + if (testData.usedBytes > 0) { + buffer = task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_EQ( + scopedParticipant->reclaimableFreeCapacity(), + testData.expectedFreeCapacity); + + const uint64_t prevReclaimedBytes = + scopedParticipant->stats().reclaimedBytes; + if (reclaimAll) { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.capacity - testData.usedBytes); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.capacity - testData.usedBytes); + } else { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.expectedFreeCapacity); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.expectedFreeCapacity); + } + ASSERT_EQ(scopedParticipant->stats().numShrinks, 2); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 0); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_GE(scopedParticipant->stats().durationUs, 0); + ASSERT_FALSE(scopedParticipant->stats().aborted); + + if (buffer != nullptr) { + task->free(buffer); + } + } + } +} + +TEST_F(ArbitrationParticipantTest, reclaimableUsedCapacityAndReclaim) { + struct { + uint64_t minCapacity; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + uint64_t capacity; + uint64_t usedBytes; + uint64_t peakBytes; + uint64_t expectedReclaimableUsedBytes; + uint64_t expectedActualReclaimedBytes; + uint64_t expectedUsedBytes; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, minFreeCapacity {}, minFreeCapacityRatio {}, capacity {}, usedBytes {}, peakBytes {}, expectedReclaimableUsedBytes {}, expectedActualReclaimedBytes {}, expectedUsedBytes {}", + succinctBytes(minCapacity), + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(peakBytes), + succinctBytes(expectedReclaimableUsedBytes), + succinctBytes(expectedActualReclaimedBytes), + succinctBytes(expectedUsedBytes)); + } + } testSettings[] = { + {128 << 20, 0, 0.0, 128 << 20, 0, 0, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 0, 32 << 20, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 32 << 20, 0, 0, 0, 32 << 20}, + {64 << 20, 0, 0.0, 128 << 20, 96 << 20, 0, 64 << 20, 64 << 20, 32 << 20}, + {64 << 20, 0, 0.0, 128 << 20, 128 << 20, 0, 64 << 20, 64 << 20, 64 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 0, 0, 0, 0}, + {0, 64 << 20, 0.25, 128 << 20, 0, 0, 0, 0}, + {0, 32 << 20, 0.25, 256 << 20, 0, 0, 0, 0}, + {0, 32 << 20, 0.25, 256 << 20, 0, 64 << 20, 0, 0}, + {0, 32 << 20, 0.25, 128 << 20, 96 << 20, 0, 96 << 20, 128 << 20, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 0, 0, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 0, 0, 0, 64 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 128 << 20, 0, 0, 0, 128 << 20}, + {64 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 0, + 64 << 20, + 128 << 20, + 0}, + {128 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 0, + 0, + 64 << 20}, + {64 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 64 << 20, + 128 << 20, + 0}, + {96 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 32 << 20, + 32 << 20, + 32 << 20}, + {32 << 20, + 32 << 20, + 0.5, + 256 << 20, + 256 << 20, + 0, + 224 << 20, + 192 << 20, + 32 << 20}, + {32 << 20, + 64 << 20, + 0.125, + 256 << 20, + 256 << 20, + 0, + 224 << 20, + 192 << 20, + 32 << 20}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + 0, + 0.0, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + if (testData.peakBytes > 0) { + void* buffer = task->allocate(testData.peakBytes); + task->free(buffer); + ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); + } + + scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + for (int i = 0; i < testData.usedBytes / MB; ++i) { + task->allocate(MB); + } + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_EQ( + scopedParticipant->reclaimableUsedCapacity(), + testData.expectedReclaimableUsedBytes); + + const auto targetBytes = scopedParticipant->reclaimableUsedCapacity(); + const uint64_t prevReclaimedBytes = + scopedParticipant->stats().reclaimedBytes; + ASSERT_EQ( + scopedParticipant->reclaim(targetBytes, 1'000'000), + testData.expectedActualReclaimedBytes); + ASSERT_EQ( + scopedParticipant->pool()->usedBytes(), testData.expectedUsedBytes); + + if (targetBytes != 0) { + ASSERT_EQ(scopedParticipant->stats().numShrinks, 2); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 1); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_FALSE(scopedParticipant->stats().aborted); + } else { + ASSERT_EQ(scopedParticipant->stats().numShrinks, 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 0); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_FALSE(scopedParticipant->stats().aborted); + } + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.expectedActualReclaimedBytes); + } +} + +TEST_F(ArbitrationParticipantTest, checkCapacityGrowth) { + struct { + uint64_t maxCapacity; + uint64_t capacity; + uint64_t requestBytes; + bool expectedGrowth; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, capacity {}, requestBytes {}, expectedGrowth {}", + succinctBytes(maxCapacity), + succinctBytes(capacity), + succinctBytes(requestBytes), + expectedGrowth); + } + } testSettings[] = { + {128 << 20, 32 << 20, 1 << 20, true}, + {128 << 20, 128 << 20, 1 << 20, false}, + {128 << 20, 64 << 20, 64 << 20, true}, + {128 << 20, 128 << 20, 0, true}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + task->allocate(testData.capacity); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + ASSERT_EQ( + scopedParticipant->checkCapacityGrowth(testData.requestBytes), + testData.expectedGrowth); + } +} + +TEST_F(ArbitrationParticipantTest, grow) { + struct { + uint64_t maxCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t growthBytes; + uint64_t reservationBytes; + bool expectedFailure; + uint64_t expectedReservationBytes; + uint64_t expectedCapacityAfterGrowth; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, capacity {}, usedBytes {}, growthBytes {}, reservationBytes {}, expectedFailure {}, expectedReservationBytes {}, expectedCapacityAfterGrowth {}", + succinctBytes(maxCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(growthBytes), + succinctBytes(reservationBytes), + expectedFailure, + succinctBytes(expectedReservationBytes), + succinctBytes(expectedCapacityAfterGrowth)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 1 << 20, 0, false, 0, 129 << 20}, + {256 << 20, 128 << 20, 0, 256 << 20, 0, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 0 << 20, 192 << 20, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 256 << 20, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 32 << 20, false, 32 << 20, 160 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 16 << 20, false, 16 << 20, 160 << 20}, + {256 << 20, 128 << 20, 0, 0, 16 << 20, false, 16 << 20, 128 << 20}, + {256 << 20, 128 << 20, 0, 0, 128 << 20, false, 128 << 20, 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 0, + 16 << 20, + false, + 112 << 20, + 128 << 20}, + {256 << 20, 128 << 20, 96 << 20, 0, 64 << 20, true, 96 << 20, 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 8 << 20, + 64 << 20, + true, + 96 << 20, + 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 128 << 20, + 64 << 20, + false, + 160 << 20, + 256 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 256 << 20, + 64 << 20, + true, + 96 << 20, + 128 << 20}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + const uint64_t prevGrowBytes = scopedParticipant->stats().growBytes; + ASSERT_EQ( + !testData.expectedFailure, + scopedParticipant->grow( + testData.growthBytes, testData.reservationBytes)); + ASSERT_EQ( + testData.expectedReservationBytes, + scopedParticipant->pool()->reservedBytes()); + ASSERT_EQ( + scopedParticipant->capacity(), testData.expectedCapacityAfterGrowth); + if (!testData.expectedFailure && testData.reservationBytes > 0) { + static_cast(scopedParticipant->pool()) + ->testingSetReservation( + testData.expectedReservationBytes - testData.reservationBytes); + } + if (testData.expectedFailure) { + ASSERT_EQ(scopedParticipant->stats().growBytes, prevGrowBytes); + } else { + ASSERT_EQ( + scopedParticipant->stats().growBytes, + prevGrowBytes + testData.growthBytes); + } + } +} + +TEST_F(ArbitrationParticipantTest, shrink) { + struct { + uint64_t maxCapacity; + uint64_t minCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t expectedFreeCapacity; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, minCapacity {}, capacity {}, usedBytes {}, expectedFreeCapacity {}", + succinctBytes(maxCapacity), + succinctBytes(minCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(expectedFreeCapacity)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 0, 0}, + {256 << 20, 128 << 20, 64 << 20, 0, 0}, + {256 << 20, 128 << 20, 64 << 20, 32 << 20, 0}, + {256 << 20, 128 << 20, 64 << 20, 64 << 20, 0}, + {256 << 20, 128 << 20, 128 << 20, 64 << 20, 0}, + {256 << 20, 128 << 20, 192 << 20, 64 << 20, 64 << 20}, + {256 << 20, 128 << 20, 256 << 20, 128 << 20, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 0, 128 << 20}, + {256 << 20, 128 << 20, 192 << 20, 0, 64 << 20}, + {256 << 20, 128 << 20, 128 << 20, 0, 0}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + for (bool reclaimAll : {false, true}) { + SCOPED_TRACE(fmt::format("reclaimAll {}", reclaimAll)); + + auto task = createTask(testData.maxCapacity); + const auto config = + arbitrationConfig(testData.minCapacity, 0, 0.0, 0, 0.0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + const uint64_t prevFreedBytes = scopedParticipant->stats().reclaimedBytes; + const uint32_t prevNumShrunks = scopedParticipant->stats().numShrinks; + if (reclaimAll) { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.capacity - testData.usedBytes); + ASSERT_EQ( + prevFreedBytes + testData.capacity - testData.usedBytes, + scopedParticipant->stats().reclaimedBytes); + } else { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.expectedFreeCapacity); + ASSERT_EQ( + prevFreedBytes + testData.expectedFreeCapacity, + scopedParticipant->stats().reclaimedBytes); + } + ASSERT_EQ(prevNumShrunks + 1, scopedParticipant->stats().numShrinks); + } + } +} + +TEST_F(ArbitrationParticipantTest, abort) { + struct { + uint64_t maxCapacity; + uint64_t minCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t expectedReclaimCapacity; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, minCapacity {}, capacity {}, usedBytes {}, expectedReclaimCapacity {}", + succinctBytes(maxCapacity), + succinctBytes(minCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(expectedReclaimCapacity)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 0, 0}, + {256 << 20, 128 << 20, 128 << 20, 0, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 0, 256 << 20}, + {256 << 20, 128 << 20, 64 << 20, 0, 64 << 20}, + {256 << 20, 128 << 20, 128 << 20, 64 << 20, 128 << 20}, + {256 << 20, 128 << 20, 128 << 20, 128 << 20, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 128 << 20, 256 << 20}, + {256 << 20, 128 << 20, 256 << 20, 256 << 20, 256 << 20}, + {256 << 20, 128 << 20, 64 << 20, 32 << 20, 64 << 20}, + {256 << 20, 128 << 20, 64 << 20, 64 << 20, 64 << 20}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(testData.minCapacity, 0, 0.0, 0, 0.0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_FALSE(scopedParticipant->stats().aborted); + ASSERT_FALSE(scopedParticipant->aborted()); + const uint64_t prevFreedBytes = scopedParticipant->stats().reclaimedBytes; + const uint32_t prevNumShrunks = scopedParticipant->stats().numShrinks; + const uint32_t prevNumReclaims = scopedParticipant->stats().numReclaims; + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ( + scopedParticipant->abort(std::current_exception()), + testData.expectedReclaimCapacity); + } + ASSERT_TRUE(task->pool()->aborted()); + ASSERT_TRUE(scopedParticipant->stats().aborted); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevFreedBytes + testData.expectedReclaimCapacity); + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims); + + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); + } + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ(scopedParticipant->capacity(), 0); + + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims + 1); + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 2); + } +} + +DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + const uint64_t allocatedBytes = 32 * MB; + for (int i = 0; i < 32; ++i) { + task->allocate(MB); + } + auto scopedParticipant = participant->lock().value(); + + std::atomic_bool reclaim1WaitFlag{false}; + folly::EventCount reclaim1Wait; + std::atomic_bool reclaim1ResumeFlag{false}; + folly::EventCount reclaim1Resume; + std::atomic_bool reclaim2WaitFlag{false}; + folly::EventCount reclaim2Wait; + std::atomic_bool reclaim2ResumeFlag{false}; + folly::EventCount reclaim2Resume; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + if (!reclaim1WaitFlag.exchange(true)) { + reclaim1Wait.notifyAll(); + reclaim1Resume.await([&]() { return reclaim1ResumeFlag.load(); }); + return; + } + if (!reclaim2WaitFlag.exchange(true)) { + reclaim2Wait.notifyAll(); + reclaim1Resume.await([&]() { return reclaim1ResumeFlag.load(); }); + return; + } + }))); + + std::atomic_bool abortWaitFlag{false}; + folly::EventCount abortWait; + std::atomic_bool abortResumeFlag{false}; + folly::EventCount abortResume; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + if (!abortWaitFlag.exchange(true)) { + abortWait.notifyAll(); + abortResume.await([&]() { return abortResumeFlag.load(); }); + return; + } + }))); + + std::atomic_bool reclaim1CompletedFlag{false}; + folly::EventCount reclaim1CompletedWait; + std::thread reclaimThread1([&]() { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + reclaim1CompletedFlag = true; + reclaim1CompletedWait.notifyAll(); + }); + reclaim1Wait.await([&]() { return reclaim1WaitFlag.load(); }); + + std::atomic_bool abortCompletedFlag{false}; + folly::EventCount abortCompletedWait; + std::thread abortThread([&]() { + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 32 * MB); + } + abortCompletedFlag = true; + abortCompletedWait.notifyAll(); + }); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + ASSERT_FALSE(reclaim1CompletedFlag); + ASSERT_FALSE(abortWaitFlag); + + reclaim1ResumeFlag = true; + reclaim1Resume.notifyAll(); + reclaim1CompletedWait.await([&]() { return reclaim1CompletedFlag.load(); }); + reclaimThread1.join(); + + abortWait.await([&]() { return abortWaitFlag.load(); }); + + std::atomic_bool reclaim2CompletedFlag{false}; + folly::EventCount reclaim2CompletedWait; + std::thread reclaimThread2([&]() { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + reclaim2CompletedFlag = true; + reclaim2CompletedWait.notifyAll(); + }); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + ASSERT_FALSE(abortCompletedFlag); + ASSERT_FALSE(reclaim2WaitFlag); + + abortResumeFlag = true; + abortResume.notifyAll(); + abortCompletedWait.await([&]() { return abortCompletedFlag.load(); }); + abortThread.join(); + + reclaim2ResumeFlag = true; + reclaim2Resume.notifyAll(); + reclaim2CompletedWait.await([&]() { return reclaim2CompletedFlag.load(); }); + reclaimThread2.join(); + + ASSERT_TRUE(task->pool()->aborted()); + ASSERT_TRUE(task->abortError() != nullptr); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ(scopedParticipant->capacity(), 0); + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), 0); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 2); + ASSERT_EQ(scopedParticipant->stats().numShrinks, 3); + ASSERT_EQ(scopedParticipant->stats().reclaimedBytes, 32 << 20); +} + +DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, waitForReclaimOrAbort) { + struct { + uint64_t waitTimeUs; + bool pendingReclaim; + uint64_t reclaimWaitMs{0}; + bool expectedTimeout; + + std::string debugString() const { + return fmt::format( + "waitTime {}, pendingReclaim {}, reclaimWait {}, expectedTimeout {}", + succinctMicros(waitTimeUs), + pendingReclaim, + succinctMillis(reclaimWaitMs), + expectedTimeout); + } + } testSettings[] = { + {0, true, 1'000, true}, + {0, false, 1'000, true}, + {1'000'000, true, 1'000, false}, + {1'000'000, true, 1'000, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::atomic_bool reclaimWaitFlag{false}; + folly::EventCount reclaimWait; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + reclaimWaitFlag = true; + reclaimWait.notifyAll(); + std::this_thread::sleep_for( + std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + reclaimWaitFlag = true; + reclaimWait.notifyAll(); + std::this_thread::sleep_for( + std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT + }))); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + task->allocate(MB); + auto scopedParticipant = participant->lock().value(); + + std::thread reclaimThread([&]() { + if (testData.pendingReclaim) { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), MB); + } else { + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), MB); + } + } + }); + reclaimWait.await([&]() { return reclaimWaitFlag.load(); }); + ASSERT_EQ( + scopedParticipant->waitForReclaimOrAbort(testData.waitTimeUs), + !testData.expectedTimeout); + reclaimThread.join(); + } +} + +TEST_F(ArbitrationParticipantTest, capacityCheck) { + auto task = createTask(256 << 20); + const auto config = arbitrationConfig(512 << 20); + VELOX_ASSERT_THROW( + ArbitrationParticipant::create(0, task->pool(), &config), + "The min capacity is larger than the max capacity for memory pool"); +} + +TEST_F(ArbitrationParticipantTest, arbitrationCandidate) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(32 << 20, 0); + ASSERT_EQ(scopedParticipant->capacity(), 32 << 20); + task->allocate(MB); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), MB); + + ArbitrationCandidate candidateWithFreeCapacityOnly( + participant->lock().value(), /*freeCapacityOnly=*/true); + ASSERT_EQ( + candidateWithFreeCapacityOnly.participant->name(), + scopedParticipant->name()); + ASSERT_EQ(candidateWithFreeCapacityOnly.reclaimableUsedCapacity, 0); + ASSERT_EQ(candidateWithFreeCapacityOnly.reclaimableFreeCapacity, 31 << 20); + ASSERT_EQ( + candidateWithFreeCapacityOnly.toString(), + "TaskPool-0 RECLAIMABLE_USED_CAPACITY 0B RECLAIMABLE_FREE_CAPACITY 31.00MB"); + + ArbitrationCandidate candidate( + participant->lock().value(), /*freeCapacityOnly=*/false); + ASSERT_EQ(candidate.participant->name(), scopedParticipant->name()); + ASSERT_EQ(candidate.reclaimableUsedCapacity, MB); + ASSERT_EQ(candidate.reclaimableFreeCapacity, 31 << 20); + ASSERT_EQ( + candidate.toString(), + "TaskPool-0 RECLAIMABLE_USED_CAPACITY 1.00MB RECLAIMABLE_FREE_CAPACITY 31.00MB"); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperation) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + const int participantId{10}; + auto participant = + ArbitrationParticipant::create(participantId, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + const int requestBytes = 1 << 20; + const int opTimeoutMs = 1'000'000; + ArbitrationOperation op( + participant->lock().value(), requestBytes, opTimeoutMs); + VELOX_ASSERT_THROW( + ArbitrationOperation(participant->lock().value(), 0, opTimeoutMs), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_FALSE(op.aborted()); + ASSERT_FALSE(op.hasTimeout()); + ASSERT_EQ(op.allocatedBytes(), 0); + ASSERT_LE(op.timeoutMs(), opTimeoutMs); + + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); // NOLINT + ASSERT_GE(op.executionTimeMs(), 1'000); + ASSERT_LE(op.timeoutMs(), opTimeoutMs - 1'000); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 0); + ASSERT_FALSE(op.hasTimeout()); + VELOX_ASSERT_THROW(op.setGrowTargets(), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 0); + + ASSERT_EQ(op.state(), ArbitrationOperation::State::kInit); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + op.start(); + op.setGrowTargets(); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), requestBytes); + ASSERT_EQ(op.minGrowBytes(), 0); + VELOX_ASSERT_THROW(op.setGrowTargets(), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), requestBytes); + ASSERT_EQ(op.minGrowBytes(), 0); + + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + ASSERT_EQ(op.state(), ArbitrationOperation::State::kRunning); + + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + op.setGlobalArbitrationWaitTimeUs(2'000); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 2'000); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + op.allocatedBytes() = op.maxGrowBytes(); + + op.finish(); + ASSERT_EQ(op.state(), ArbitrationOperation::State::kFinished); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + ASSERT_FALSE(op.hasTimeout()); + const auto execTimeMs = op.executionTimeMs(); + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); // NOLINT + ASSERT_EQ(op.executionTimeMs(), execTimeMs); + ASSERT_FALSE(op.hasTimeout()); + + // Operation timeout. + { + ArbitrationOperation timedOutOp(participant->lock().value(), 1 << 20, 100); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // NOLINT + ASSERT_TRUE(timedOutOp.hasTimeout()); + + ArbitrationOperation noTimedoutOp( + participant->lock().value(), 1 << 20, 100); + noTimedoutOp.start(); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // NOLINT + noTimedoutOp.finish(); + ASSERT_FALSE(noTimedoutOp.hasTimeout()); + } + + // Operation abort. + { + ArbitrationOperation abortOp(participant->lock().value(), 1 << 20, 100); + ASSERT_FALSE(abortOp.aborted()); + try { + VELOX_FAIL("abort op"); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); + } + ASSERT_TRUE(abortOp.aborted()); + + ArbitrationOperation abortCheckOp( + participant->lock().value(), 1 << 20, 100); + ASSERT_TRUE(abortCheckOp.aborted()); + } +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationWait) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + const int requestBytes = 1 << 20; + const int opTimeoutMs = 1'000'000; + ArbitrationOperation op1( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op2( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op3( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op4(participant->lock().value(), requestBytes, 1'000); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + + op1.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + ASSERT_EQ(op1.state(), ArbitrationOperation::State::kRunning); + + std::thread op2Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op2.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_FALSE(op2.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 2); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kWaiting); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + ASSERT_EQ(scopedParticipant->numWaitingOps(), 2); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kWaiting); + op2.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 1) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + std::thread op3Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op3.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_FALSE(op3.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 1); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kWaiting); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + ASSERT_EQ(scopedParticipant->numWaitingOps(), 1); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kWaiting); + op3.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + std::thread op4Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op4.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_TRUE(op4.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + op4.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 3) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + op1.finish(); + ASSERT_EQ(op1.state(), ArbitrationOperation::State::kFinished); + ASSERT_FALSE(op1.hasTimeout()); + ASSERT_GE(op1.executionTimeMs(), 1'000); + + op2Thread.join(); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op2.executionTimeMs(), 1'000 + 500); + + op3Thread.join(); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op3.executionTimeMs(), 1'000 + 500 + 500); + + op4Thread.join(); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op4.executionTimeMs(), 1'000 + 500 + 500); + + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + + ASSERT_EQ(scopedParticipant->stats().numRequests, 4); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationFuzzerTest) { + const int numThreads = 10; + const int numOpsPerThread = 100; + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + + std::vector arbitrationThreads; + for (int i = 0; i < numThreads; ++i) { + arbitrationThreads.emplace_back([&, i]() { + folly::Random::DefaultGenerator rng; + rng.seed(i); + for (int j = 0; j < numOpsPerThread; ++j) { + const int numExecutionTimeUs = folly::Random::rand32(0, 1'000, rng); + ArbitrationOperation op(participant->lock().value(), 1 << 20, 1'000); + op.start(); + std::this_thread::sleep_for( + std::chrono::microseconds(numExecutionTimeUs)); // NOLINT + op.finish(); + } + }); + } + for (auto& thread : arbitrationThreads) { + thread.join(); + } + + ASSERT_EQ(participant->stats().numRequests, numThreads * numOpsPerThread); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationState) { + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kInit), + "init"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kWaiting), + "waiting"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kRunning), + "running"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kFinished), + "finished"); + ASSERT_EQ( + ArbitrationOperation::stateName( + static_cast(10)), + "unknown state: 10"); +} +} // namespace +} // namespace facebook::velox::memory