Skip to content

Commit

Permalink
Add spark check_overflow function and cast toDecimal32/64/128 (#231)
Browse files Browse the repository at this point in the history
* add spark check_overflow function and cast toDecimal32/64/128

* fix check overflow allow null
  • Loading branch information
loneylee authored Dec 13, 2022
1 parent e571255 commit dc9b918
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
84 changes: 83 additions & 1 deletion utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,35 @@ std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
return actions_dag;
}

std::string getDecimalFunction(const substrait::Type_Decimal & decimal, const bool null_on_overflow) {
std::string ch_function_name;
UInt32 precision = decimal.precision();
UInt32 scale = decimal.scale();

if (precision <= DataTypeDecimal32::maxPrecision())
{
ch_function_name = "toDecimal32";
}
else if (precision <= DataTypeDecimal64::maxPrecision())
{
ch_function_name = "toDecimal64";
}
else if (precision <= DataTypeDecimal128::maxPrecision())
{
ch_function_name = "toDecimal128";
}
else
{
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision);
}

if (null_on_overflow) {
ch_function_name = ch_function_name + "OrNull";
}

return ch_function_name;
}

/// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types.
std::string getCastFunction(const substrait::Type & type)
{
Expand Down Expand Up @@ -226,6 +255,10 @@ std::string getCastFunction(const substrait::Type & type)
{
ch_function_name = "toUInt8";
}
else if (type.has_decimal())
{
ch_function_name = getDecimalFunction(type.decimal(), false);
}
else
throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support cast type {}", type.DebugString());

Expand Down Expand Up @@ -1046,6 +1079,12 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of extract function is wrong.");
}
else if (function_name == "check_overflow")
{
if (args.size() != 2)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires two args.");
ch_function_name = getDecimalFunction(output_type.decimal(), args.at(1).value().literal().boolean());
}
else
ch_function_name = SCALAR_FUNCTIONS.at(function_name);

Expand Down Expand Up @@ -1120,6 +1159,33 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
args.erase(args.begin());
}

if (function_signature.find("check_overflow:", 0) != function_signature.npos)
{
if (scalar_function.arguments().size() != 2)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires two args.");

// if toDecimalxxOrNull, first arg need string type
if (scalar_function.arguments().at(1).value().literal().boolean())
{
std::string check_overflow_args_trans_function = "toString";
DB::ActionsDAG::NodeRawConstPtrs to_string_args({args[0]});

auto to_string_cast = FunctionFactory::instance().get(check_overflow_args_trans_function, context);
std::string to_string_cast_args_name;
join(to_string_args, ',', to_string_cast_args_name);
result_name = check_overflow_args_trans_function + "(" + to_string_cast_args_name + ")";
const auto * to_string_cast_node = &actions_dag->addFunction(to_string_cast, to_string_args, result_name);
args[0] = to_string_cast_node;
}

// delete the latest arg
args.pop_back();
auto type = std::make_shared<DataTypeUInt32>();
UInt32 scale = rel.scalar_function().output_type().decimal().scale();
args.emplace_back(
&actions_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, scale), type, getUniqueName(toString(scale)))));
}

auto function_builder = FunctionFactory::instance().get(function_name, context);
std::string args_name;
join(args, ',', args_name);
Expand All @@ -1130,6 +1196,15 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
{
auto cast_function = getCastFunction(rel.scalar_function().output_type());
DB::ActionsDAG::NodeRawConstPtrs cast_args({function_node});

if (cast_function.starts_with("toDecimal"))
{
auto type = std::make_shared<DataTypeUInt32>();
UInt32 scale = rel.scalar_function().output_type().decimal().scale();
cast_args.emplace_back(&actions_dag->addColumn(
ColumnWithTypeAndName(type->createColumnConst(1, scale), type, getUniqueName(toString(scale)))));
}

auto cast = FunctionFactory::instance().get(cast_function, context);
std::string cast_args_name;
join(cast_args, ',', cast_args_name);
Expand Down Expand Up @@ -1329,7 +1404,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio
std::string ch_function_name = getCastFunction(rel.cast().type());
DB::ActionsDAG::NodeRawConstPtrs args;
auto cast_input = rel.cast().input();
if (cast_input.has_selection())
if (cast_input.has_selection() || cast_input.has_literal())
{
args.emplace_back(parseArgument(action_dag, rel.cast().input()));
}
Expand All @@ -1348,6 +1423,13 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported cast input {}", rel.cast().input().DebugString());
}

if (ch_function_name.starts_with("toDecimal"))
{
UInt32 scale = rel.cast().type().decimal().scale();
args.emplace_back(add_column(std::make_shared<DataTypeUInt32>(), scale));
}

const auto * function_node = toFunctionNode(action_dag, ch_function_name, args);
action_dag->addOrReplaceInIndex(*function_node);
return function_node;
Expand Down
1 change: 1 addition & 0 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS = {
{"quarter", "toQuarter"},
{"shiftleft", "bitShiftLeft"},
{"shiftright", "bitShiftRight"},
{"check_overflow", "check_overflow"},

/// string functions
{"like", "like"},
Expand Down

0 comments on commit dc9b918

Please sign in to comment.