Skip to content

Commit

Permalink
Avoid cross thread access of thread local in parallel join build (#11402
Browse files Browse the repository at this point in the history
)

Summary:
parallel building threads in parallel join are accessing the main driver thread's thread local variable. The cross thread access of thread local variable can create undefined behavior. Instead of capturing the thread local, we capture the driver context directly to avoid this.

Pull Request resolved: #11402

Reviewed By: bikramSingh91

Differential Revision: D65308117

Pulled By: tanjialiang

fbshipit-source-id: be428f4967052b7240d5af6a0f550d5cd87d574a
  • Loading branch information
tanjialiang authored and facebook-github-bot committed Nov 1, 2024
1 parent ee8a683 commit 7b2bb7f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
8 changes: 4 additions & 4 deletions velox/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ RowVectorPtr Driver::next(ContinueFuture* future) {
auto self = shared_from_this();
facebook::velox::process::ScopedThreadDebugInfo scopedInfo(
self->driverCtx()->threadDebugInfo);
ScopedDriverThreadContext scopedDriverThreadContext(*self->driverCtx());
ScopedDriverThreadContext scopedDriverThreadContext(self->driverCtx());
std::shared_ptr<BlockingState> blockingState;
RowVectorPtr result;
const auto stop = runInternal(self, blockingState, result);
Expand Down Expand Up @@ -759,7 +759,7 @@ void Driver::run(std::shared_ptr<Driver> self) {
process::TraceContext trace("Driver::run");
facebook::velox::process::ScopedThreadDebugInfo scopedInfo(
self->driverCtx()->threadDebugInfo);
ScopedDriverThreadContext scopedDriverThreadContext(*self->driverCtx());
ScopedDriverThreadContext scopedDriverThreadContext(self->driverCtx());
std::shared_ptr<BlockingState> blockingState;
RowVectorPtr nullResult;
auto reason = self->runInternal(self, blockingState, nullResult);
Expand Down Expand Up @@ -1151,9 +1151,9 @@ DriverThreadContext* driverThreadContext() {
return driverThreadCtx;
}

ScopedDriverThreadContext::ScopedDriverThreadContext(const DriverCtx& driverCtx)
ScopedDriverThreadContext::ScopedDriverThreadContext(const DriverCtx* driverCtx)
: savedDriverThreadCtx_(driverThreadCtx),
currentDriverThreadCtx_(DriverThreadContext(&driverCtx)) {
currentDriverThreadCtx_(DriverThreadContext(driverCtx)) {
driverThreadCtx = &currentDriverThreadCtx_;
}

Expand Down
2 changes: 1 addition & 1 deletion velox/exec/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ class DriverThreadContext {
/// starts/leaves the driver thread.
class ScopedDriverThreadContext {
public:
explicit ScopedDriverThreadContext(const DriverCtx& driverCtx);
explicit ScopedDriverThreadContext(const DriverCtx* driverCtx);
explicit ScopedDriverThreadContext(
const DriverThreadContext* _driverThreadCtx);
~ScopedDriverThreadContext();
Expand Down
16 changes: 11 additions & 5 deletions velox/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,13 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
rowPartitions.push_back(table->rows()->createRowPartitions(*rows_->pool()));
}

const auto* driverThreadCtx = driverThreadContext();
// Passing driver context directly to avoid cross thread access to thread
// local driver thread context.
const DriverCtx* driverCtx{nullptr};
if (const auto* driverThreadCtx = driverThreadContext()) {
driverCtx = driverThreadCtx->driverCtx();
}

// The parallel table partitioning step.
for (auto i = 0; i < numPartitions; ++i) {
auto* table = getTable(i);
Expand All @@ -938,8 +944,8 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
return std::make_unique<bool>(true);
}));
VELOX_CHECK(!partitionSteps.empty());
buildExecutor_->add([driverThreadCtx, step = partitionSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverThreadCtx);
buildExecutor_->add([driverCtx, step = partitionSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverCtx);
step->prepare();
});
}
Expand All @@ -965,8 +971,8 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
return std::make_unique<bool>(true);
}));
VELOX_CHECK(!buildSteps.empty());
buildExecutor_->add([driverThreadCtx, step = buildSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverThreadCtx);
buildExecutor_->add([driverCtx, step = buildSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverCtx);
step->prepare();
});
}
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/tests/MemoryReclaimerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TEST_F(MemoryReclaimerTest, enterArbitrationTest) {
fakeTask_->testingIncrementThreads();
if (underDriverContext) {
driver->state().setThread();
ScopedDriverThreadContext scopedDriverThreadCtx{*driver->driverCtx()};
ScopedDriverThreadContext scopedDriverThreadCtx{driver->driverCtx()};
reclaimer->enterArbitration();
ASSERT_TRUE(driver->state().isOnThread());
ASSERT_TRUE(driver->state().suspended());
Expand Down

0 comments on commit 7b2bb7f

Please sign in to comment.