Skip to content

Commit

Permalink
Add back memory reclaim async callback to remove async dependency on … (
Browse files Browse the repository at this point in the history
facebookincubator#10845)

Summary:
The async source depends on memory module to setup memory arbitration context.
The global memory arbitration optimization might also need the async source to parallelize
the query spill so it can cause circular dependency. The memory arbitration context will
only be used by memory reclamation so it is not necessary be part of base lib. This PR keeps
previous async memory reclaim callback to handle the memory arbitration context setup to
remove this dependency.

Pull Request resolved: facebookincubator#10845

Reviewed By: Yuhta

Differential Revision: D61828202

Pulled By: xiaoxmeng

fbshipit-source-id: 09d1be89d42a7b618425f6efc221e761ba580e49
  • Loading branch information
xiaoxmeng authored and Joe-Abraham committed Sep 3, 2024
1 parent eacb0ec commit 63fd5f4
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 50 deletions.
12 changes: 1 addition & 11 deletions velox/common/base/AsyncSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/Portability.h"
#include "velox/common/future/VeloxPromise.h"
#include "velox/common/memory/MemoryArbitrator.h"
#include "velox/common/process/ThreadDebugInfo.h"
#include "velox/common/testutil/TestValue.h"

Expand All @@ -43,10 +42,6 @@ class AsyncSource {
public:
explicit AsyncSource(std::function<std::unique_ptr<Item>()> make)
: make_(std::move(make)) {
if (memory::memoryArbitrationContext() != nullptr) {
memoryArbitrationContext_ = *memory::memoryArbitrationContext();
}

if (process::GetThreadDebugInfo() != nullptr) {
auto* currentThreadDebugInfo = process::GetThreadDebugInfo();
// We explicitly leave out the callback when copying the ThreadDebugInfo
Expand Down Expand Up @@ -203,18 +198,13 @@ class AsyncSource {

private:
std::unique_ptr<Item> runMake(std::function<std::unique_ptr<Item>()>& make) {
memory::ScopedMemoryArbitrationContext memoryArbitrationContext(
memoryArbitrationContext_.has_value()
? &memoryArbitrationContext_.value()
: nullptr);
process::ScopedThreadDebugInfo threadDebugInfo(
threadDebugInfo_.has_value() ? &threadDebugInfo_.value() : nullptr);
return make();
}

// Stored contexts (if present upon construction) so they can be restored when
// Stored context (if present upon construction) so they can be restored when
// make_ is invoked.
std::optional<memory::MemoryArbitrationContext> memoryArbitrationContext_;
std::optional<process::ThreadDebugInfo> threadDebugInfo_;

mutable std::mutex mutex_;
Expand Down
18 changes: 0 additions & 18 deletions velox/common/base/tests/AsyncSourceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
#include <folly/Synchronized.h>
#include <folly/synchronization/Baton.h>
#include <gtest/gtest.h>
#include <chrono>
#include <thread>
#include "velox/common/base/Exceptions.h"
#include "velox/common/memory/Memory.h"

using namespace facebook::velox;
using namespace std::chrono_literals;
Expand Down Expand Up @@ -249,15 +247,10 @@ TEST(AsyncSourceTest, close) {
void verifyContexts(
const std::string& expectedPoolName,
const std::string& expectedTaskId) {
EXPECT_EQ(
memory::memoryArbitrationContext()->requestor->name(), expectedPoolName);
EXPECT_EQ(process::GetThreadDebugInfo()->taskId_, expectedTaskId);
}

TEST(AsyncSourceTest, emptyContexts) {
memory::MemoryManager::testingSetInstance({});

EXPECT_EQ(memory::memoryArbitrationContext(), nullptr);
EXPECT_EQ(process::GetThreadDebugInfo(), nullptr);

AsyncSource<bool> src([]() {
Expand All @@ -268,9 +261,6 @@ TEST(AsyncSourceTest, emptyContexts) {
return std::make_unique<bool>(true);
});

auto pool = memory::MemoryManager::getInstance()->addRootPool("test");
memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext(
pool.get());
process::ThreadDebugInfo debugInfo{"query_id", "task_id", nullptr};
process::ScopedThreadDebugInfo scopedDebugInfo(debugInfo);

Expand All @@ -282,14 +272,9 @@ TEST(AsyncSourceTest, emptyContexts) {
}

TEST(AsyncSourceTest, setContexts) {
memory::MemoryManager::testingSetInstance({});

auto pool1 = memory::MemoryManager::getInstance()->addRootPool("test1");
process::ThreadDebugInfo debugInfo1{"query_id1", "task_id1", nullptr};

std::unique_ptr<AsyncSource<bool>> src;
memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext1(
pool1.get());
process::ScopedThreadDebugInfo scopedDebugInfo1(debugInfo1);

verifyContexts("test1", "task_id1");
Expand All @@ -302,9 +287,6 @@ TEST(AsyncSourceTest, setContexts) {
return std::make_unique<bool>(true);
}));

auto pool2 = memory::MemoryManager::getInstance()->addRootPool("test2");
memory::ScopedMemoryArbitrationContext scopedMemoryArbitrationContext2(
pool2.get());
process::ThreadDebugInfo debugInfo2{"query_id2", "task_id2", nullptr};
process::ScopedThreadDebugInfo scopedDebugInfo2(debugInfo2);

Expand Down
9 changes: 0 additions & 9 deletions velox/common/memory/MemoryArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,15 +480,6 @@ ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext(
arbitrationCtx = &currentArbitrationCtx_;
}

ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext(
const MemoryArbitrationContext* contextToRestore)
: savedArbitrationCtx_(arbitrationCtx) {
if (contextToRestore != nullptr) {
currentArbitrationCtx_ = *contextToRestore;
arbitrationCtx = &currentArbitrationCtx_;
}
}

ScopedMemoryArbitrationContext::~ScopedMemoryArbitrationContext() {
arbitrationCtx = savedArbitrationCtx_;
}
Expand Down
25 changes: 18 additions & 7 deletions velox/common/memory/MemoryArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <vector>

#include "velox/common/base/AsyncSource.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/Portability.h"
#include "velox/common/base/SuccinctPrinter.h"
Expand Down Expand Up @@ -419,13 +420,6 @@ class ScopedMemoryArbitrationContext {
public:
explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor);

/// Can be used to restore a previously captured MemoryArbitrationContext.
/// contextToRestore can be nullptr if there was no context at the time it was
/// captured, in which case arbitrationCtx is unchanged upon
/// contruction/destruction of this object.
explicit ScopedMemoryArbitrationContext(
const MemoryArbitrationContext* contextToRestore);

~ScopedMemoryArbitrationContext();

private:
Expand All @@ -451,6 +445,23 @@ const MemoryArbitrationContext* memoryArbitrationContext();
/// Returns true if the running thread is under memory arbitration or not.
bool underMemoryArbitration();

/// Creates an async memory reclaim task with memory arbitration context set.
/// This is to avoid recursive memory arbitration during memory reclaim.
///
/// NOTE: this must be called under memory arbitration.
template <typename Item>
std::shared_ptr<AsyncSource<Item>> createAsyncMemoryReclaimTask(
std::function<std::unique_ptr<Item>()> task) {
auto* arbitrationCtx = memory::memoryArbitrationContext();
VELOX_CHECK_NOT_NULL(arbitrationCtx);
return std::make_shared<AsyncSource<Item>>(
[asyncTask = std::move(task), arbitrationCtx]() -> std::unique_ptr<Item> {
VELOX_CHECK_NOT_NULL(arbitrationCtx);
memory::ScopedMemoryArbitrationContext ctx(arbitrationCtx->requestor);
return asyncTask();
});
}

/// The function triggers memory arbitration by shrinking memory pools from
/// 'manager' by invoking shrinkPools API. If 'manager' is not set, then it
/// shrinks from the process wide memory manager. If 'targetBytes' is zero, then
Expand Down
2 changes: 1 addition & 1 deletion velox/common/memory/tests/MockSharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ TEST_F(MockSharedArbitrationTest, asyncArbitrationWork) {

explicit Result(bool _succeeded) : succeeded(_succeeded) {}
};
auto asyncReclaimTask = std::make_shared<AsyncSource<Result>>([&]() {
auto asyncReclaimTask = createAsyncMemoryReclaimTask<Result>([&]() {
memoryOp->allocate(poolCapacity);
return std::make_unique<Result>(true);
});
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ void HashBuild::reclaim(
for (auto* op : operators) {
HashBuild* buildOp = static_cast<HashBuild*>(op);
spillTasks.push_back(
std::make_shared<AsyncSource<SpillResult>>([buildOp]() {
memory::createAsyncMemoryReclaimTask<SpillResult>([buildOp]() {
try {
buildOp->spiller_->spill();
buildOp->table_->clear();
Expand Down
6 changes: 3 additions & 3 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,7 @@ void HashProbe::spillOutput(const std::vector<HashProbe*>& operators) {
for (auto* op : operators) {
HashProbe* probeOp = static_cast<HashProbe*>(op);
spillTasks.push_back(
std::make_shared<AsyncSource<SpillResult>>([probeOp]() {
memory::createAsyncMemoryReclaimTask<SpillResult>([probeOp]() {
try {
probeOp->spillOutput();
return std::make_unique<SpillResult>(nullptr);
Expand Down Expand Up @@ -1868,8 +1868,8 @@ SpillPartitionSet HashProbe::spillTable() {
if (rowContainer->numRows() == 0) {
continue;
}
spillTasks.push_back(
std::make_shared<AsyncSource<SpillResult>>([this, rowContainer]() {
spillTasks.push_back(memory::createAsyncMemoryReclaimTask<SpillResult>(
[this, rowContainer]() {
try {
return std::make_unique<SpillResult>(spillTable(rowContainer));
} catch (const std::exception& e) {
Expand Down

0 comments on commit 63fd5f4

Please sign in to comment.