diff --git a/utils/local-engine/Parser/SerializedPlanParser.cpp b/utils/local-engine/Parser/SerializedPlanParser.cpp index be14b74addbd..63f5ebdce882 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.cpp +++ b/utils/local-engine/Parser/SerializedPlanParser.cpp @@ -182,6 +182,35 @@ std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( return actions_dag; } +std::string getDecimalFunction(const substrait::Type_Decimal & decimal, const bool null_on_overflow) { + std::string ch_function_name; + UInt32 precision = decimal.precision(); + UInt32 scale = decimal.scale(); + + if (precision <= DataTypeDecimal32::maxPrecision()) + { + ch_function_name = "toDecimal32"; + } + else if (precision <= DataTypeDecimal64::maxPrecision()) + { + ch_function_name = "toDecimal64"; + } + else if (precision <= DataTypeDecimal128::maxPrecision()) + { + ch_function_name = "toDecimal128"; + } + else + { + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); + } + + if (null_on_overflow) { + ch_function_name = ch_function_name + "OrNull"; + } + + return ch_function_name; +} + /// TODO: This function needs to be improved for Decimal/Array/Map/Tuple types. std::string getCastFunction(const substrait::Type & type) { @@ -227,6 +256,10 @@ std::string getCastFunction(const substrait::Type & type) { ch_function_name = "toUInt8"; } + else if (type.has_decimal()) + { + ch_function_name = getDecimalFunction(type.decimal(), false); + } else throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support cast type {}", type.DebugString()); @@ -1025,6 +1058,12 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co else throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of extract function is wrong."); } + else if (function_name == "check_overflow") + { + if (args.size() != 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires two args."); + ch_function_name = getDecimalFunction(output_type.decimal(), args.at(1).value().literal().boolean()); + } else ch_function_name = SCALAR_FUNCTIONS.at(function_name); @@ -1090,6 +1129,33 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( args.erase(args.begin()); } + if (function_signature.find("check_overflow:", 0) != function_signature.npos) + { + if (scalar_function.arguments().size() != 2) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires two args."); + + // if toDecimalxxOrNull, first arg need string type + if (scalar_function.arguments().at(1).value().literal().boolean()) + { + std::string check_overflow_args_trans_function = "toString"; + DB::ActionsDAG::NodeRawConstPtrs to_string_args({args[0]}); + + auto to_string_cast = FunctionFactory::instance().get(check_overflow_args_trans_function, context); + std::string to_string_cast_args_name; + join(to_string_args, ',', to_string_cast_args_name); + result_name = check_overflow_args_trans_function + "(" + to_string_cast_args_name + ")"; + const auto * to_string_cast_node = &actions_dag->addFunction(to_string_cast, to_string_args, result_name); + args[0] = to_string_cast_node; + } + + // delete the latest arg + args.pop_back(); + auto type = std::make_shared(); + UInt32 scale = rel.scalar_function().output_type().decimal().scale(); + args.emplace_back( + &actions_dag->addColumn(ColumnWithTypeAndName(type->createColumnConst(1, scale), type, getUniqueName(toString(scale))))); + } + auto function_builder = FunctionFactory::instance().get(function_name, context); std::string args_name; join(args, ',', args_name); @@ -1100,6 +1166,15 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( { auto cast_function = getCastFunction(rel.scalar_function().output_type()); DB::ActionsDAG::NodeRawConstPtrs cast_args({function_node}); + + if (cast_function.starts_with("toDecimal")) + { + auto type = std::make_shared(); + UInt32 scale = rel.scalar_function().output_type().decimal().scale(); + cast_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(type->createColumnConst(1, scale), type, getUniqueName(toString(scale))))); + } + auto cast = FunctionFactory::instance().get(cast_function, context); std::string cast_args_name; join(cast_args, ',', cast_args_name); @@ -1299,7 +1374,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio std::string ch_function_name = getCastFunction(rel.cast().type()); DB::ActionsDAG::NodeRawConstPtrs args; auto cast_input = rel.cast().input(); - if (cast_input.has_selection()) + if (cast_input.has_selection() || cast_input.has_literal()) { args.emplace_back(parseArgument(action_dag, rel.cast().input())); } @@ -1318,6 +1393,13 @@ const ActionsDAG::Node * SerializedPlanParser::parseArgument(ActionsDAGPtr actio { throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported cast input {}", rel.cast().input().DebugString()); } + + if (ch_function_name.starts_with("toDecimal")) + { + UInt32 scale = rel.cast().type().decimal().scale(); + args.emplace_back(add_column(std::make_shared(), scale)); + } + const auto * function_node = toFunctionNode(action_dag, ch_function_name, args); action_dag->addOrReplaceInIndex(*function_node); return function_node; diff --git a/utils/local-engine/Parser/SerializedPlanParser.h b/utils/local-engine/Parser/SerializedPlanParser.h index dcafb4cb0dac..ba32903b240e 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.h +++ b/utils/local-engine/Parser/SerializedPlanParser.h @@ -89,6 +89,7 @@ static const std::map SCALAR_FUNCTIONS = { {"quarter", "toQuarter"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"}, + {"check_overflow", "check_overflow"}, /// string functions {"like", "like"},