From b7540f8389230d3d4a3dba38e6fe65bfdd1d1035 Mon Sep 17 00:00:00 2001 From: duanmeng Date: Wed, 16 Oct 2024 19:43:25 +0800 Subject: [PATCH] Add hash join replayer --- velox/exec/Operator.cpp | 4 + velox/exec/TraceUtil.cpp | 12 + velox/exec/TraceUtil.h | 3 + velox/exec/tests/OperatorTraceTest.cpp | 168 +++++++- velox/tool/trace/CMakeLists.txt | 3 +- velox/tool/trace/HashJoinReplayer.cpp | 47 +++ velox/tool/trace/HashJoinReplayer.h | 47 +++ velox/tool/trace/OperatorReplayerBase.h | 2 +- velox/tool/trace/TraceReplayRunner.cpp | 9 + velox/tool/trace/tests/CMakeLists.txt | 7 +- .../tool/trace/tests/HashJoinReplayerTest.cpp | 396 ++++++++++++++++++ 11 files changed, 692 insertions(+), 6 deletions(-) create mode 100644 velox/tool/trace/HashJoinReplayer.cpp create mode 100644 velox/tool/trace/HashJoinReplayer.h create mode 100644 velox/tool/trace/tests/HashJoinReplayerTest.cpp diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index bb55b790426bd..577b00ab66d18 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -126,6 +126,10 @@ void Operator::maybeSetTracer() { } tracedOpMap.emplace(operatorId(), operatorType()); + if (!trace::canTrace(operatorType())) { + VELOX_UNSUPPORTED("{} does not support tracing", operatorType()); + } + const auto pipelineId = operatorCtx_->driverCtx()->pipelineId; const auto driverId = operatorCtx_->driverCtx()->driverId; LOG(INFO) << "Trace input for operator type: " << operatorType() diff --git a/velox/exec/TraceUtil.cpp b/velox/exec/TraceUtil.cpp index 4df3b3f12e100..2817f1640f46b 100644 --- a/velox/exec/TraceUtil.cpp +++ b/velox/exec/TraceUtil.cpp @@ -179,4 +179,16 @@ size_t getNumDrivers( const std::shared_ptr& fs) { return listDriverIds(nodeTraceDir, pipelineId, fs).size(); } + +bool canTrace(const std::string& operatorType) { + static const std::unordered_set kSupportedOperatorTypes{ + "FilterProject", + "TableWrite", + "Aggregation", + "PartialAggregation", + "PartitionedOutput", + "HashBuild", + "HashProbe"}; + return kSupportedOperatorTypes.count(operatorType) > 0; +} } // namespace facebook::velox::exec::trace diff --git a/velox/exec/TraceUtil.h b/velox/exec/TraceUtil.h index 7b226844b7595..807b51c0eddb5 100644 --- a/velox/exec/TraceUtil.h +++ b/velox/exec/TraceUtil.h @@ -123,4 +123,7 @@ std::vector getTaskIds( folly::dynamic getTaskMetadata( const std::string& taskMetaFilePath, const std::shared_ptr& fs); + +/// Checks whether the operator can be traced. +bool canTrace(const std::string& operatorType); } // namespace facebook::velox::exec::trace diff --git a/velox/exec/tests/OperatorTraceTest.cpp b/velox/exec/tests/OperatorTraceTest.cpp index a766b7794db84..2a8f075511a10 100644 --- a/velox/exec/tests/OperatorTraceTest.cpp +++ b/velox/exec/tests/OperatorTraceTest.cpp @@ -334,6 +334,7 @@ TEST_F(OperatorTraceTest, task) { } auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId hashJoinNodeId; const auto planNode = PlanBuilder(planNodeIdGenerator) .values(rows, false) @@ -349,6 +350,7 @@ TEST_F(OperatorTraceTest, task) { "c0 < 135", {"c0", "c1", "c2"}, core::JoinType::kInner) + .capturePlanNodeId(hashJoinNodeId) .planNode(); const auto expectedResult = AssertQueryBuilder(planNode).maxDrivers(1).copyResults(pool()); @@ -374,7 +376,7 @@ TEST_F(OperatorTraceTest, task) { std::to_string(100UL << 30)}, {core::QueryConfig::kQueryTraceDir, outputDir->getPath()}, {core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr}, - {core::QueryConfig::kQueryTraceNodeIds, "1,2"}, + {core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId}, {"key1", "value1"}, }; @@ -595,7 +597,7 @@ TEST_F(OperatorTraceTest, traceTableWriter) { continue; } - // Query metadta file should exist. + // Query metadata file should exist. const auto traceMetaFilePath = getTaskTraceMetaFilePath(taskTraceDir); ASSERT_TRUE(fs->exists(traceMetaFilePath)); @@ -724,4 +726,166 @@ TEST_F(OperatorTraceTest, filterProject) { ASSERT_EQ(numOutputVectors, testData.numTracedBatches); } } + +TEST_F(OperatorTraceTest, hashJoin) { + std::vector probeInput; + RowTypePtr probeType = + ROW({"c0", "c1", "c2"}, {BIGINT(), TINYINT(), VARCHAR()}); + constexpr auto numBatch = 5; + probeInput.reserve(numBatch); + for (auto i = 0; i < numBatch; ++i) { + probeInput.push_back(vectorFuzzer_.fuzzInputFlatRow(probeType)); + } + + std::vector buildInput; + RowTypePtr buildType = + ROW({"u0", "u1", "u2"}, {BIGINT(), SMALLINT(), BIGINT()}); + buildInput.reserve(numBatch); + for (auto i = 0; i < numBatch; ++i) { + buildInput.push_back(vectorFuzzer_.fuzzInputFlatRow(buildType)); + } + + struct { + std::string taskRegExpr; + uint64_t maxTracedBytes; + uint8_t numTracedBatches; + bool limitExceeded; + + std::string debugString() const { + return fmt::format( + "taskRegExpr: {}, maxTracedBytes: {}, numTracedBatches: {}, limitExceeded {}", + taskRegExpr, + maxTracedBytes, + numTracedBatches, + limitExceeded); + } + } testSettings[]{ + {".*", 10UL << 30, numBatch, false}, + {".*", 0, numBatch, true}, + {"wrong id", 10UL << 30, 0, false}, + {"test_cursor \\d+", 10UL << 30, numBatch, false}, + {"test_cursor \\d+", 800, 2, true}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + const auto outputDir = TempDirectoryPath::create(); + const auto planNodeIdGenerator{ + std::make_shared()}; + core::PlanNodeId hashJoinNodeId; + const auto planNode = PlanBuilder(planNodeIdGenerator) + .values(probeInput, false) + .hashJoin( + {"c0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildInput, true) + .planNode(), + "c0 < 135", + {"c0", "c1", "c2"}, + core::JoinType::kInner) + .capturePlanNodeId(hashJoinNodeId) + .planNode(); + const auto testDir = TempDirectoryPath::create(); + const auto traceRoot = + fmt::format("{}/{}", testDir->getPath(), "traceRoot"); + std::shared_ptr task; + if (testData.limitExceeded) { + VELOX_ASSERT_THROW( + AssertQueryBuilder(planNode) + .maxDrivers(1) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config( + core::QueryConfig::kQueryTraceMaxBytes, + testData.maxTracedBytes) + .config( + core::QueryConfig::kQueryTraceTaskRegExp, + testData.taskRegExpr) + .config(core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId) + .copyResults(pool(), task), + "Query exceeded per-query local trace limit of"); + continue; + } + AssertQueryBuilder(planNode) + .maxDrivers(1) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) + .config(core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) + .config(core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId) + .copyResults(pool(), task); + + const auto taskTraceDir = getTaskTraceDirectory(traceRoot, *task); + const auto fs = filesystems::getFileSystem(taskTraceDir, nullptr); + + if (testData.taskRegExpr == "wrong id") { + ASSERT_FALSE(fs->exists(traceRoot)); + continue; + } + + // Query metadata file should exist. + const auto traceMetaFilePath = getTaskTraceMetaFilePath(taskTraceDir); + ASSERT_TRUE(fs->exists(traceMetaFilePath)); + + for (uint32_t pipelineId = 0; pipelineId < 2; ++pipelineId) { + const auto opTraceProbeDir = + getOpTraceDirectory(taskTraceDir, hashJoinNodeId, pipelineId, 0); + + ASSERT_EQ(fs->list(opTraceProbeDir).size(), 2); + + const auto summary = + OperatorTraceSummaryReader(opTraceProbeDir, pool()).read(); + RowTypePtr dataType; + if (pipelineId == 0) { + dataType = probeType; + } else { + dataType = buildType; + } + const auto reader = + trace::OperatorTraceInputReader(opTraceProbeDir, dataType, pool()); + RowVectorPtr actual; + size_t numOutputVectors{0}; + RowVectorPtr expected; + if (pipelineId == 0) { + expected = probeInput[numOutputVectors]; + } else { + expected = buildInput[numOutputVectors]; + } + while (reader.read(actual)) { + const auto size = actual->size(); + ASSERT_EQ(size, expected->size()); + for (auto i = 0; i < size; ++i) { + actual->compare(expected.get(), i, i, {.nullsFirst = true}); + } + ++numOutputVectors; + } + ASSERT_EQ(numOutputVectors, testData.numTracedBatches); + } + } +} + +TEST_F(OperatorTraceTest, canTrace) { + struct { + const std::string operatorType; + const bool canTrace; + + std::string debugString() const { + return fmt::format( + "operatorType: {}, canTrace: {}", operatorType, canTrace); + } + } testSettings[] = { + {"PartitionedOutput", true}, + {"HashBuild", true}, + {"HashProbe", true}, + {"RowNumber", false}, + {"OrderBy", false}, + {"PartialAggregation", true}, + {"Aggregation", true}, + {"TableWrite", true}, + {"FilterProject", true}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + ASSERT_EQ(testData.canTrace, trace::canTrace(testData.operatorType)); + } +} } // namespace facebook::velox::exec::trace::test diff --git a/velox/tool/trace/CMakeLists.txt b/velox/tool/trace/CMakeLists.txt index 7390bba8afa47..a0a3a9151d753 100644 --- a/velox/tool/trace/CMakeLists.txt +++ b/velox/tool/trace/CMakeLists.txt @@ -18,6 +18,7 @@ velox_add_library( OperatorReplayerBase.cpp PartitionedOutputReplayer.cpp TableWriterReplayer.cpp + HashJoinReplayer.cpp FilterProjectReplayer.cpp TraceReplayRunner.cpp) @@ -35,7 +36,7 @@ velox_link_libraries( glog::glog gflags::gflags) -add_executable(velox_query_replayer TraceReplayerMain.cpp TraceReplayRunner.cpp) +add_executable(velox_query_replayer TraceReplayerMain.cpp) target_link_libraries( velox_query_replayer diff --git a/velox/tool/trace/HashJoinReplayer.cpp b/velox/tool/trace/HashJoinReplayer.cpp new file mode 100644 index 0000000000000..76336c9d018c4 --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/tool/trace/HashJoinReplayer.h" +#include "velox/exec/TraceUtil.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::tool::trace { +core::PlanNodePtr HashJoinReplayer::createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const { + const auto* hashJoinNode = dynamic_cast(node); + return std::make_shared( + nodeId, + hashJoinNode->joinType(), + hashJoinNode->isNullAware(), + hashJoinNode->leftKeys(), + hashJoinNode->rightKeys(), + hashJoinNode->filter(), + source, + PlanBuilder(planNodeIdGenerator_) + .traceScan( + nodeTraceDir_, + pipelineId_ + 1, // Build side + exec::trace::getDataType(planFragment_, nodeId_, 1)) + .planNode(), + hashJoinNode->outputType()); +} +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/HashJoinReplayer.h b/velox/tool/trace/HashJoinReplayer.h new file mode 100644 index 0000000000000..830ded9e6910f --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/tool/trace/OperatorReplayerBase.h" + +namespace facebook::velox::tool::trace { +/// The replayer to replay the traced 'HashJoin' operator. +class HashJoinReplayer final : public OperatorReplayerBase { + public: + HashJoinReplayer( + const std::string& rootDir, + const std::string& queryId, + const std::string& taskId, + const std::string& nodeId, + const int32_t pipelineId, + const std::string& operatorType) + : OperatorReplayerBase( + rootDir, + queryId, + taskId, + nodeId, + pipelineId, + operatorType) {} + + private: + core::PlanNodePtr createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const override; +}; +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/OperatorReplayerBase.h b/velox/tool/trace/OperatorReplayerBase.h index 6fa1db7adce86..78c65dead3a34 100644 --- a/velox/tool/trace/OperatorReplayerBase.h +++ b/velox/tool/trace/OperatorReplayerBase.h @@ -18,7 +18,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/core/PlanNode.h" -#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/parse/PlanNodeIdGenerator.h" namespace facebook::velox::exec { class Task; diff --git a/velox/tool/trace/TraceReplayRunner.cpp b/velox/tool/trace/TraceReplayRunner.cpp index 7ed56c6ca653e..357ce9c0425f1 100644 --- a/velox/tool/trace/TraceReplayRunner.cpp +++ b/velox/tool/trace/TraceReplayRunner.cpp @@ -42,6 +42,7 @@ #include "velox/parse/TypeResolver.h" #include "velox/tool/trace/AggregationReplayer.h" #include "velox/tool/trace/FilterProjectReplayer.h" +#include "velox/tool/trace/HashJoinReplayer.h" #include "velox/tool/trace/OperatorReplayerBase.h" #include "velox/tool/trace/PartitionedOutputReplayer.h" #include "velox/tool/trace/TableWriterReplayer.h" @@ -118,6 +119,14 @@ std::unique_ptr createReplayer() { FLAGS_node_id, FLAGS_pipeline_id, FLAGS_operator_type); + } else if (FLAGS_operator_type == "HashJoin") { + replayer = std::make_unique( + FLAGS_root_dir, + FLAGS_query_id, + FLAGS_task_id, + FLAGS_node_id, + FLAGS_pipeline_id, + FLAGS_operator_type); } else { VELOX_UNSUPPORTED("Unsupported operator type: {}", FLAGS_operator_type); } diff --git a/velox/tool/trace/tests/CMakeLists.txt b/velox/tool/trace/tests/CMakeLists.txt index 91b9dcfc73f90..9a5d377012718 100644 --- a/velox/tool/trace/tests/CMakeLists.txt +++ b/velox/tool/trace/tests/CMakeLists.txt @@ -14,8 +14,11 @@ add_executable( velox_tool_trace_test - AggregationReplayerTest.cpp PartitionedOutputReplayerTest.cpp - TableWriterReplayerTest.cpp FilterProjectReplayerTest.cpp) + AggregationReplayerTest.cpp + PartitionedOutputReplayerTest.cpp + TableWriterReplayerTest.cpp + HashJoinReplayerTest.cpp + FilterProjectReplayerTest.cpp) add_test( NAME velox_tool_trace_test diff --git a/velox/tool/trace/tests/HashJoinReplayerTest.cpp b/velox/tool/trace/tests/HashJoinReplayerTest.cpp new file mode 100644 index 0000000000000..2555faed34335 --- /dev/null +++ b/velox/tool/trace/tests/HashJoinReplayerTest.cpp @@ -0,0 +1,396 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/exec/PartitionFunction.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/TraceUtil.h" +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/tool/trace/HashJoinReplayer.h" + +#include "velox/common/file/Utils.h" +#include "velox/exec/PlanNodeStats.h" + +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::core; +using namespace facebook::velox::common; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::connector; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::common::hll; + +namespace facebook::velox::tool::trace::test { +class HashJoinReplayerTest : public HiveConnectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + HiveConnectorTestBase::SetUpTestCase(); + filesystems::registerLocalFileSystem(); + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } + Type::registerSerDe(); + common::Filter::registerSerDe(); + connector::hive::HiveTableHandle::registerSerDe(); + connector::hive::LocationHandle::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + connector::hive::HiveInsertTableHandle::registerSerDe(); + connector::hive::HiveConnectorSplit::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + registerPartitionFunctionSerDe(); + } + + void TearDown() override { + probeInput_.clear(); + buildInput_.clear(); + HiveConnectorTestBase::TearDown(); + } + + struct PlanWithSplits { + core::PlanNodePtr plan; + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + std::unordered_map> splits; + + explicit PlanWithSplits( + const core::PlanNodePtr& _plan, + const core::PlanNodeId& _probeScanId = "", + const core::PlanNodeId& _buildScanId = "", + const std::unordered_map< + core::PlanNodeId, + std::vector>& _splits = {}) + : plan(_plan), + probeScanId(_probeScanId), + buildScanId(_buildScanId), + splits(_splits) {} + }; + + RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b) { + std::vector names = a->names(); + std::vector types = a->children(); + + for (auto i = 0; i < b->size(); ++i) { + names.push_back(b->nameOf(i)); + types.push_back(b->childAt(i)); + } + + return ROW(std::move(names), std::move(types)); + } + + std::vector + makeVectors(int32_t count, int32_t rowsPerVector, const RowTypePtr& rowType) { + return HiveConnectorTestBase::makeVectors(rowType, count, rowsPerVector); + } + + std::vector makeSplits( + const std::vector& inputs, + const std::string& path, + memory::MemoryPool* writerPool) { + std::vector splits; + for (auto i = 0; i < 4; ++i) { + const std::string filePath = fmt::format("{}/{}", path, i); + writeToFile(filePath, inputs); + splits.emplace_back(makeHiveConnectorSplit(filePath)); + } + + return splits; + } + + PlanWithSplits createPlan( + const std::string& tableDir, + core::JoinType joinType, + const std::vector& probeKeys, + const std::vector& buildKeys, + const std::vector& probeInput, + const std::vector& buildInput) { + auto planNodeIdGenerator = std::make_shared(); + const std::vector probeSplits = + makeSplits(probeInput, fmt::format("{}/probe", tableDir), pool()); + const std::vector buildSplits = + makeSplits(buildInput, fmt::format("{}/build", tableDir), pool()); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + const auto outputColumns = concat( + asRowType(probeInput_[0]->type()), + asRowType(buildInput_[0]->type())) + ->names(); + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(probeType_) + .capturePlanNodeId(probeScanId) + .hashJoin( + probeKeys, + buildKeys, + PlanBuilder(planNodeIdGenerator) + .tableScan(buildType_) + .capturePlanNodeId(buildScanId) + .planNode(), + /*filter=*/"", + outputColumns, + joinType, + false) + .capturePlanNodeId(traceNodeId_) + .planNode(); + return PlanWithSplits{ + plan, + probeScanId, + buildScanId, + {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}; + } + + core::PlanNodeId traceNodeId_; + RowTypePtr probeType_{ + ROW({"t0", "t1", "t2", "t3"}, {BIGINT(), VARCHAR(), SMALLINT(), REAL()})}; + + RowTypePtr buildType_{ + ROW({"u0", "u1", "u2", "u3"}, + {BIGINT(), INTEGER(), SMALLINT(), VARCHAR()})}; + std::vector probeInput_ = makeVectors(5, 100, probeType_); + std::vector buildInput_ = makeVectors(3, 100, buildType_); + + const std::vector probeKeys_{"t0"}; + const std::vector buildKeys_{"u0"}; + const std::shared_ptr testDir_ = + TempDirectoryPath::create(); + const std::string tableDir_ = + fmt::format("{}/{}", testDir_->getPath(), "table"); +}; + +TEST_F(HashJoinReplayerTest, basic) { + const auto planWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir_->getPath(), "basic"); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + auto traceResult = traceBuilder.copyResults(pool(), task); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = HashJoinReplayer( + traceRoot, + task->queryCtx()->queryId(), + task->taskId(), + traceNodeId_, + 0, + "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} + +DEBUG_ONLY_TEST_F(HashJoinReplayerTest, hashBuildSpill) { + const auto planWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir_->getPath(), "hash_build_spill"); + const auto spillDir = + fmt::format("{}/{}/spillDir/", testDir_->getPath(), "hash_build_spill"); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + + std::atomic_bool injectSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::finishHashBuild", + std::function([&](Operator* op) { + if (!injectSpillOnce.exchange(false)) { + return; + } + Operator::ReclaimableSectionGuard guard(op); + testingRunArbitration(op->pool()); + })); + + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .spillDirectory(spillDir); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + + auto traceResult = traceBuilder.copyResults(pool(), task); + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(traceNodeId_); + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT( + opStats.at("HashBuild").runtimeStats[Operator::kSpillWrites].sum, 0); + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = HashJoinReplayer( + traceRoot, + task->queryCtx()->queryId(), + task->taskId(), + traceNodeId_, + 0, + "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} + +DEBUG_ONLY_TEST_F(HashJoinReplayerTest, hashProbeSpill) { + const auto planWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir_->getPath(), "hash_probe_spill"); + const auto spillDir = + fmt::format("{}/{}/spillDir/", testDir_->getPath(), "hash_probe_spill"); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir_, + core::JoinType::kInner, + probeKeys_, + buildKeys_, + probeInput_, + buildInput_); + + std::atomic_bool injectProbeSpillOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (!injectProbeSpillOnce.exchange(false)) { + return; + } + testingRunArbitration(op->pool()); + })); + + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .spillDirectory(spillDir); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + + auto traceResult = traceBuilder.copyResults(pool(), task); + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(traceNodeId_); + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT( + opStats.at("HashProbe").runtimeStats[Operator::kSpillWrites].sum, 0); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = HashJoinReplayer( + traceRoot, + task->queryCtx()->queryId(), + task->taskId(), + traceNodeId_, + 0, + "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); +} +} // namespace facebook::velox::tool::trace::test