Skip to content

Commit

Permalink
Add driver arbitration state check callback (facebookincubator#6656)
Browse files Browse the repository at this point in the history
Summary:
This is a leftover from facebookincubator#6643.

Pull Request resolved: facebookincubator#6656

Reviewed By: tanjialiang

Differential Revision: D49469932

Pulled By: xiaoxmeng

fbshipit-source-id: 16a46afc634e97259b0f20d380d90958d58fd3e5
  • Loading branch information
xiaoxmeng authored and codyschierbeck committed Sep 27, 2023
1 parent d04f124 commit fd8c0a1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
14 changes: 1 addition & 13 deletions velox/common/memory/tests/SharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ class FakeMemoryReclaimer : public MemoryReclaimer {
auto* driver = driverThreadCtx->driverCtx.driver;
ASSERT_TRUE(driver != nullptr);
if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) {
// There is no need for arbitration if the associated task has already
// terminated.
VELOX_FAIL("Terminate detected when entering suspension");
}
}
Expand Down Expand Up @@ -331,17 +329,7 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase {
options.memoryPoolInitCapacity = memoryPoolInitCapacity;
options.memoryPoolTransferCapacity = memoryPoolTransferCapacity;
options.checkUsageLeak = true;
options.arbitrationStateCheckCb = [](MemoryPool& pool) {
const auto* driverThreadCtx = driverThreadContext();
if (driverThreadCtx != nullptr) {
if (!driverThreadCtx->driverCtx.driver->state().isSuspended) {
LOG(ERROR)
<< "false "
<< driverThreadCtx->driverCtx.driver->state().toJsonString();
}
ASSERT_TRUE(driverThreadCtx->driverCtx.driver->state().isSuspended);
}
};
options.arbitrationStateCheckCb = driverArbitrationStateCheck;
memoryManager_ = std::make_unique<MemoryManager>(options);
ASSERT_EQ(memoryManager_->arbitrator()->kind(), "SHARED");
arbitrator_ = static_cast<SharedArbitrator*>(memoryManager_->arbitrator());
Expand Down
13 changes: 13 additions & 0 deletions velox/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,19 @@ std::string Driver::toJsonString() const {
return folly::toPrettyJson(obj);
}

void driverArbitrationStateCheck(memory::MemoryPool& pool) {
const auto* driverThreadCtx = driverThreadContext();
if (driverThreadCtx != nullptr) {
Driver* driver = driverThreadCtx->driverCtx.driver;
if (!driver->state().isSuspended) {
VELOX_FAIL(
"Driver thread is not suspended under memory arbitration processing: {}, request memory pool: {}",
driver->toString(),
pool.name());
}
}
}

SuspendedSection::SuspendedSection(Driver* driver) : driver_(driver) {
if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) {
VELOX_FAIL("Terminate detected when entering suspended section");
Expand Down
32 changes: 20 additions & 12 deletions velox/exec/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ class Driver : public std::enable_shared_from_this<Driver> {

void initializeOperatorStats(std::vector<OperatorStats>& stats);

// Close operators and add operator stats to the task.
/// Close operators and add operator stats to the task.
void closeOperators();

// Returns true if all operators between the source and 'aggregation' are
// order-preserving and do not increase cardinality.
/// Returns true if all operators between the source and 'aggregation' are
/// order-preserving and do not increase cardinality.
bool mayPushdownAggregation(Operator* aggregation) const;

// Returns a subset of channels for which there are operators upstream from
// filterSource that accept dynamically generated filters.
/// Returns a subset of channels for which there are operators upstream from
/// filterSource that accept dynamically generated filters.
std::unordered_set<column_index_t> canPushdownFilters(
const Operator* filterSource,
const std::vector<column_index_t>& channels) const;
Expand All @@ -300,7 +300,7 @@ class Driver : public std::enable_shared_from_this<Driver> {
/// Returns the Operator with 'operatorId' or nullptr if not found.
Operator* findOperator(int32_t operatorId) const;

// Returns a list of all operators.
/// Returns a list of all operators.
std::vector<Operator*> operators() const;

std::string toString() const;
Expand All @@ -315,8 +315,8 @@ class Driver : public std::enable_shared_from_this<Driver> {
return ctx_->task;
}

// Updates the stats in Task and frees resources. Only called by Task for
// closing non-running Drivers.
/// Updates the stats in Task and frees resources. Only called by Task for
/// closing non-running Drivers.
void closeByTask();

BlockingReason blockingReason() const {
Expand Down Expand Up @@ -349,10 +349,10 @@ class Driver : public std::enable_shared_from_this<Driver> {
// position in the pipeline.
void pushdownFilters(int operatorIndex);

/// If 'trackOperatorCpuUsage_' is true, returns initialized timer object to
/// track cpu and wall time of an operation. Returns null otherwise.
/// The delta CpuWallTiming object would be passes to 'func' upon
/// destruction of the timer.
// If 'trackOperatorCpuUsage_' is true, returns initialized timer object to
// track cpu and wall time of an operation. Returns null otherwise.
// The delta CpuWallTiming object would be passes to 'func' upon
// destruction of the timer.
template <typename F>
std::unique_ptr<DeltaCpuWallTimer<F>> createDeltaCpuWallTimer(F&& func) {
return trackOperatorCpuUsage_
Expand Down Expand Up @@ -397,6 +397,14 @@ class Driver : public std::enable_shared_from_this<Driver> {
friend struct DriverFactory;
};

/// Callback used by memory arbitration to check if a driver thread under memory
/// arbitration has been put in suspension state. This is to prevent arbitration
/// deadlock as the arbitrator might reclaim memory from the task of the driver
/// thread which is under arbitration. The task reclaim needs to wait for the
/// drivers to go off thread. A suspended driver thread is not counted as
/// running.
void driverArbitrationStateCheck(memory::MemoryPool& pool);

using OperatorSupplier = std::function<
std::unique_ptr<Operator>(int32_t operatorId, DriverCtx* ctx)>;

Expand Down

0 comments on commit fd8c0a1

Please sign in to comment.