From 4dcada9c7cdb3cf0e2a113de90709e45cc4ebb6d Mon Sep 17 00:00:00 2001 From: duanmeng Date: Wed, 16 Oct 2024 19:43:25 +0800 Subject: [PATCH] Add hash join replayer --- velox/exec/HashBuild.cpp | 3 + velox/exec/HashProbe.cpp | 3 + velox/exec/Operator.cpp | 4 + velox/exec/QueryTraceUtil.cpp | 10 + velox/exec/QueryTraceUtil.h | 3 + velox/exec/tests/QueryTraceTest.cpp | 26 ++ velox/tool/trace/CMakeLists.txt | 3 +- velox/tool/trace/HashJoinReplayer.cpp | 45 ++ velox/tool/trace/HashJoinReplayer.h | 47 ++ velox/tool/trace/OperatorReplayerBase.cpp | 2 +- velox/tool/trace/OperatorReplayerBase.h | 4 + velox/tool/trace/QueryReplayer.cpp | 8 + velox/tool/trace/tests/CMakeLists.txt | 2 +- .../tool/trace/tests/HashJoinReplayerTest.cpp | 428 ++++++++++++++++++ 14 files changed, 585 insertions(+), 3 deletions(-) create mode 100644 velox/tool/trace/HashJoinReplayer.cpp create mode 100644 velox/tool/trace/HashJoinReplayer.h create mode 100644 velox/tool/trace/tests/HashJoinReplayerTest.cpp diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 498f4569c627..cbeee3d10e6c 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -312,6 +312,9 @@ void HashBuild::removeInputRowsForAntiJoinFilter() { } void HashBuild::addInput(RowVectorPtr input) { + if (FOLLY_UNLIKELY(!isInputFromSpill())) { + traceInput(input); + } checkRunning(); ensureInputFits(input); diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 3b9dc93f1c78..b08d869957e9 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -600,6 +600,9 @@ void HashProbe::decodeAndDetectNonNullKeys() { } void HashProbe::addInput(RowVectorPtr input) { + if (FOLLY_UNLIKELY(!isSpillInput())) { + traceInput(input); + } if (skipInput_) { VELOX_CHECK_NULL(input_); return; diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index 3b65536ba04c..17ec42dc4ab7 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -121,6 +121,10 @@ void Operator::maybeSetTracer() { return; } + if (!trace::canTrace(operatorType())) { + VELOX_UNSUPPORTED("{} does not support tracing", operatorType()); + } + auto& tracedOpMap = operatorCtx_->driverCtx()->tracedOperatorMap; if (const auto iter = tracedOpMap.find(operatorId()); iter != tracedOpMap.end()) { diff --git a/velox/exec/QueryTraceUtil.cpp b/velox/exec/QueryTraceUtil.cpp index f3a339eec080..5accb52081b8 100644 --- a/velox/exec/QueryTraceUtil.cpp +++ b/velox/exec/QueryTraceUtil.cpp @@ -108,4 +108,14 @@ getDataDir(const std::string& traceDir, int pipelineId, int driverId) { return fmt::format("{}/{}/{}/data", traceDir, pipelineId, driverId); } +bool canTrace(const std::string& operatorType) { + static const std::unordered_set kSupportedOperatorTypes{ + "TableWrite", + "Aggregation", + "PartialAggregation", + "PartitionedOutput", + "HashBuild", + "HashProbe"}; + return kSupportedOperatorTypes.count(operatorType) > 0; +} } // namespace facebook::velox::exec::trace diff --git a/velox/exec/QueryTraceUtil.h b/velox/exec/QueryTraceUtil.h index 633c0bc27b2a..e1cbbb4c125d 100644 --- a/velox/exec/QueryTraceUtil.h +++ b/velox/exec/QueryTraceUtil.h @@ -72,4 +72,7 @@ folly::dynamic getMetadata( /// given plan node, which is $traceRoot/$taskId/$nodeId. std::string getDataDir(const std::string& traceDir, int pipelineId, int driverId); + +/// Checks whether the operator can be traced. +bool canTrace(const std::string& operatorType); } // namespace facebook::velox::exec::trace diff --git a/velox/exec/tests/QueryTraceTest.cpp b/velox/exec/tests/QueryTraceTest.cpp index 71b49da483a0..e86864ed9bb2 100644 --- a/velox/exec/tests/QueryTraceTest.cpp +++ b/velox/exec/tests/QueryTraceTest.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include "velox/common/file/FileSystems.h" #include "velox/exec/PartitionFunction.h" @@ -541,4 +542,29 @@ TEST_F(QueryTracerTest, traceTableWriter) { ASSERT_EQ(numOutputVectors, testData.numTracedBatches); } } + +TEST_F(QueryTracerTest, canTrace) { + struct { + const std::string operatorType; + const bool canTrace; + + std::string debugString() const { + return fmt::format( + "operatorType: {}, canTrace: {}", operatorType, canTrace); + } + } testSettings[] = { + {"PartitionedOutput", true}, + {"HashBuild", true}, + {"HashProbe", true}, + {"RowNumber", false}, + {"OrderBy", false}, + {"PartialAggregation", true}, + {"Aggregation", true}, + {"TableWrite", true}, + {"TableScan", false}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + ASSERT_EQ(testData.canTrace, trace::canTrace(testData.operatorType)); + } +} } // namespace facebook::velox::exec::trace::test diff --git a/velox/tool/trace/CMakeLists.txt b/velox/tool/trace/CMakeLists.txt index 3218b4ee30b0..98c0f6d52422 100644 --- a/velox/tool/trace/CMakeLists.txt +++ b/velox/tool/trace/CMakeLists.txt @@ -17,7 +17,8 @@ velox_add_library( AggregationReplayer.cpp OperatorReplayerBase.cpp PartitionedOutputReplayer.cpp - TableWriterReplayer.cpp) + TableWriterReplayer.cpp + HashJoinReplayer.cpp) velox_link_libraries( velox_query_trace_replayer_base diff --git a/velox/tool/trace/HashJoinReplayer.cpp b/velox/tool/trace/HashJoinReplayer.cpp new file mode 100644 index 000000000000..9e871bfd5fba --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/tool/trace/HashJoinReplayer.h" +#include "velox/exec/QueryTraceUtil.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::tool::trace { +core::PlanNodePtr HashJoinReplayer::createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const { + const auto* hashJoinNode = dynamic_cast(node); + return std::make_shared( + nodeId, + hashJoinNode->joinType(), + hashJoinNode->isNullAware(), + hashJoinNode->leftKeys(), + hashJoinNode->rightKeys(), + hashJoinNode->filter(), + source, + PlanBuilder(planNodeIdGenerator_) + .traceScan( + nodeDir_, exec::trace::getDataType(planFragment_, nodeId_, 1)) + .planNode(), + hashJoinNode->outputType()); +} +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/HashJoinReplayer.h b/velox/tool/trace/HashJoinReplayer.h new file mode 100644 index 000000000000..e77811759398 --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/tool/trace/OperatorReplayerBase.h" + +#include + +namespace facebook::velox::tool::trace { +/// The replayer to replay the traced 'HashJoin' operator. +class HashJoinReplayer : public OperatorReplayerBase { + public: + HashJoinReplayer( + const std::string& rootDir, + const std::string& taskId, + const std::string& nodeId, + const int32_t pipelineId, + const std::string& operatorType) + : OperatorReplayerBase( + rootDir, + taskId, + nodeId, + pipelineId, + operatorType) {} + + private: + core::PlanNodePtr createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const override; +}; +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/OperatorReplayerBase.cpp b/velox/tool/trace/OperatorReplayerBase.cpp index b54ad323a447..ff0af99de061 100644 --- a/velox/tool/trace/OperatorReplayerBase.cpp +++ b/velox/tool/trace/OperatorReplayerBase.cpp @@ -68,7 +68,7 @@ core::PlanNodePtr OperatorReplayerBase::createPlan() const { const auto* replayNode = core::PlanNode::findFirstNode( planFragment_.get(), [this](const core::PlanNode* node) { return node->id() == nodeId_; }); - return exec::test::PlanBuilder() + return exec::test::PlanBuilder(planNodeIdGenerator_) .traceScan(nodeDir_, exec::trace::getDataType(planFragment_, nodeId_)) .addNode(replayNodeFactory(replayNode)) .planNode(); diff --git a/velox/tool/trace/OperatorReplayerBase.h b/velox/tool/trace/OperatorReplayerBase.h index 863a8f5f80f6..62a217d6ef2e 100644 --- a/velox/tool/trace/OperatorReplayerBase.h +++ b/velox/tool/trace/OperatorReplayerBase.h @@ -18,6 +18,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/core/PlanNode.h" +#include "velox/exec/tests/utils/PlanBuilder.h" namespace facebook::velox::exec { class Task; @@ -59,6 +60,9 @@ class OperatorReplayerBase { const std::string taskDir_; const std::string nodeDir_; + const std::shared_ptr planNodeIdGenerator_ = + std::make_shared(); + std::unordered_map queryConfigs_; std::unordered_map> connectorConfigs_; diff --git a/velox/tool/trace/QueryReplayer.cpp b/velox/tool/trace/QueryReplayer.cpp index a0680d30a387..234054da8ce5 100644 --- a/velox/tool/trace/QueryReplayer.cpp +++ b/velox/tool/trace/QueryReplayer.cpp @@ -40,6 +40,7 @@ #include "velox/parse/ExpressionsParser.h" #include "velox/parse/TypeResolver.h" #include "velox/tool/trace/AggregationReplayer.h" +#include "velox/tool/trace/HashJoinReplayer.h" #include "velox/tool/trace/OperatorReplayerBase.h" #include "velox/tool/trace/PartitionedOutputReplayer.h" #include "velox/tool/trace/TableWriterReplayer.h" @@ -148,6 +149,13 @@ std::unique_ptr createReplayer() { FLAGS_node_id, FLAGS_pipeline_id, FLAGS_operator_type); + } else if (FLAGS_operator_type == "HashJoin") { + replayer = std::make_unique( + FLAGS_root_dir, + FLAGS_task_id, + FLAGS_node_id, + FLAGS_pipeline_id, + FLAGS_operator_type); } else { VELOX_UNSUPPORTED("Unsupported operator type: {}", FLAGS_operator_type); } diff --git a/velox/tool/trace/tests/CMakeLists.txt b/velox/tool/trace/tests/CMakeLists.txt index 03d1705eebe7..42197fff233f 100644 --- a/velox/tool/trace/tests/CMakeLists.txt +++ b/velox/tool/trace/tests/CMakeLists.txt @@ -15,7 +15,7 @@ add_executable( velox_tool_trace_test AggregationReplayerTest.cpp PartitionedOutputReplayerTest.cpp - TableWriterReplayerTest.cpp) + TableWriterReplayerTest.cpp HashJoinReplayerTest.cpp) add_test( NAME velox_tool_trace_test diff --git a/velox/tool/trace/tests/HashJoinReplayerTest.cpp b/velox/tool/trace/tests/HashJoinReplayerTest.cpp new file mode 100644 index 000000000000..d13f312618cc --- /dev/null +++ b/velox/tool/trace/tests/HashJoinReplayerTest.cpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/exec/PartitionFunction.h" +#include "velox/exec/QueryDataReader.h" +#include "velox/exec/QueryTraceUtil.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/tool/trace/HashJoinReplayer.h" + +#include "velox/common/file/Utils.h" +#include "velox/exec/PlanNodeStats.h" + +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::core; +using namespace facebook::velox::common; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::connector; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::common::hll; + +namespace facebook::velox::tool::trace::test { +class HashJoinReplayerTest : public HiveConnectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + HiveConnectorTestBase::SetUpTestCase(); + filesystems::registerLocalFileSystem(); + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } + Type::registerSerDe(); + common::Filter::registerSerDe(); + connector::hive::HiveTableHandle::registerSerDe(); + connector::hive::LocationHandle::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + connector::hive::HiveInsertTableHandle::registerSerDe(); + connector::hive::HiveConnectorSplit::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + registerPartitionFunctionSerDe(); + } + + struct PlanWithSplits { + core::PlanNodePtr plan; + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + std::unordered_map> splits; + + explicit PlanWithSplits( + const core::PlanNodePtr& _plan, + const core::PlanNodeId& _probeScanId = "", + const core::PlanNodeId& _buildScanId = "", + const std::unordered_map< + core::PlanNodeId, + std::vector>& _splits = {}) + : plan(_plan), + probeScanId(_probeScanId), + buildScanId(_buildScanId), + splits(_splits) {} + }; + + int32_t randInt(int32_t min, int32_t max) { + return boost::random::uniform_int_distribution(min, max)(rng_); + } + + RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b) { + std::vector names = a->names(); + std::vector types = a->children(); + + for (auto i = 0; i < b->size(); ++i) { + names.push_back(b->nameOf(i)); + types.push_back(b->childAt(i)); + } + + return ROW(std::move(names), std::move(types)); + } + + std::vector + makeVectors(int32_t count, int32_t rowsPerVector, const RowTypePtr& rowType) { + return HiveConnectorTestBase::makeVectors(rowType, count, rowsPerVector); + } + + std::vector makeSplits( + const std::vector& inputs, + const std::string& path, + memory::MemoryPool* writerPool) { + std::vector splits; + for (auto i = 0; i < 4; ++i) { + const std::string filePath = fmt::format("{}/{}", path, i); + writeToFile(filePath, inputs); + splits.emplace_back(makeHiveConnectorSplit(filePath)); + } + + return splits; + } + + PlanWithSplits createPlan( + const std::string& tableDir, + core::JoinType joinType, + const std::vector& probeKeys, + const std::vector& buildKeys, + const std::vector& probeInput, + const std::vector& buildInput, + const std::vector& outputColumns) { + auto planNodeIdGenerator = std::make_shared(); + const std::vector probeSplits = + makeSplits(probeInput, fmt::format("{}/probe", tableDir), pool()); + const std::vector buildSplits = + makeSplits(buildInput, fmt::format("{}/build", tableDir), pool()); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(probeType_) + .capturePlanNodeId(probeScanId) + .hashJoin( + probeKeys, + buildKeys, + PlanBuilder(planNodeIdGenerator) + .tableScan(buildType_) + .capturePlanNodeId(buildScanId) + .planNode(), + /*filter=*/"", + outputColumns, + joinType, + false) + .capturePlanNodeId(traceNodeId_) + .planNode(); + return PlanWithSplits{ + plan, + probeScanId, + buildScanId, + {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}; + } + + core::PlanNodeId traceNodeId_; + std::mt19937 rng_; + const RowTypePtr probeType_{ + ROW({"t0", "t1", "t2", "t3"}, {BIGINT(), INTEGER(), SMALLINT(), REAL()})}; + + const RowTypePtr buildType_{ + ROW({"u0", "u1", "u2", "u3"}, {BIGINT(), INTEGER(), SMALLINT(), REAL()})}; +}; + +TEST_F(HashJoinReplayerTest, test) { + const auto testDir = TempDirectoryPath::create(); + const auto tableDir = fmt::format("{}/{}", testDir->getPath(), "table"); + std::vector probeKeys{"t0", "t1"}; + std::vector buildKeys{"u0", "u1"}; + const auto probeInput = makeVectors(randInt(2, 5), 100, probeType_); + const auto buildInput = makeVectors(randInt(2, 5), 100, buildType_); + constexpr auto joinType = core::JoinType::kInner; + auto outputColumns = + (core::isLeftSemiProjectJoin(joinType) || + core::isLeftSemiFilterJoin(joinType) || core::isAntiJoin(joinType)) + ? asRowType(probeInput[0]->type())->names() + : concat( + asRowType(probeInput[0]->type()), asRowType(buildInput[0]->type())) + ->names(); + + if (core::isLeftSemiProjectJoin(joinType) || + core::isRightSemiProjectJoin(joinType)) { + outputColumns.emplace_back("match"); + } + + const auto planWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir->getPath(), joinType); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + auto traceResult = traceBuilder.copyResults(pool(), task); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = + HashJoinReplayer(traceRoot, task->taskId(), traceNodeId_, 0, "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} + +TEST_F(HashJoinReplayerTest, hashBuildSpill) { + const auto testDir = TempDirectoryPath::create(); + const auto tableDir = fmt::format("{}/{}", testDir->getPath(), "table"); + std::vector probeKeys{"t0", "t1"}; + std::vector buildKeys{"u0", "u1"}; + const auto probeInput = makeVectors(randInt(2, 5), 100, probeType_); + const auto buildInput = makeVectors(randInt(2, 5), 100, buildType_); + constexpr auto joinType = core::JoinType::kInner; + auto outputColumns = + (core::isLeftSemiProjectJoin(joinType) || + core::isLeftSemiFilterJoin(joinType) || core::isAntiJoin(joinType)) + ? asRowType(probeInput[0]->type())->names() + : concat( + asRowType(probeInput[0]->type()), asRowType(buildInput[0]->type())) + ->names(); + + if (core::isLeftSemiProjectJoin(joinType) || + core::isRightSemiProjectJoin(joinType)) { + outputColumns.emplace_back("match"); + } + + const auto planWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir->getPath(), joinType); + const auto spillDir = + fmt::format("{}/{}/spillDir/", testDir->getPath(), joinType); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + + std::atomic_bool injectSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function([&](Operator* op) { + if (!injectSpillOnce.exchange(false)) { + return; + } + Operator::ReclaimableSectionGuard guard(op); + testingRunArbitration(op->pool()); + })); + + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .spillDirectory(spillDir); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + + auto traceResult = traceBuilder.copyResults(pool(), task); + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(traceNodeId_); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = + HashJoinReplayer(traceRoot, task->taskId(), traceNodeId_, 0, "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} + +TEST_F(HashJoinReplayerTest, hashProbeSpill) { + const auto testDir = TempDirectoryPath::create(); + const auto tableDir = fmt::format("{}/{}", testDir->getPath(), "table"); + std::vector probeKeys{"t0", "t1"}; + std::vector buildKeys{"u0", "u1"}; + const auto probeInput = makeVectors(randInt(2, 5), 100, probeType_); + const auto buildInput = makeVectors(randInt(2, 5), 100, buildType_); + constexpr auto joinType = core::JoinType::kInner; + auto outputColumns = + (core::isLeftSemiProjectJoin(joinType) || + core::isLeftSemiFilterJoin(joinType) || core::isAntiJoin(joinType)) + ? asRowType(probeInput[0]->type())->names() + : concat( + asRowType(probeInput[0]->type()), asRowType(buildInput[0]->type())) + ->names(); + + if (core::isLeftSemiProjectJoin(joinType) || + core::isRightSemiProjectJoin(joinType)) { + outputColumns.emplace_back("match"); + } + + const auto planWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir->getPath(), joinType); + const auto spillDir = + fmt::format("{}/{}/spillDir/", testDir->getPath(), joinType); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + + std::atomic_bool injectProbeSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (!injectProbeSpillOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .spillDirectory(spillDir); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + + auto traceResult = traceBuilder.copyResults(pool(), task); + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(traceNodeId_); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = + HashJoinReplayer(traceRoot, task->taskId(), traceNodeId_, 0, "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} +} // namespace facebook::velox::tool::trace::test