Skip to content

Commit

Permalink
add PrestoQueryRunner/DuckDbQueryRunner supports for join fuzzer
Browse files Browse the repository at this point in the history
  • Loading branch information
yanngyoung committed Jun 5, 2024
1 parent a6bb4d8 commit e7d924e
Show file tree
Hide file tree
Showing 9 changed files with 556 additions and 72 deletions.
140 changes: 140 additions & 0 deletions velox/exec/fuzzer/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
return queryRunner.execute(sql, resultType);
}

std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
const std::string& sql,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
const RowTypePtr& resultType) {
DuckDbQueryRunner queryRunner;
queryRunner.createTable("t", probeInput);
queryRunner.createTable("u", buildInput);
return queryRunner.execute(sql, resultType);
}

std::optional<std::string> DuckQueryRunner::toSql(
const core::PlanNodePtr& plan) {
if (!isSupported(plan->outputType())) {
Expand Down Expand Up @@ -153,6 +164,16 @@ std::optional<std::string> DuckQueryRunner::toSql(
return toSql(rowNumberNode);
}

if (const auto joinNode =
std::dynamic_pointer_cast<const core::HashJoinNode>(plan)) {
return toSql(joinNode);
}

if (const auto joinNode =
std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(plan)) {
return toSql(joinNode);
}

VELOX_NYI();
}

Expand Down Expand Up @@ -329,4 +350,123 @@ std::optional<std::string> DuckQueryRunner::toSql(

return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::HashJoinNode>& joinNode) {
const auto& joinKeysToSql = [](auto keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
}
return out.str();
};

const auto& equiClausesToSql = [](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
out << " AND ";
}
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
return out.str();
};

const auto& outputNames = joinNode->outputType()->names();

std::stringstream sql;
if (joinNode->isLeftSemiProjectJoin()) {
sql << "SELECT "
<< folly::join(", ", outputNames.begin(), --outputNames.end());
} else {
sql << "SELECT " << folly::join(", ", outputNames);
}

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
}
break;
default:
VELOX_UNREACHABLE(
"Unknown join type: {}", static_cast<int>(joinNode->joinType()));
}

return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode) {
const auto& joinKeysToSql = [](auto keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
}
return out.str();
};

const auto& outputNames = joinNode->outputType()->names();
std::stringstream sql;

// Nested loop join without filter.
VELOX_CHECK(
joinNode->joinCondition() == nullptr,
"This code path should be called only for nested loop join without filter");
const std::string joinCondition{"(1 = 1)"};
switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << joinCondition;
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << joinCondition;
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << joinCondition;
break;
default:
VELOX_UNREACHABLE(
"Unknown join type: {}", static_cast<int>(joinNode->joinType()));
}

return sql.str();
}
} // namespace facebook::velox::exec::test
12 changes: 12 additions & 0 deletions velox/exec/fuzzer/DuckQueryRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class DuckQueryRunner : public ReferenceQueryRunner {
const std::vector<RowVectorPtr>& input,
const RowTypePtr& resultType) override;

std::multiset<std::vector<velox::variant>> execute(
const std::string& sql,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput,
const RowTypePtr& resultType) override;

private:
std::optional<std::string> toSql(
const std::shared_ptr<const core::AggregationNode>& aggregationNode);
Expand All @@ -52,6 +58,12 @@ class DuckQueryRunner : public ReferenceQueryRunner {
std::optional<std::string> toSql(
const std::shared_ptr<const core::RowNumberNode>& rowNumberNode);

std::optional<std::string> toSql(
const std::shared_ptr<const core::HashJoinNode>& joinNode);

std::optional<std::string> toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode);

std::unordered_set<std::string> aggregateFunctionNames_;
};

