Skip to content

Commit

Permalink
add PrestoQueryRunner/DuckDbQueryRunner supports for join fuzzer
Browse files Browse the repository at this point in the history
  • Loading branch information
yanngyoung committed May 22, 2024
1 parent 60b9b6e commit 5f81757
Show file tree
Hide file tree
Showing 10 changed files with 458 additions and 68 deletions.
5 changes: 3 additions & 2 deletions velox/exec/fuzzer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ target_link_libraries(

add_library(velox_join_fuzzer JoinFuzzer.cpp)

target_link_libraries(velox_join_fuzzer velox_type velox_vector_fuzzer
velox_exec_test_lib velox_expression_test_utility)
target_link_libraries(
velox_join_fuzzer velox_fuzzer_util velox_type velox_vector_fuzzer
velox_exec_test_lib velox_expression_test_utility)
96 changes: 96 additions & 0 deletions velox/exec/fuzzer/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
return queryRunner.execute(sql, resultType);
}

std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
const std::string& sql,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
const RowTypePtr& resultType) {
DuckDbQueryRunner queryRunner;
queryRunner.createTable("t", probeInput);
queryRunner.createTable("u", buildInput);
return queryRunner.execute(sql, resultType);
}

std::optional<std::string> DuckQueryRunner::toSql(
const core::PlanNodePtr& plan) {
if (!isSupported(plan->outputType())) {
Expand Down Expand Up @@ -153,6 +164,11 @@ std::optional<std::string> DuckQueryRunner::toSql(
return toSql(rowNumberNode);
}

if (const auto joinNode =
std::dynamic_pointer_cast<const core::HashJoinNode>(plan)) {
return toSql(joinNode);
}

VELOX_NYI();
}

Expand Down Expand Up @@ -329,4 +345,84 @@ std::optional<std::string> DuckQueryRunner::toSql(

return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::HashJoinNode>& joinNode) {
const auto& joinKeysToSql = [](auto keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
}
return out.str();
};

const auto& equiClausesToSql = [](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
out << " AND ";
}
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
return out.str();
};

const auto& outputNames = joinNode->outputType()->names();

std::stringstream sql;
if (joinNode->isLeftSemiProjectJoin()) {
sql << "SELECT "
<< folly::join(", ", outputNames.begin(), --outputNames.end());
} else {
sql << "SELECT " << folly::join(", ", outputNames);
}

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
}
break;
default:
VELOX_UNREACHABLE(
"Unknown join type: {}", static_cast<int>(joinNode->joinType()));
}

return sql.str();
}
} // namespace facebook::velox::exec::test
9 changes: 9 additions & 0 deletions velox/exec/fuzzer/DuckQueryRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class DuckQueryRunner : public ReferenceQueryRunner {
const std::vector<RowVectorPtr>& input,
const RowTypePtr& resultType) override;

std::multiset<std::vector<velox::variant>> execute(
const std::string& sql,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
const RowTypePtr& resultType) override;

private:
std::optional<std::string> toSql(
const std::shared_ptr<const core::AggregationNode>& aggregationNode);
Expand All @@ -52,6 +58,9 @@ class DuckQueryRunner : public ReferenceQueryRunner {
std::optional<std::string> toSql(
const std::shared_ptr<const core::RowNumberNode>& rowNumberNode);

std::optional<std::string> toSql(
const std::shared_ptr<const velox::core::HashJoinNode>& joinNode);

std::unordered_set<std::string> aggregateFunctionNames_;
};

