diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index 0d1e0acf8b1de..ed14558403e39 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -266,8 +266,6 @@ class FakeMemoryReclaimer : public MemoryReclaimer { auto* driver = driverThreadCtx->driverCtx.driver; ASSERT_TRUE(driver != nullptr); if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) { - // There is no need for arbitration if the associated task has already - // terminated. VELOX_FAIL("Terminate detected when entering suspension"); } } @@ -331,17 +329,7 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase { options.memoryPoolInitCapacity = memoryPoolInitCapacity; options.memoryPoolTransferCapacity = memoryPoolTransferCapacity; options.checkUsageLeak = true; - options.arbitrationStateCheckCb = [](MemoryPool& pool) { - const auto* driverThreadCtx = driverThreadContext(); - if (driverThreadCtx != nullptr) { - if (!driverThreadCtx->driverCtx.driver->state().isSuspended) { - LOG(ERROR) - << "false " - << driverThreadCtx->driverCtx.driver->state().toJsonString(); - } - ASSERT_TRUE(driverThreadCtx->driverCtx.driver->state().isSuspended); - } - }; + options.arbitrationStateCheckCb = driverArbitrationStateCheck; memoryManager_ = std::make_unique(options); ASSERT_EQ(memoryManager_->arbitrator()->kind(), "SHARED"); arbitrator_ = static_cast(memoryManager_->arbitrator()); diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 2053042577022..ab82a91528e63 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -828,6 +828,19 @@ std::string Driver::toJsonString() const { return folly::toPrettyJson(obj); } +void driverArbitrationStateCheck(memory::MemoryPool& pool) { + const auto* driverThreadCtx = driverThreadContext(); + if (driverThreadCtx != nullptr) { + Driver* driver = driverThreadCtx->driverCtx.driver; + if (!driver->state().isSuspended) { + VELOX_FAIL( + "Driver thread is not suspended under memory arbitration processing: {}, request memory pool: {}", + driver->toString(), + pool.name()); + } + } +} + SuspendedSection::SuspendedSection(Driver* driver) : driver_(driver) { if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) { VELOX_FAIL("Terminate detected when entering suspended section"); diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index e28fd3e4c6df7..2b11f8674ddb4 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -280,15 +280,15 @@ class Driver : public std::enable_shared_from_this { void initializeOperatorStats(std::vector& stats); - // Close operators and add operator stats to the task. + /// Close operators and add operator stats to the task. void closeOperators(); - // Returns true if all operators between the source and 'aggregation' are - // order-preserving and do not increase cardinality. + /// Returns true if all operators between the source and 'aggregation' are + /// order-preserving and do not increase cardinality. bool mayPushdownAggregation(Operator* aggregation) const; - // Returns a subset of channels for which there are operators upstream from - // filterSource that accept dynamically generated filters. + /// Returns a subset of channels for which there are operators upstream from + /// filterSource that accept dynamically generated filters. std::unordered_set canPushdownFilters( const Operator* filterSource, const std::vector& channels) const; @@ -300,7 +300,7 @@ class Driver : public std::enable_shared_from_this { /// Returns the Operator with 'operatorId' or nullptr if not found. Operator* findOperator(int32_t operatorId) const; - // Returns a list of all operators. + /// Returns a list of all operators. std::vector operators() const; std::string toString() const; @@ -315,8 +315,8 @@ class Driver : public std::enable_shared_from_this { return ctx_->task; } - // Updates the stats in Task and frees resources. Only called by Task for - // closing non-running Drivers. + /// Updates the stats in Task and frees resources. Only called by Task for + /// closing non-running Drivers. void closeByTask(); BlockingReason blockingReason() const { @@ -349,10 +349,10 @@ class Driver : public std::enable_shared_from_this { // position in the pipeline. void pushdownFilters(int operatorIndex); - /// If 'trackOperatorCpuUsage_' is true, returns initialized timer object to - /// track cpu and wall time of an operation. Returns null otherwise. - /// The delta CpuWallTiming object would be passes to 'func' upon - /// destruction of the timer. + // If 'trackOperatorCpuUsage_' is true, returns initialized timer object to + // track cpu and wall time of an operation. Returns null otherwise. + // The delta CpuWallTiming object would be passes to 'func' upon + // destruction of the timer. template std::unique_ptr> createDeltaCpuWallTimer(F&& func) { return trackOperatorCpuUsage_ @@ -397,6 +397,14 @@ class Driver : public std::enable_shared_from_this { friend struct DriverFactory; }; +/// Callback used by memory arbitration to check if a driver thread under memory +/// arbitration has been put in suspension state. This is to prevent arbitration +/// deadlock as the arbitrator might reclaim memory from the task of the driver +/// thread which is under arbitration. The task reclaim needs to wait for the +/// drivers to go off thread. A suspended driver thread is not counted as +/// running. +void driverArbitrationStateCheck(memory::MemoryPool& pool); + using OperatorSupplier = std::function< std::unique_ptr(int32_t operatorId, DriverCtx* ctx)>;