Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Apr 1, 2024
1 parent 4d780ee commit 8251b61
Show file tree
Hide file tree
Showing 20 changed files with 424 additions and 171 deletions.
36 changes: 36 additions & 0 deletions velox/expression/tests/ArgumentGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/core/ITypedExpr.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"

namespace facebook::velox::test {

class ExpressionFuzzer;

class ArgumentGenerator {
public:
virtual ~ArgumentGenerator() = default;

// Generates function arguments of the specified signature.
virtual std::vector<core::TypedExprPtr> generate(
ExpressionFuzzer* expressionFuzzer,
const CallableSignature& input,
int32_t maxNumVarArgs) = 0;
};

} // namespace facebook::velox::test
10 changes: 6 additions & 4 deletions velox/expression/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@ target_link_libraries(

add_executable(velox_expression_fuzzer_test ExpressionFuzzerTest.cpp)

target_link_libraries(velox_expression_fuzzer_test velox_expression_fuzzer
velox_functions_prestosql gtest gtest_main)
target_link_libraries(
velox_expression_fuzzer_test velox_expression_fuzzer_utility
velox_expression_fuzzer velox_functions_prestosql gtest gtest_main)

add_executable(spark_expression_fuzzer_test SparkExpressionFuzzerTest.cpp)

target_link_libraries(spark_expression_fuzzer_test velox_expression_fuzzer
velox_functions_spark gtest gtest_main)
target_link_libraries(
spark_expression_fuzzer_test spark_expression_fuzzer_utility
velox_expression_fuzzer velox_functions_spark gtest gtest_main)
11 changes: 0 additions & 11 deletions velox/expression/tests/ExprTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2615,17 +2615,6 @@ TEST_P(ParameterizedExprTest, constantToSql) {
ASSERT_EQ(toSql(2134456LL), "'2134456'::BIGINT");
ASSERT_EQ(toSql(variant::null(TypeKind::BIGINT)), "NULL::BIGINT");

ASSERT_EQ(toSql(2134456LL, DECIMAL(18, 2)), "'21344.56'::DECIMAL(18, 2)");
ASSERT_EQ(
toSql(variant::null(TypeKind::BIGINT), DECIMAL(18, 2)),
"NULL::DECIMAL(18, 2)");
ASSERT_EQ(
toSql((int128_t)1'000'000'000'000'000'000, DECIMAL(38, 2)),
"'10000000000000000.00'::DECIMAL(38, 2)");
ASSERT_EQ(
toSql(variant::null(TypeKind::HUGEINT), DECIMAL(38, 2)),
"NULL::DECIMAL(38, 2)");

ASSERT_EQ(toSql(18'506, DATE()), "'2020-09-01'::DATE");
ASSERT_EQ(toSql(variant::null(TypeKind::INTEGER), DATE()), "NULL::DATE");

Expand Down
128 changes: 13 additions & 115 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,13 @@ ExpressionFuzzer::ExpressionFuzzer(
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options)
: options_(options.value_or(Options())),
vectorFuzzer_(vectorFuzzer),
state{rng_, std::max(1, options_.maxLevelOfNesting)} {
state{rng_, std::max(1, options_.maxLevelOfNesting)},
customArgumentGenerators_(customArgumentGenerators) {
VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided");
seed(initialSeed);

Expand Down Expand Up @@ -711,13 +714,6 @@ ExpressionFuzzer::ExpressionFuzzer(
// Register function override (for cases where we want to restrict the types
// or parameters we pass to functions).
registerFuncOverride(&ExpressionFuzzer::generateSwitchArgs, "switch");
registerFuncOverride(
&ExpressionFuzzer::generateExtremeFunctionArgs, "greatest");
registerFuncOverride(&ExpressionFuzzer::generateExtremeFunctionArgs, "least");
registerFuncOverride(
&ExpressionFuzzer::generateMakeTimestampArgs, "make_timestamp");
registerFuncOverride(
&ExpressionFuzzer::generateUnscaledValueArgs, "unscaled_value");
}

void ExpressionFuzzer::getTicketsForFunctions() {
Expand Down Expand Up @@ -950,84 +946,6 @@ core::TypedExprPtr ExpressionFuzzer::generateArg(
}
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateExtremeFunctionArgs(
const CallableSignature& input) {
const auto argTypes = input.args;
VELOX_CHECK_GE(
argTypes.size(),
1,
"At least one input is expected from the template signature.");
if (!argTypes[0]->isDecimal()) {
return generateArgs(input);
}

auto numVarArgs =
!input.variableArity ? 0 : rand32(0, options_.maxNumVarArgs);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(argTypes.size() + numVarArgs);
inputExpressions.emplace_back(
generateArg(argTypes.at(0), input.constantArgs.at(0)));

// Append varargs to the argument list.
for (int i = 0; i < numVarArgs; i++) {
core::TypedExprPtr argExpr;
// The varargs need to be generated following the result type of the first
// argument. But when nested expression is generated, that cannot be
// guaranteed as argument precisions and scales cannot be inferred from the
// result type through a decimal function signature. Given this limitation,
// generate constant or column only.
const auto argType = inputExpressions[0]->type();
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(argType);
} else {
argExpr = generateArgColumn(argType);
}
inputExpressions.emplace_back(argExpr);
}
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateMakeTimestampArgs(
const CallableSignature& input) {
VELOX_CHECK_GE(
input.args.size(),
6,
"At least six inputs are expected from the template signature.");
bool useTimezone = vectorFuzzer_->coinToss(0.5);
std::vector<core::TypedExprPtr> inputExpressions;
inputExpressions.reserve(6);
for (int index = 0; index < 5; ++index) {
inputExpressions.emplace_back(generateArg(input.args[index]));
}

// The required result type of the sixth argument is a short decimal type with
// scale being 6. But when nested expression is generated, that cannot be
// guaranteed as argument precisions and scales cannot be inferred from the
// result type through a decimal function signature. Given this limitation,
// generate constant or column only.
core::TypedExprPtr argExpr;
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(input.args[5]);
} else {
argExpr = generateArgColumn(input.args[5]);
}
inputExpressions.emplace_back(argExpr);

if (input.args.size() == 7) {
// The 7th. argument cannot be randomly generated as it should be a valid
// timezone string.
std::vector<std::string> timezoneSet = {
"Asia/Kolkata",
"America/Los_Angeles",
"Canada/Atlantic",
"+08:00",
"-10:00"};
inputExpressions.emplace_back(std::make_shared<core::ConstantTypedExpr>(
VARCHAR(), variant(timezoneSet[rand32(0, 4)])));
}
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
const CallableSignature& input) {
VELOX_CHECK_EQ(
Expand All @@ -1050,29 +968,6 @@ std::vector<core::TypedExprPtr> ExpressionFuzzer::generateSwitchArgs(
return inputExpressions;
}

std::vector<core::TypedExprPtr> ExpressionFuzzer::generateUnscaledValueArgs(
const CallableSignature& input) {
VELOX_CHECK_EQ(
input.args.size(),
1,
"Only one input is expected from the template signature.");

// The required result type of input argument is a short decimal type. But
// when nested expression is generated, that cannot be guaranteed as argument
// precisions and scales cannot be inferred from the result type through a
// decimal function signature. Given this limitation, generate constant or
// column only.
std::vector<core::TypedExprPtr> inputExpressions;
core::TypedExprPtr argExpr;
if (rand32(0, 1) == kArgConstant) {
argExpr = generateArgConstant(input.args[0]);
} else {
argExpr = generateArgColumn(input.args[0]);
}
inputExpressions.emplace_back(argExpr);
return inputExpressions;
}

ExpressionFuzzer::FuzzedExpressionData ExpressionFuzzer::fuzzExpressions(
const RowTypePtr& outType) {
state.reset();
Expand Down Expand Up @@ -1167,6 +1062,10 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression(

std::vector<core::TypedExprPtr> ExpressionFuzzer::getArgsForCallable(
const CallableSignature& callable) {
if (customArgumentGenerators_.count(callable.name)) {
return customArgumentGenerators_[callable.name]->generate(
this, callable, options_.maxNumVarArgs);
}
auto funcIt = funcArgOverrides_.find(callable.name);
if (funcIt == funcArgOverrides_.end()) {
return generateArgs(callable);
Expand All @@ -1180,7 +1079,6 @@ TypePtr ExpressionFuzzer::getConstrainedOutputType(
if (signature == nullptr) {
return nullptr;
}

// Checks if any variable is integer constrained, and get the decimal name
// style.
bool integerConstrained = false;
Expand Down Expand Up @@ -1259,10 +1157,10 @@ core::TypedExprPtr ExpressionFuzzer::getCallExprFromCallable(
// For a decimal function (especially a nested one), as argument precisions
// and scales are randomly generated, callable.returnType does not follow the
// required constraints, and the matched result type needs to be recalculated
// from the argument types. If a constrained output type can be generated, use
// it to avoid breaking the constraints between input types and output types.
// Otherwise, generate a CallTypedExpr with type because callable.returnType
// may not have the required field names.
// from the argument types. If function signature is provided, generates a
// constrained type to avoid breaking the constraints between input types and
// output types. Otherwise, generate a CallTypedExpr with type because
// callable.returnType may not have the required field names.
const auto constrainedType = getConstrainedOutputType(args, signature);
return std::make_shared<core::CallTypedExpr>(
constrainedType ? constrainedType : type, args, callable.name);
Expand Down Expand Up @@ -1347,7 +1245,7 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromConcreteSignatures(
}

markSelected(chosen->name);
return getCallExprFromCallable(*chosen, returnType, nullptr);
return getCallExprFromCallable(*chosen, returnType);
}

const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate(
Expand Down
55 changes: 24 additions & 31 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/core/ITypedExpr.h"
#include "velox/core/QueryCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/tests/ArgumentGenerator.h"
#include "velox/expression/tests/ExpressionVerifier.h"
#include "velox/expression/tests/utils/FuzzerToolkit.h"
#include "velox/functions/FunctionRegistry.h"
Expand Down Expand Up @@ -107,6 +108,8 @@ class ExpressionFuzzer {
FunctionSignatureMap signatureMap,
size_t initialSeed,
const std::shared_ptr<VectorFuzzer>& vectorFuzzer,
const std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>&
customArgumentGenerators,
const std::optional<ExpressionFuzzer::Options>& options = std::nullopt);

template <typename TFunc>
Expand Down Expand Up @@ -195,6 +198,19 @@ class ExpressionFuzzer {

RowTypePtr fuzzRowReturnType(size_t size, char prefix = 'p');

core::TypedExprPtr generateArg(const TypePtr& arg);

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

private:
// Either generates a new expression of the required return type or if
// already generated expressions of the same return type exist then there is
Expand All @@ -218,12 +234,6 @@ class ExpressionFuzzer {

void appendConjunctSignatures();

core::TypedExprPtr generateArgConstant(const TypePtr& arg);

core::TypedExprPtr generateArgColumn(const TypePtr& arg);

core::TypedExprPtr generateArg(const TypePtr& arg);

// Given lambda argument type, generate matching LambdaTypedExpr.
//
// The 'arg' specifies inputs types and result type for the lambda. This
Expand All @@ -234,24 +244,14 @@ class ExpressionFuzzer {
// all input. The constant value is generated using 'generateArgConstant'.
core::TypedExprPtr generateArgFunction(const TypePtr& arg);

std::vector<core::TypedExprPtr> generateArgs(const CallableSignature& input);

std::vector<core::TypedExprPtr> generateArgs(
const std::vector<TypePtr>& argTypes,
const std::vector<bool>& constantArgs,
uint32_t numVarArgs = 0);

core::TypedExprPtr generateArg(const TypePtr& arg, bool isConstant);

/// Specialization for the "greatest" and "least" functions: decimal varargs
/// need to be constant or column.
std::vector<core::TypedExprPtr> generateExtremeFunctionArgs(
const CallableSignature& input);

/// Specialization for the "make_timestamp" function: 1) decimal argument
/// needs to be constant or column. 2) timezone argument needs to be valid.
std::vector<core::TypedExprPtr> generateMakeTimestampArgs(
const CallableSignature& input);
// Return a vector of expressions for each argument of callable in order.
std::vector<core::TypedExprPtr> getArgsForCallable(
const CallableSignature& callable);

/// Specialization for the "switch" function. Takes in a signature that is
/// of the form Switch (condition, then): boolean, T -> T where the type
Expand All @@ -262,15 +262,6 @@ class ExpressionFuzzer {
std::vector<core::TypedExprPtr> generateSwitchArgs(
const CallableSignature& input);

/// Specialization for the "unscaled_value" function: decimal argument needs
/// to be constant or column.
std::vector<core::TypedExprPtr> generateUnscaledValueArgs(
const CallableSignature& input);

// Return a vector of expressions for each argument of callable in order.
std::vector<core::TypedExprPtr> getArgsForCallable(
const CallableSignature& callable);

/// Given the argument types, calculates the return type of a decimal function
/// by evaluating constraints.
TypePtr getConstrainedOutputType(
Expand Down Expand Up @@ -352,9 +343,6 @@ class ExpressionFuzzer {
state.expressionStats_[funcName]++;
}

// Returns random integer between min and max inclusive.
int32_t rand32(int32_t min, int32_t max);

static const inline std::string kTypeParameterName = "T";

const Options options_;
Expand Down Expand Up @@ -441,6 +429,11 @@ class ExpressionFuzzer {
int32_t remainingLevelOfNesting_;

} state;

// Maps from function name to a custom arguments generator.
std::unordered_map<std::string, std::shared_ptr<ArgumentGenerator>>
customArgumentGenerators_;

friend class ExpressionFuzzerUnitTest;
};

Expand Down
Loading

0 comments on commit 8251b61

Please sign in to comment.