Expand Down
69 changes: 54 additions & 15 deletions velox/exec/fuzzer/JoinFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "velox/dwio/dwrf/writer/Writer.h"
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/fuzzer/FuzzerUtil.h"
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
Expand Down Expand Up @@ -66,7 +67,9 @@ namespace {

class JoinFuzzer {
public:
explicit JoinFuzzer(size_t initialSeed);
explicit JoinFuzzer(
size_t initialSeed,
std::unique_ptr<ReferenceQueryRunner>);

void go();

Expand Down Expand Up @@ -260,6 +263,11 @@ class JoinFuzzer {

RowVectorPtr execute(const PlanWithSplits& plan, bool injectSpill);

std::optional<MaterializedRowMultiset> computeReferenceResults(
core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput);

template <typename TNode>
std::optional<MaterializedRowMultiset> computeDuckDbResult(
const std::vector<RowVectorPtr>& probeInput,
Expand Down Expand Up @@ -298,10 +306,14 @@ class JoinFuzzer {
exec::MemoryReclaimer::create())};

VectorFuzzer vectorFuzzer_;
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner_;
};

JoinFuzzer::JoinFuzzer(size_t initialSeed)
: vectorFuzzer_{getFuzzerOptions(), pool_.get()} {
JoinFuzzer::JoinFuzzer(
size_t initialSeed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner)
: vectorFuzzer_{getFuzzerOptions(), pool_.get()},
referenceQueryRunner_{std::move(referenceQueryRunner)} {
filesystems::registerLocalFileSystem();

// Make sure not to run out of open file descriptors.
Expand Down Expand Up @@ -610,6 +622,27 @@ core::PlanNodePtr tryFlipJoinSides(const core::NestedLoopJoinNode& joinNode) {
joinNode.outputType());
}

std::optional<MaterializedRowMultiset> JoinFuzzer::computeReferenceResults(
core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& probeInput,
const std::vector<RowVectorPtr>& buildInput) {
if (containsUnsupportedTypes(probeInput[0]->type())) {
return std::nullopt;
}

if (containsUnsupportedTypes(buildInput[0]->type())) {
return std::nullopt;
}

if (auto sql = referenceQueryRunner_->toSql(plan)) {
return referenceQueryRunner_->execute(
sql.value(), probeInput, buildInput, plan->outputType());
}

LOG(INFO) << "Query not supported by the reference DB";
return std::nullopt;
}

template <typename TNode>
std::optional<MaterializedRowMultiset> JoinFuzzer::computeDuckDbResult(
const std::vector<RowVectorPtr>& probeInput,
Expand Down Expand Up @@ -1021,13 +1054,14 @@ RowVectorPtr JoinFuzzer::testCrossProduct(
/*withFilter*/ false);
const auto expected = execute(plan, /*injectSpill=*/false);

// If OOM injection is not enabled verify the results against DuckDB.
// If OOM injection is not enabled verify the results against Reference query
// runner.
if (!FLAGS_enable_oom_injection) {
if (auto duckDbResult = computeDuckDbResult<core::NestedLoopJoinNode>(
probeInput, buildInput, plan.plan)) {
if (auto referenceResult =
computeReferenceResults(plan.plan, probeInput, buildInput)) {
VELOX_CHECK(
assertEqualResults(
duckDbResult.value(), plan.plan->outputType(), {expected}),
referenceResult.value(), plan.plan->outputType(), {expected}),
"Velox and DuckDB results don't match");
}
}
Expand Down Expand Up @@ -1142,7 +1176,7 @@ void JoinFuzzer::verify(core::JoinType joinType) {

shuffleJoinKeys(probeKeys, buildKeys);

const auto defaultPlan = makeDefaultPlan(
auto defaultPlan = makeDefaultPlan(
joinType,
nullAware,
probeKeys,
Expand All @@ -1153,14 +1187,17 @@ void JoinFuzzer::verify(core::JoinType joinType) {

const auto expected = execute(defaultPlan, /*injectSpill=*/false);

// If OOM injection is not enabled verify the results against DuckDB.
// If OOM injection is not enabled verify the results against Reference query
// runner.
if (!FLAGS_enable_oom_injection) {
if (auto duckDbResult = computeDuckDbResult<core::HashJoinNode>(
probeInput, buildInput, defaultPlan.plan)) {
if (auto referenceResult =
computeReferenceResults(defaultPlan.plan, probeInput, buildInput)) {
VELOX_CHECK(
assertEqualResults(
duckDbResult.value(), defaultPlan.plan->outputType(), {expected}),
"Velox and DuckDB results don't match");
referenceResult.value(),
defaultPlan.plan->outputType(),
{expected}),
"Velox and Reference results don't match");
}
}

Expand Down Expand Up @@ -1548,7 +1585,9 @@ void JoinFuzzer::go() {

} // namespace

void joinFuzzer(size_t seed) {
JoinFuzzer(seed).go();
void joinFuzzer(
size_t seed,
std::unique_ptr<test::ReferenceQueryRunner> referenceQueryRunner) {
JoinFuzzer(seed, std::move(referenceQueryRunner)).go();
}
} // namespace facebook::velox::exec::test
5 changes: 4 additions & 1 deletion velox/exec/fuzzer/JoinFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
#pragma once

#include <cstddef>
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"

namespace facebook::velox::exec::test {
void joinFuzzer(size_t seed);
void joinFuzzer(
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner);
}
38 changes: 13 additions & 25 deletions velox/exec/fuzzer/JoinFuzzerRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/common/memory/SharedArbitrator.h"
#include "velox/exec/MemoryReclaimer.h"
#include "velox/exec/fuzzer/JoinFuzzer.h"
#include "velox/exec/fuzzer/ReferenceQueryRunner.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/parse/TypeResolver.h"
#include "velox/serializers/PrestoSerializer.h"
Expand Down Expand Up @@ -57,33 +58,20 @@
/// --seed 123 \
/// --v=1

namespace facebook::velox::exec::test {

class JoinFuzzerRunner {
public:
static int run(size_t seed) {
setupMemory();
facebook::velox::serializer::presto::PrestoVectorSerde::
registerVectorSerde();
facebook::velox::filesystems::registerLocalFileSystem();
facebook::velox::functions::prestosql::registerAllScalarFunctions();
facebook::velox::parse::registerTypeResolver();

facebook::velox::exec::test::joinFuzzer(seed);
static int run(
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
serializer::presto::PrestoVectorSerde::registerVectorSerde();
filesystems::registerLocalFileSystem();
functions::prestosql::registerAllScalarFunctions();
parse::registerTypeResolver();
joinFuzzer(seed, std::move(referenceQueryRunner));
return RUN_ALL_TESTS();
}

private:
// Invoked to set up memory system with arbitration.
static void setupMemory() {
FLAGS_velox_enable_memory_usage_track_in_default_memory_pool = true;
FLAGS_velox_memory_leak_check_enabled = true;
facebook::velox::memory::SharedArbitrator::registerFactory();
facebook::velox::memory::MemoryManagerOptions options;
options.allocatorCapacity = 8L << 30;
options.arbitratorCapacity = 6L << 30;
options.arbitratorKind = "SHARED";
options.checkUsageLeak = true;
options.arbitrationStateCheckCb =
facebook::velox::exec::memoryArbitrationStateCheck;
facebook::velox::memory::MemoryManager::initialize(options);
}
};

} // namespace facebook::velox::exec::test
Loading

0 comments on commit e7d924e

Please sign in to comment.