Expand Down
60 changes: 49 additions & 11 deletions velox/exec/fuzzer/JoinFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "velox/dwio/dwrf/reader/DwrfReader.h"
#include "velox/dwio/dwrf/writer/Writer.h"
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
Expand Down Expand Up @@ -65,7 +66,9 @@ namespace {

class JoinFuzzer {
public:
explicit JoinFuzzer(size_t initialSeed);
explicit JoinFuzzer(
size_t initialSeed,
std::unique_ptr<ReferenceQueryRunner>);

void go();

Expand Down Expand Up @@ -216,6 +219,11 @@ class JoinFuzzer {

RowVectorPtr execute(const PlanWithSplits& plan, bool injectSpill);

std::optional<MaterializedRowMultiset> computeReferenceResults(
core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput);

std::optional<MaterializedRowMultiset> computeDuckDbResult(
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
Expand All @@ -242,10 +250,14 @@ class JoinFuzzer {
exec::MemoryReclaimer::create())};

VectorFuzzer vectorFuzzer_;
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner_;
};

JoinFuzzer::JoinFuzzer(size_t initialSeed)
: vectorFuzzer_{getFuzzerOptions(), pool_.get()} {
JoinFuzzer::JoinFuzzer(
size_t initialSeed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner)
: vectorFuzzer_{getFuzzerOptions(), pool_.get()},
referenceQueryRunner_{std::move(referenceQueryRunner)} {
filesystems::registerLocalFileSystem();

// Make sure not to run out of open file descriptors.
Expand Down Expand Up @@ -547,6 +559,27 @@ bool containsUnsupportedTypes(const TypePtr& type) {
containsType(type, INTERVAL_DAY_TIME());
}

std::optional<MaterializedRowMultiset> JoinFuzzer::computeReferenceResults(
core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput) {
if (containsUnsupportedTypes(probeInput[0]->type())) {
return std::nullopt;
}

if (containsUnsupportedTypes(buildInput[0]->type())) {
return std::nullopt;
}

if (auto sql = referenceQueryRunner_->toSql(plan)) {
return referenceQueryRunner_->execute(
sql.value(), probeInput, buildInput, plan->outputType());
}

LOG(INFO) << "Query not supported by the reference DB";
return std::nullopt;
}

std::optional<MaterializedRowMultiset> JoinFuzzer::computeDuckDbResult(
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
Expand Down Expand Up @@ -948,7 +981,7 @@ void JoinFuzzer::verify(core::JoinType joinType) {

shuffleJoinKeys(probeKeys, buildKeys);

const auto defaultPlan = makeDefaultPlan(
auto defaultPlan = makeDefaultPlan(
joinType,
nullAware,
probeKeys,
Expand All @@ -959,14 +992,17 @@ void JoinFuzzer::verify(core::JoinType joinType) {

const auto expected = execute(defaultPlan, /*injectSpill=*/false);

// If OOM injection is not enabled verify the results against DuckDB.
// If OOM injection is not enabled verify the results against Reference query
// runner.
if (!FLAGS_enable_oom_injection) {
if (auto duckDbResult =
computeDuckDbResult(probeInput, buildInput, defaultPlan.plan)) {
if (auto referenceResult =
computeReferenceResults(defaultPlan.plan, probeInput, buildInput)) {
VELOX_CHECK(
assertEqualResults(
duckDbResult.value(), defaultPlan.plan->outputType(), {expected}),
"Velox and DuckDB results don't match");
referenceResult.value(),
defaultPlan.plan->outputType(),
{expected}),
"Velox and Reference results don't match");
}
}

Expand Down Expand Up @@ -1239,7 +1275,9 @@ void JoinFuzzer::go() {

} // namespace

void joinFuzzer(size_t seed) {
JoinFuzzer(seed).go();
void joinFuzzer(
size_t seed,
std::unique_ptr<test::ReferenceQueryRunner> referenceQueryRunner) {
JoinFuzzer(seed, std::move(referenceQueryRunner)).go();
}
} // namespace facebook::velox::exec::test
5 changes: 4 additions & 1 deletion velox/exec/fuzzer/JoinFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
#pragma once

#include <cstddef>
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"

namespace facebook::velox::exec::test {
void joinFuzzer(size_t seed);
void joinFuzzer(
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner);
}
34 changes: 11 additions & 23 deletions velox/exec/fuzzer/JoinFuzzerRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/common/memory/SharedArbitrator.h"
#include "velox/exec/MemoryReclaimer.h"
#include "velox/exec/fuzzer/JoinFuzzer.h"
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"
#include "velox/serializers/PrestoSerializer.h"

/// Join FuzzerRunner leverages JoinFuzzer and VectorFuzzer to
Expand Down Expand Up @@ -55,31 +56,18 @@
/// --seed 123 \
/// --v=1

namespace facebook::velox::exec::test {

class JoinFuzzerRunner {
public:
static int run(size_t seed) {
setupMemory();
facebook::velox::serializer::presto::PrestoVectorSerde::
registerVectorSerde();
facebook::velox::filesystems::registerLocalFileSystem();

facebook::velox::exec::test::joinFuzzer(seed);
static int run(
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
serializer::presto::PrestoVectorSerde::registerVectorSerde();
filesystems::registerLocalFileSystem();
joinFuzzer(seed, std::move(referenceQueryRunner));
return RUN_ALL_TESTS();
}

private:
// Invoked to set up memory system with arbitration.
static void setupMemory() {
FLAGS_velox_enable_memory_usage_track_in_default_memory_pool = true;
FLAGS_velox_memory_leak_check_enabled = true;
facebook::velox::memory::SharedArbitrator::registerFactory();
facebook::velox::memory::MemoryManagerOptions options;
options.allocatorCapacity = 8L << 30;
options.arbitratorCapacity = 6L << 30;
options.arbitratorKind = "SHARED";
options.checkUsageLeak = true;
options.arbitrationStateCheckCb =
facebook::velox::exec::memoryArbitrationStateCheck;
facebook::velox::memory::MemoryManager::initialize(options);
}
};

} // namespace facebook::velox::exec::test
Loading

0 comments on commit 5f81757

Please sign in to comment.