From 63fd5f4cdba1193c1dae50bcf450277db68be890 Mon Sep 17 00:00:00 2001 From: xiaoxmeng Date: Tue, 27 Aug 2024 12:27:18 -0700 Subject: [PATCH] =?UTF-8?q?Add=20back=20memory=20reclaim=20async=20callbac?= =?UTF-8?q?k=20to=20remove=20async=20dependency=20on=20=E2=80=A6=20(#10845?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The async source depends on memory module to setup memory arbitration context. The global memory arbitration optimization might also need the async source to parallelize the query spill so it can cause circular dependency. The memory arbitration context will only be used by memory reclamation so it is not necessary be part of base lib. This PR keeps previous async memory reclaim callback to handle the memory arbitration context setup to remove this dependency. Pull Request resolved: https://github.com/facebookincubator/velox/pull/10845 Reviewed By: Yuhta Differential Revision: D61828202 Pulled By: xiaoxmeng fbshipit-source-id: 09d1be89d42a7b618425f6efc221e761ba580e49 --- velox/common/base/AsyncSource.h | 12 +-------- velox/common/base/tests/AsyncSourceTest.cpp | 18 ------------- velox/common/memory/MemoryArbitrator.cpp | 9 ------- velox/common/memory/MemoryArbitrator.h | 25 +++++++++++++------ .../memory/tests/MockSharedArbitratorTest.cpp | 2 +- velox/exec/HashBuild.cpp | 2 +- velox/exec/HashProbe.cpp | 6 ++--- 7 files changed, 24 insertions(+), 50 deletions(-) diff --git a/velox/common/base/AsyncSource.h b/velox/common/base/AsyncSource.h index c5b241f2fb70a..76740dd8681a6 100644 --- a/velox/common/base/AsyncSource.h +++ b/velox/common/base/AsyncSource.h @@ -26,7 +26,6 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/base/Portability.h" #include "velox/common/future/VeloxPromise.h" -#include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/process/ThreadDebugInfo.h" #include "velox/common/testutil/TestValue.h" @@ -43,10 +42,6 @@ class AsyncSource { public: explicit AsyncSource(std::function()> make) : make_(std::move(make)) { - if (memory::memoryArbitrationContext() != nullptr) { - memoryArbitrationContext_ = *memory::memoryArbitrationContext(); - } - if (process::GetThreadDebugInfo() != nullptr) { auto* currentThreadDebugInfo = process::GetThreadDebugInfo(); // We explicitly leave out the callback when copying the ThreadDebugInfo @@ -203,18 +198,13 @@ class AsyncSource { private: std::unique_ptr runMake(std::function()>& make) { - memory::ScopedMemoryArbitrationContext memoryArbitrationContext( - memoryArbitrationContext_.has_value() - ? &memoryArbitrationContext_.value() - : nullptr); process::ScopedThreadDebugInfo threadDebugInfo( threadDebugInfo_.has_value() ? &threadDebugInfo_.value() : nullptr); return make(); } - // Stored contexts (if present upon construction) so they can be restored when + // Stored context (if present upon construction) so they can be restored when // make_ is invoked. - std::optional memoryArbitrationContext_; std::optional threadDebugInfo_; mutable std::mutex mutex_; diff --git a/velox/common/base/tests/AsyncSourceTest.cpp b/velox/common/base/tests/AsyncSourceTest.cpp index 14836753ea7ed..657a7ba8e08e8 100644 --- a/velox/common/base/tests/AsyncSourceTest.cpp +++ b/velox/common/base/tests/AsyncSourceTest.cpp @@ -20,10 +20,8 @@ #include #include #include -#include #include #include "velox/common/base/Exceptions.h" -#include "velox/common/memory/Memory.h" using namespace facebook::velox; using namespace std::chrono_literals; @@ -249,15 +247,10 @@ TEST(AsyncSourceTest, close) { void verifyContexts( const std::string& expectedPoolName, const std::string& expectedTaskId) { - EXPECT_EQ( - memory::memoryArbitrationContext()->requestor->name(), expectedPoolName); EXPECT_EQ(process::GetThreadDebugInfo()->taskId_, expectedTaskId); } TEST(AsyncSourceTest, emptyContexts) { - memory::MemoryManager::testingSetInstance({}); - - EXPECT_EQ(memory::memoryArbitrationContext(), nullptr); EXPECT_EQ(process::GetThreadDebugInfo(), nullptr); AsyncSource src([]() { @@ -268,9 +261,6 @@ TEST(AsyncSourceTest, emptyContexts) { return std::make_unique(true); }); - auto pool = memory::MemoryManager::getInstance()->addRootPool("test"); - memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext( - pool.get()); process::ThreadDebugInfo debugInfo{"query_id", "task_id", nullptr}; process::ScopedThreadDebugInfo scopedDebugInfo(debugInfo); @@ -282,14 +272,9 @@ TEST(AsyncSourceTest, emptyContexts) { } TEST(AsyncSourceTest, setContexts) { - memory::MemoryManager::testingSetInstance({}); - - auto pool1 = memory::MemoryManager::getInstance()->addRootPool("test1"); process::ThreadDebugInfo debugInfo1{"query_id1", "task_id1", nullptr}; std::unique_ptr> src; - memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext1( - pool1.get()); process::ScopedThreadDebugInfo scopedDebugInfo1(debugInfo1); verifyContexts("test1", "task_id1"); @@ -302,9 +287,6 @@ TEST(AsyncSourceTest, setContexts) { return std::make_unique(true); })); - auto pool2 = memory::MemoryManager::getInstance()->addRootPool("test2"); - memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext2( - pool2.get()); process::ThreadDebugInfo debugInfo2{"query_id2", "task_id2", nullptr}; process::ScopedThreadDebugInfo scopedDebugInfo2(debugInfo2); diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index 048c32afbcd82..89d786e1acc43 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -480,15 +480,6 @@ ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext( arbitrationCtx = ¤tArbitrationCtx_; } -ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext( - const MemoryArbitrationContext* contextToRestore) - : savedArbitrationCtx_(arbitrationCtx) { - if (contextToRestore != nullptr) { - currentArbitrationCtx_ = *contextToRestore; - arbitrationCtx = ¤tArbitrationCtx_; - } -} - ScopedMemoryArbitrationContext::~ScopedMemoryArbitrationContext() { arbitrationCtx = savedArbitrationCtx_; } diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index 87274453c9aeb..51eb6739641c6 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -18,6 +18,7 @@ #include +#include "velox/common/base/AsyncSource.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/Portability.h" #include "velox/common/base/SuccinctPrinter.h" @@ -419,13 +420,6 @@ class ScopedMemoryArbitrationContext { public: explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor); - /// Can be used to restore a previously captured MemoryArbitrationContext. - /// contextToRestore can be nullptr if there was no context at the time it was - /// captured, in which case arbitrationCtx is unchanged upon - /// contruction/destruction of this object. - explicit ScopedMemoryArbitrationContext( - const MemoryArbitrationContext* contextToRestore); - ~ScopedMemoryArbitrationContext(); private: @@ -451,6 +445,23 @@ const MemoryArbitrationContext* memoryArbitrationContext(); /// Returns true if the running thread is under memory arbitration or not. bool underMemoryArbitration(); +/// Creates an async memory reclaim task with memory arbitration context set. +/// This is to avoid recursive memory arbitration during memory reclaim. +/// +/// NOTE: this must be called under memory arbitration. +template +std::shared_ptr> createAsyncMemoryReclaimTask( + std::function()> task) { + auto* arbitrationCtx = memory::memoryArbitrationContext(); + VELOX_CHECK_NOT_NULL(arbitrationCtx); + return std::make_shared>( + [asyncTask = std::move(task), arbitrationCtx]() -> std::unique_ptr { + VELOX_CHECK_NOT_NULL(arbitrationCtx); + memory::ScopedMemoryArbitrationContext ctx(arbitrationCtx->requestor); + return asyncTask(); + }); +} + /// The function triggers memory arbitration by shrinking memory pools from /// 'manager' by invoking shrinkPools API. If 'manager' is not set, then it /// shrinks from the process wide memory manager. If 'targetBytes' is zero, then diff --git a/velox/common/memory/tests/MockSharedArbitratorTest.cpp b/velox/common/memory/tests/MockSharedArbitratorTest.cpp index 3b701391fab9c..59e14c460691d 100644 --- a/velox/common/memory/tests/MockSharedArbitratorTest.cpp +++ b/velox/common/memory/tests/MockSharedArbitratorTest.cpp @@ -706,7 +706,7 @@ TEST_F(MockSharedArbitrationTest, asyncArbitrationWork) { explicit Result(bool _succeeded) : succeeded(_succeeded) {} }; - auto asyncReclaimTask = std::make_shared>([&]() { + auto asyncReclaimTask = createAsyncMemoryReclaimTask([&]() { memoryOp->allocate(poolCapacity); return std::make_unique(true); }); diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index d88b13a60ebbc..ea14000e2886a 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -1163,7 +1163,7 @@ void HashBuild::reclaim( for (auto* op : operators) { HashBuild* buildOp = static_cast(op); spillTasks.push_back( - std::make_shared>([buildOp]() { + memory::createAsyncMemoryReclaimTask([buildOp]() { try { buildOp->spiller_->spill(); buildOp->table_->clear(); diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index defbbcdd8e119..077dab88b6f08 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -1766,7 +1766,7 @@ void HashProbe::spillOutput(const std::vector& operators) { for (auto* op : operators) { HashProbe* probeOp = static_cast(op); spillTasks.push_back( - std::make_shared>([probeOp]() { + memory::createAsyncMemoryReclaimTask([probeOp]() { try { probeOp->spillOutput(); return std::make_unique(nullptr); @@ -1868,8 +1868,8 @@ SpillPartitionSet HashProbe::spillTable() { if (rowContainer->numRows() == 0) { continue; } - spillTasks.push_back( - std::make_shared>([this, rowContainer]() { + spillTasks.push_back(memory::createAsyncMemoryReclaimTask( + [this, rowContainer]() { try { return std::make_unique(spillTable(rowContainer)); } catch (const std::exception& e) {