From 1ce3c7ad2f8348c2e0edf6fe2185086b09d8bdc7 Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Wed, 27 Nov 2024 18:42:10 -0800 Subject: [PATCH] fix: Fix the memory reclaim bytes for hash join (#11642) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11642 Both hash join and probe does the coordinated spill so we shouldn't report the reclaimed bytes from a single node but shall report from the plan node. Also probe side spill might spill built table from join side and the memory is actually reclaimed from build side pool instead of probe side. This PR also removes the unused wait for spill state from hash build Reviewed By: bikramSingh91, tanjialiang Differential Revision: D66437719 fbshipit-source-id: 4ac7bb9cf87b4346d234ef5f7c04ed64ee12d249 --- velox/exec/Driver.cpp | 6 +- velox/exec/Driver.h | 2 + velox/exec/HashBuild.cpp | 13 -- velox/exec/HashBuild.h | 9 +- velox/exec/HashJoinBridge.cpp | 24 ++- velox/exec/Task.cpp | 6 +- velox/exec/Task.h | 4 +- velox/exec/TaskStructs.h | 28 ++- velox/exec/tests/HashJoinTest.cpp | 246 +++++++++++++++++++++++- velox/exec/tests/LocalPartitionTest.cpp | 4 +- 10 files changed, 299 insertions(+), 43 deletions(-) diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 6fc28f822f8b..30357a1e240e 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -1140,14 +1140,14 @@ std::string blockingReasonToString(BlockingReason reason) { return "kWaitForMemory"; case BlockingReason::kWaitForConnector: return "kWaitForConnector"; - case BlockingReason::kWaitForSpill: - return "kWaitForSpill"; case BlockingReason::kYield: return "kYield"; case BlockingReason::kWaitForArbitration: return "kWaitForArbitration"; + default: + VELOX_UNREACHABLE( + fmt::format("Unknown blocking reason {}", static_cast(reason))); } - VELOX_UNREACHABLE(); } DriverThreadContext* driverThreadContext() { diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index 497a64c5215b..a5cfe209d290 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -205,6 +205,8 @@ enum class BlockingReason { kWaitForConnector, /// Build operator is blocked waiting for all its peers to stop to run group /// spill on all of them. + /// + /// TODO: remove this after Prestissimo is updated. kWaitForSpill, /// Some operators (like Table Scan) may run long loops and can 'voluntarily' /// exit them because Task requested to yield or stop or after a certain time. diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 638e006d6b78..f6220a5acf99 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -35,8 +35,6 @@ BlockingReason fromStateToBlockingReason(HashBuild::State state) { return BlockingReason::kNotBlocked; case HashBuild::State::kYield: return BlockingReason::kYield; - case HashBuild::State::kWaitForSpill: - return BlockingReason::kWaitForSpill; case HashBuild::State::kWaitForBuild: return BlockingReason::kWaitForJoinBuild; case HashBuild::State::kWaitForProbe: @@ -944,13 +942,6 @@ BlockingReason HashBuild::isBlocked(ContinueFuture* future) { break; case State::kFinish: break; - case State::kWaitForSpill: - if (!future_.valid()) { - setRunning(); - VELOX_CHECK_NOT_NULL(input_); - addInput(std::move(input_)); - } - break; case State::kWaitForBuild: [[fallthrough]]; case State::kWaitForProbe: @@ -1003,8 +994,6 @@ void HashBuild::checkStateTransition(State state) { break; case State::kWaitForBuild: [[fallthrough]]; - case State::kWaitForSpill: - [[fallthrough]]; case State::kWaitForProbe: [[fallthrough]]; case State::kFinish: @@ -1022,8 +1011,6 @@ std::string HashBuild::stateName(State state) { return "RUNNING"; case State::kYield: return "YIELD"; - case State::kWaitForSpill: - return "WAIT_FOR_SPILL"; case State::kWaitForBuild: return "WAIT_FOR_BUILD"; case State::kWaitForProbe: diff --git a/velox/exec/HashBuild.h b/velox/exec/HashBuild.h index 74be0f4cef81..0b12554afc8d 100644 --- a/velox/exec/HashBuild.h +++ b/velox/exec/HashBuild.h @@ -44,17 +44,14 @@ class HashBuild final : public Operator { /// The yield state that voluntarily yield cpu after running too long when /// processing input from spilled file. kYield = 2, - /// The state that waits for the pending group spill to finish. This state - /// only applies if disk spilling is enabled. - kWaitForSpill = 3, /// The state that waits for the hash tables to be merged together. - kWaitForBuild = 4, + kWaitForBuild = 3, /// The state that waits for the hash probe to finish before start to build /// the hash table for one of previously spilled partition. This state only /// applies if disk spilling is enabled. - kWaitForProbe = 5, + kWaitForProbe = 4, /// The finishing state. - kFinish = 6, + kFinish = 5, }; static std::string stateName(State state); diff --git a/velox/exec/HashJoinBridge.cpp b/velox/exec/HashJoinBridge.cpp index 79eb6e47f36b..8961affb15b8 100644 --- a/velox/exec/HashJoinBridge.cpp +++ b/velox/exec/HashJoinBridge.cpp @@ -382,11 +382,12 @@ uint64_t HashJoinMemoryReclaimer::reclaim( uint64_t targetBytes, uint64_t maxWaitMs, memory::MemoryReclaimer::Stats& stats) { + const auto prevNodeReservedMemory = pool->reservedBytes(); + // The flags to track if we have reclaimed from both build and probe operators // under a hash join node. bool hasReclaimedFromBuild{false}; bool hasReclaimedFromProbe{false}; - uint64_t reclaimedBytes{0}; pool->visitChildren([&](memory::MemoryPool* child) { VELOX_CHECK_EQ(child->kind(), memory::MemoryPool::Kind::kLeaf); const bool isBuild = isHashBuildMemoryPool(*child); @@ -394,7 +395,7 @@ uint64_t HashJoinMemoryReclaimer::reclaim( if (!hasReclaimedFromBuild) { // We just need to reclaim from any one of the hash build operator. hasReclaimedFromBuild = true; - reclaimedBytes += child->reclaim(targetBytes, maxWaitMs, stats); + child->reclaim(targetBytes, maxWaitMs, stats); } return !hasReclaimedFromProbe; } @@ -403,22 +404,25 @@ uint64_t HashJoinMemoryReclaimer::reclaim( // The same as build operator, we only need to reclaim from any one of the // hash probe operator. hasReclaimedFromProbe = true; - reclaimedBytes += child->reclaim(targetBytes, maxWaitMs, stats); + child->reclaim(targetBytes, maxWaitMs, stats); } return !hasReclaimedFromBuild; }); - if (reclaimedBytes != 0) { - return reclaimedBytes; + + auto currNodeReservedMemory = pool->reservedBytes(); + VELOX_CHECK_LE(currNodeReservedMemory, prevNodeReservedMemory); + if (currNodeReservedMemory < prevNodeReservedMemory) { + return prevNodeReservedMemory - currNodeReservedMemory; } + auto joinBridge = joinBridge_.lock(); if (joinBridge == nullptr) { - return reclaimedBytes; + return 0; } - const auto oldNodeReservedMemory = pool->reservedBytes(); joinBridge->reclaim(); - const auto newNodeReservedMemory = pool->reservedBytes(); - VELOX_CHECK_LE(newNodeReservedMemory, oldNodeReservedMemory); - return oldNodeReservedMemory - newNodeReservedMemory; + currNodeReservedMemory = pool->reservedBytes(); + VELOX_CHECK_LE(currNodeReservedMemory, prevNodeReservedMemory); + return prevNodeReservedMemory - currNodeReservedMemory; } bool isHashBuildMemoryPool(const memory::MemoryPool& pool) { diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index dcf4aedba8f0..7b10948b53a0 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -647,9 +647,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { } VELOX_CHECK_EQ( - static_cast(state_), - static_cast(kRunning), - "Task has already finished processing."); + state_, TaskState::kRunning, "Task has already finished processing."); // On first call, create the drivers. if (driverFactories_.empty()) { @@ -1480,7 +1478,7 @@ void Task::noMoreSplits(const core::PlanNodeId& planNodeId) { } if (allFinished) { - terminate(kFinished); + terminate(TaskState::kFinished); } } diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 3ba28f6a572d..d205b15178af 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -613,13 +613,13 @@ class Task : public std::enable_shared_from_this { /// realized when the last thread stops running for 'this'. This is used to /// mark cancellation by the user. ContinueFuture requestCancel() { - return terminate(kCanceled); + return terminate(TaskState::kCanceled); } /// Like requestCancel but sets end state to kAborted. This is for stopping /// Tasks due to failures of other parts of the query. ContinueFuture requestAbort() { - return terminate(kAborted); + return terminate(TaskState::kAborted); } void requestYield() { diff --git a/velox/exec/TaskStructs.h b/velox/exec/TaskStructs.h index 3ddc147b6527..7d9236649589 100644 --- a/velox/exec/TaskStructs.h +++ b/velox/exec/TaskStructs.h @@ -27,8 +27,24 @@ class MergeSource; class MergeJoinSource; struct Split; +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY +enum TaskState { + kRunning = 0, + kFinished = 1, + kCanceled = 2, + kAborted = 3, + kFailed = 4 +}; +#else /// Corresponds to Presto TaskState, needed for reporting query completion. -enum TaskState { kRunning, kFinished, kCanceled, kAborted, kFailed }; +enum class TaskState : int { + kRunning = 0, + kFinished = 1, + kCanceled = 2, + kAborted = 3, + kFailed = 4 +}; +#endif std::string taskStateString(TaskState state); @@ -139,3 +155,13 @@ struct SplitGroupState { }; } // namespace facebook::velox::exec + +template <> +struct fmt::formatter + : formatter { + auto format(facebook::velox::exec::TaskState state, format_context& ctx) + const { + return formatter::format( + facebook::velox::exec::taskStateString(state), ctx); + } +}; diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 25afc860c1fc..09eb903612a5 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -19,8 +19,6 @@ #include #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" -#include "velox/common/memory/SharedArbitrator.h" -#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/HashBuild.h" @@ -8391,4 +8389,248 @@ DEBUG_ONLY_TEST_F(HashJoinTest, spillOnBlockedProbe) { arbitrationThread.join(); waitForAllTasksToBeDeleted(30'000'000); } + +DEBUG_ONLY_TEST_F(HashJoinTest, buildReclaimedMemoryReport) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + const int32_t numBuildVectors = 3; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + const int numDrivers{2}; + // duckdb need double probe and build inputs as we run two drivers for hash + // join. + std::vector totalProbeVectors = probeVectors; + totalProbeVectors.insert( + totalProbeVectors.end(), probeVectors.begin(), probeVectors.end()); + std::vector totalBuildVectors = buildVectors; + totalBuildVectors.insert( + totalBuildVectors.end(), buildVectors.begin(), buildVectors.end()); + + createDuckDbTable("t", totalProbeVectors); + createDuckDbTable("u", totalBuildVectors); + + auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + std::atomic_bool driverWaitFlag{true}; + folly::EventCount taskWait; + std::atomic_bool taskWaitFlag{true}; + + Operator* op{nullptr}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashBuild") { + return; + } + op = testOp; + }))); + + std::atomic_bool injectOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", + std::function( + ([&](memory::MemoryPoolImpl* pool) { + ASSERT_TRUE(op != nullptr); + if (!isHashBuildMemoryPool(*pool)) { + return; + } + ASSERT_TRUE(op->canReclaim()); + if (op->pool()->usedBytes() == 0) { + // We skip trigger memory reclaim when the hash table is empty on + // memory reservation. + return; + } + if (op->pool()->parent()->reservedBytes() == + op->pool()->reservedBytes()) { + // We skip trigger memory reclaim if the other peer hash build + // operator hasn't run yet. + return; + } + if (!injectOnce.exchange(false)) { + return; + } + uint64_t reclaimableBytes{0}; + const bool reclaimable = op->reclaimableBytes(reclaimableBytes); + ASSERT_TRUE(reclaimable); + ASSERT_GT(reclaimableBytes, 0); + auto* driver = op->testingOperatorCtx()->driver(); + SuspendedSection suspendedSection(driver); + taskWaitFlag = false; + taskWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + ASSERT_GT(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 16); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 16); + verifyTaskSpilledRuntimeStats(*task, true); + }) + .run(); + }); + + taskWait.await([&]() { return !taskWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->testingOperatorCtx()->task(); + auto* nodePool = op->pool()->parent(); + const auto nodeMemoryUsage = nodePool->reservedBytes(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + const uint64_t reclaimedBytes = task->pool()->reclaim( + task->pool()->capacity(), 1'000'000, reclaimerStats_); + ASSERT_GT(reclaimedBytes, 0); + ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); + } + // Verify all the memory has been freed. + ASSERT_EQ(nodePool->reservedBytes(), 0); + + driverWaitFlag = false; + driverWait.notifyAll(); + task.reset(); + + taskThread.join(); +} + +DEBUG_ONLY_TEST_F(HashJoinTest, probeReclaimedMemoryReport) { + constexpr int64_t kMaxBytes = 1LL << 30; // 1GB + const int32_t numBuildVectors = 3; + std::vector buildVectors; + for (int32_t i = 0; i < numBuildVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + buildVectors.push_back(fuzzer.fuzzRow(buildType_)); + } + + const int32_t numProbeVectors = 3; + std::vector probeVectors; + for (int32_t i = 0; i < numProbeVectors; ++i) { + VectorFuzzer fuzzer({.vectorSize = 200}, pool()); + probeVectors.push_back(fuzzer.fuzzRow(probeType_)); + } + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto tempDirectory = exec::test::TempDirectoryPath::create(); + auto queryPool = memory::memoryManager()->addRootPool( + "", kMaxBytes, memory::MemoryReclaimer::create()); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + concat(probeType_->names(), buildType_->names())) + .planNode(); + + folly::EventCount driverWait; + std::atomic_bool driverWaitFlag{true}; + folly::EventCount taskWait; + std::atomic_bool taskWaitFlag{true}; + + Operator* op{nullptr}; + std::atomic_int probeInputCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* testOp) { + if (testOp->operatorType() != "HashProbe") { + return; + } + op = testOp; + + ASSERT_TRUE(op->canReclaim()); + if (probeInputCount++ != 1) { + return; + } + auto* driver = op->testingOperatorCtx()->driver(); + SuspendedSection suspendedSection(driver); + taskWaitFlag = false; + taskWait.notifyAll(); + driverWait.await([&]() { return !driverWaitFlag.load(); }); + }))); + + std::thread taskThread([&]() { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .planNode(plan) + .queryPool(std::move(queryPool)) + .injectSpill(false) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + const auto statsPair = taskSpilledStats(*task); + // The spill triggered at the probe side. + ASSERT_EQ(statsPair.first.spilledBytes, 0); + ASSERT_EQ(statsPair.first.spilledPartitions, 0); + ASSERT_GT(statsPair.second.spilledBytes, 0); + ASSERT_EQ(statsPair.second.spilledPartitions, 16); + }) + .run(); + }); + + taskWait.await([&]() { return !taskWaitFlag.load(); }); + ASSERT_TRUE(op != nullptr); + auto task = op->testingOperatorCtx()->task(); + auto* nodePool = op->pool()->parent(); + const auto nodeMemoryUsage = nodePool->reservedBytes(); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + const uint64_t reclaimedBytes = task->pool()->reclaim( + task->pool()->capacity(), 1'000'000, reclaimerStats_); + ASSERT_GT(reclaimedBytes, 0); + ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes); + } + // Verify all the memory has been freed. + ASSERT_EQ(nodePool->reservedBytes(), 0); + + driverWaitFlag = false; + driverWait.notifyAll(); + task.reset(); + + taskThread.join(); +} } // namespace diff --git a/velox/exec/tests/LocalPartitionTest.cpp b/velox/exec/tests/LocalPartitionTest.cpp index 81cd0210f7fc..97ff74d71c5a 100644 --- a/velox/exec/tests/LocalPartitionTest.cpp +++ b/velox/exec/tests/LocalPartitionTest.cpp @@ -535,7 +535,7 @@ TEST_F(LocalPartitionTest, earlyCancelation) { } // Wait for task to transition to final state. - waitForTaskCompletion(task, exec::kCanceled); + waitForTaskCompletion(task, exec::TaskState::kCanceled); // Make sure there is only one reference to Task left, i.e. no Driver is // blocked forever. @@ -571,7 +571,7 @@ TEST_F(LocalPartitionTest, producerError) { ASSERT_THROW(while (cursor->moveNext()) { ; }, VeloxException); // Wait for task to transition to failed state. - waitForTaskCompletion(task, exec::kFailed); + waitForTaskCompletion(task, exec::TaskState::kFailed); // Make sure there is only one reference to Task left, i.e. no Driver is // blocked forever.