diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp index 77814a427624..fc2e3e886117 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp @@ -39,7 +39,6 @@ namespace local_engine DB::ActionsDAG::NodeRawConstPtrs AggregateFunctionParser::parseFunctionArguments( const CommonFunctionInfo & func_info, - const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const { DB::ActionsDAG::NodeRawConstPtrs collected_args; diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h index 464ad099a3b6..215c09626b7e 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h @@ -97,12 +97,7 @@ class AggregateFunctionParser /// Do some preprojections for the function arguments, and return the necessary arguments for the CH function. virtual DB::ActionsDAG::NodeRawConstPtrs - parseFunctionArguments(const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const; - - DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const - { - return parseFunctionArguments(func_info, getCHFunctionName(func_info), actions_dag); - } + parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const; // `PartialMerge` is applied on the merging stages. // `If` is applied when the aggreate function has a filter. This should only happen on the 1st stage. diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp b/cpp-ch/local-engine/Parser/FunctionParser.cpp index e786304b3e6e..206fb6c50ec9 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -39,24 +40,18 @@ using namespace DB; String FunctionParser::getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const { - auto func_signature = plan_parser->function_mapping.at(std::to_string(substrait_func.function_reference())); - auto pos = func_signature.find(':'); - auto func_name = func_signature.substr(0, pos); - - auto it = SCALAR_FUNCTIONS.find(func_name); - if (it == SCALAR_FUNCTIONS.end()) - throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported substrait function: {}", func_name); - return it->second; + // no meaning + /// There is no any simple equivalent ch function. + return ""; } ActionsDAG::NodeRawConstPtrs FunctionParser::parseFunctionArguments( - const substrait::Expression_ScalarFunction & substrait_func, const String & ch_func_name, ActionsDAGPtr & actions_dag) const + const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const { ActionsDAG::NodeRawConstPtrs parsed_args; const auto & args = substrait_func.arguments(); parsed_args.reserve(args.size()); - for (const auto & arg : args) - plan_parser->parseFunctionArgument(actions_dag, parsed_args, ch_func_name, arg); + plan_parser->parseFunctionArguments(actions_dag, parsed_args, substrait_func); return parsed_args; } @@ -66,7 +61,7 @@ const ActionsDAG::Node * FunctionParser::parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const { auto ch_func_name = getCHFunctionName(substrait_func); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); const auto * func_node = toFunctionNode(actions_dag, ch_func_name, parsed_args); return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); } @@ -76,13 +71,30 @@ const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded( { const auto & output_type = substrait_func.output_type(); if (!TypeParser::isTypeMatched(output_type, func_node->result_type)) - return ActionsDAGUtil::convertNodeType( - actions_dag, - func_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))->getName() - : DB::removeNullable(TypeParser::parseType(output_type))->getName(), - func_node->result_name); + { + auto result_type = TypeParser::parseType(substrait_func.output_type()); + if (DB::isDecimalOrNullableDecimal(result_type)) + { + return ActionsDAGUtil::convertNodeType( + actions_dag, + func_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), + func_node->result_name, + CastType::accurateOrNull); + } + else + { + return ActionsDAGUtil::convertNodeType( + actions_dag, + func_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))->getName() + : DB::removeNullable(TypeParser::parseType(output_type))->getName(), + func_node->result_name); + } + } else return func_node; } diff --git a/cpp-ch/local-engine/Parser/FunctionParser.h b/cpp-ch/local-engine/Parser/FunctionParser.h index 36cf9f5ce5d5..6ac162a953c6 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.h +++ b/cpp-ch/local-engine/Parser/FunctionParser.h @@ -51,10 +51,16 @@ class FunctionParser virtual String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const; protected: - + /// Deprecated method + virtual DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & /*function_name*/, + DB::ActionsDAGPtr & actions_dag) const + { + return parseFunctionArguments(substrait_func, actions_dag); + } virtual DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const substrait::Expression_ScalarFunction & substrait_func, - const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const; virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( @@ -84,6 +90,11 @@ class FunctionParser return &action_dag->addFunction(function_builder, args, result_name); } + const ActionsDAG::Node * + parseFunctionWithDAG(const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag, bool keep_result = false) const + { + return plan_parser->parseFunctionWithDAG(rel, result_name, actions_dag, keep_result); + } const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAGPtr actions_dag, const substrait::Expression & rel) const { return plan_parser->parseExpression(actions_dag, rel); @@ -91,7 +102,7 @@ class FunctionParser std::pair parseLiteral(const substrait::Expression_Literal & literal) const { return plan_parser->parseLiteral(literal); } - SerializedPlanParser * plan_parser; + mutable SerializedPlanParser * plan_parser; }; using FunctionParserPtr = std::shared_ptr; diff --git a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp index b51b76b97415..7d906a837441 100644 --- a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp +++ b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp @@ -383,13 +383,7 @@ void MergeTreeRelParser::collectColumns(const substrait::Expression & rel, NameS String MergeTreeRelParser::getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) { auto func_signature = getPlanParser()->function_mapping.at(std::to_string(substrait_func.function_reference())); - auto pos = func_signature.find(':'); - auto func_name = func_signature.substr(0, pos); - - auto it = SCALAR_FUNCTIONS.find(func_name); - if (it == SCALAR_FUNCTIONS.end()) - throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported substrait function on mergetree prewhere parser: {}", func_name); - return it->second; + return getPlanParser()->getFunctionName(func_signature, substrait_func); } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8c60c6e500a9..66b796060089 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -580,95 +580,10 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co auto args = function.arguments(); auto pos = function_signature.find(':'); auto function_name = function_signature.substr(0, pos); - if (!SCALAR_FUNCTIONS.contains(function_name)) - throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", function_name); - - std::string ch_function_name; - if (function_name == "trim") - ch_function_name = args.size() == 1 ? "trimBoth" : "trimBothSpark"; - else if (function_name == "ltrim") - ch_function_name = args.size() == 1 ? "trimLeft" : "trimLeftSpark"; - else if (function_name == "rtrim") - ch_function_name = args.size() == 1 ? "trimRight" : "trimRightSpark"; - else if (function_name == "extract") - { - if (args.size() != 2) - throw Exception( - ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); - - // Get the first arg: field - const auto & extract_field = args.at(0); - - if (extract_field.value().has_literal()) - { - const auto & field_value = extract_field.value().literal().string(); - if (field_value == "YEAR") - ch_function_name = "toYear"; // spark: extract(YEAR FROM) or year - else if (field_value == "YEAR_OF_WEEK") - ch_function_name = "toISOYear"; // spark: extract(YEAROFWEEK FROM) - else if (field_value == "QUARTER") - ch_function_name = "toQuarter"; // spark: extract(QUARTER FROM) or quarter - else if (field_value == "MONTH") - ch_function_name = "toMonth"; // spark: extract(MONTH FROM) or month - else if (field_value == "WEEK_OF_YEAR") - ch_function_name = "toISOWeek"; // spark: extract(WEEK FROM) or weekofyear - else if (field_value == "WEEK_DAY") - /// Spark WeekDay(date) (0 = Monday, 1 = Tuesday, ..., 6 = Sunday) - /// Substrait: extract(WEEK_DAY from date) - /// CH: toDayOfWeek(date, 1) - ch_function_name = "toDayOfWeek"; - else if (field_value == "DAY_OF_WEEK") - /// Spark: DayOfWeek(date) (1 = Sunday, 2 = Monday, ..., 7 = Saturday) - /// Substrait: extract(DAY_OF_WEEK from date) - /// CH: toDayOfWeek(date, 3) - /// DAYOFWEEK is alias of function toDayOfWeek. - /// This trick is to distinguish between extract fields DAY_OF_WEEK and WEEK_DAY in latter codes - ch_function_name = "DAYOFWEEK"; - else if (field_value == "DAY") - ch_function_name = "toDayOfMonth"; // spark: extract(DAY FROM) or dayofmonth - else if (field_value == "DAY_OF_YEAR") - ch_function_name = "toDayOfYear"; // spark: extract(DOY FROM) or dayofyear - else if (field_value == "HOUR") - ch_function_name = "toHour"; // spark: extract(HOUR FROM) or hour - else if (field_value == "MINUTE") - ch_function_name = "toMinute"; // spark: extract(MINUTE FROM) or minute - else if (field_value == "SECOND") - ch_function_name = "toSecond"; // spark: extract(SECOND FROM) or secondwithfraction - else - throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong."); - } - else - throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong."); - } - else if (function_name == "check_overflow") - { - if (args.size() < 2) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); - ch_function_name = SCALAR_FUNCTIONS.at(function_name); - auto null_on_overflow = args.at(1).value().literal().boolean(); - if (null_on_overflow) - ch_function_name = ch_function_name + "OrNull"; - } - else if (function_name == "make_decimal") - { - if (args.size() < 2) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args."); - ch_function_name = SCALAR_FUNCTIONS.at(function_name); - auto null_on_overflow = args.at(1).value().literal().boolean(); - if (null_on_overflow) - ch_function_name = ch_function_name + "OrNull"; - } - else if (function_name == "reverse") - { - if (function.output_type().has_list()) - ch_function_name = "arrayReverse"; - else - ch_function_name = "reverseUTF8"; - } - else - ch_function_name = SCALAR_FUNCTIONS.at(function_name); - - return ch_function_name; + auto function_parser = FunctionParserFactory::instance().tryGet(function_name, this); + if (!function_parser) + throw DB::Exception(DB::ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function: {}", function_name); + return function_parser->getCHFunctionName(function); } void SerializedPlanParser::parseArrayJoinArguments( @@ -690,8 +605,7 @@ void SerializedPlanParser::parseArrayJoinArguments( throw Exception( ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 instead of {}", scalar_function.arguments_size()); - auto function_name_copy = function_name; - parseFunctionArguments(actions_dag, parsed_args, function_name_copy, scalar_function); + parseFunctionArguments(actions_dag, parsed_args, scalar_function); auto arg = parsed_args[0]; auto arg_type = removeNullable(arg->result_type); @@ -735,7 +649,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); - auto function_name = getFunctionName(function_signature, scalar_function); + String function_name = "arrayJoin"; /// Whether the input argument of explode/posexplode is map type bool is_map; @@ -856,289 +770,29 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( auto pos = function_signature.find(':'); auto func_name = function_signature.substr(0, pos); - if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this)) - { - LOG_DEBUG( - &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); - const auto * result_node = func_parser->parse(scalar_function, actions_dag); - if (keep_result) - actions_dag->addOrReplaceInOutputs(*result_node); - - result_name = result_node->result_name; - return result_node; - } - - auto ch_func_name = getFunctionName(function_signature, scalar_function); - ActionsDAG::NodeRawConstPtrs args; - parseFunctionArguments(actions_dag, args, ch_func_name, scalar_function); - - /// If the first argument of function formatDateTimeInJodaSyntax is integer, replace formatDateTimeInJodaSyntax with fromUnixTimestampInJodaSyntax - /// to avoid exception - if (ch_func_name == "formatDateTimeInJodaSyntax") - { - if (args.size() > 1 && isInteger(removeNullable(args[0]->result_type))) - ch_func_name = "fromUnixTimestampInJodaSyntax"; - } - - if (ch_func_name == "alias") - { - result_name = args[0]->result_name; - actions_dag->addOrReplaceInOutputs(*args[0]); - return &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name); - } - - if (ch_func_name == "toYear") - { - const ActionsDAG::Node * arg_node = args[0]; - const String & arg_func_name = arg_node->function ? arg_node->function->getName() : ""; - if ((arg_func_name == "sparkToDate" || arg_func_name == "sparkToDateTime") && arg_node->children.size() > 0) - { - const ActionsDAG::Node * child_node = arg_node->children[0]; - if (child_node && isString(removeNullable(child_node->result_type))) - { - auto extract_year_builder = FunctionFactory::instance().get("sparkExtractYear", context); - auto func_result_name = "sparkExtractYear(" + child_node->result_name + ")"; - return &actions_dag->addFunction(extract_year_builder, {child_node}, func_result_name); - } - } - } - - const ActionsDAG::Node * result_node; - - if (ch_func_name == "splitByRegexp") - { - if (args.size() >= 2) - { - /// In Spark: split(str, regex [, limit] ) - /// In CH: splitByRegexp(regexp, str [, limit]) - std::swap(args[0], args[1]); - } - } - - /// TODO: FunctionParser for check_overflow and make_decimal - if (function_signature.find("check_overflow:", 0) != String::npos) - { - if (scalar_function.arguments().size() < 2) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); - - ActionsDAG::NodeRawConstPtrs new_args; - new_args.reserve(3); - new_args.emplace_back(args[0]); - - UInt32 precision = rel.scalar_function().output_type().decimal().precision(); - UInt32 scale = rel.scalar_function().output_type().decimal().scale(); - auto uint32_type = std::make_shared(); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); - args = std::move(new_args); - } - else if (startsWith(function_signature, "make_decimal:")) - { - if (scalar_function.arguments().size() < 2) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args."); - - ActionsDAG::NodeRawConstPtrs new_args; - new_args.reserve(3); - new_args.emplace_back(args[0]); - - UInt32 precision = rel.scalar_function().output_type().decimal().precision(); - UInt32 scale = rel.scalar_function().output_type().decimal().scale(); - auto uint32_type = std::make_shared(); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); - args = std::move(new_args); - } - - auto function_builder = FunctionFactory::instance().get(ch_func_name, context); - std::string args_name = join(args, ','); - result_name = ch_func_name + "(" + args_name + ")"; - const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); - result_node = function_node; - if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type)) - { - auto result_type = TypeParser::parseType(rel.scalar_function().output_type()); - if (isDecimalOrNullableDecimal(result_type)) - { - result_node = ActionsDAGUtil::convertNodeType( - actions_dag, - function_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), - function_node->result_name, - CastType::accurateOrNull); - } - else - { - result_node = ActionsDAGUtil::convertNodeType( - actions_dag, - function_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), - function_node->result_name); - } - } - + auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this); + if (!func_parser) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found function parser for {}", func_name); + LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); + const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); + result_name = result_node->result_name; return result_node; } void SerializedPlanParser::parseFunctionArguments( ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, - std::string & function_name, const substrait::Expression_ScalarFunction & scalar_function) { auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference())); const auto & args = scalar_function.arguments(); parsed_args.reserve(args.size()); + for (const auto & arg : args) + parsed_args.emplace_back(parseExpression(actions_dag, arg.value())); - // Some functions need to be handled specially. - if (function_name == "JSONExtract") - { - parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); - auto data_type = TypeParser::parseType(scalar_function.output_type()); - parsed_args.emplace_back(addColumn(actions_dag, std::make_shared(), data_type->getName())); - } - else if (function_name == "sparkTupleElement" || function_name == "tupleElement") - { - parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); - - if (!args[1].value().has_literal()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal"); - - auto [data_type, field] = parseLiteral(args[1].value().literal()); - if (data_type->getTypeId() != TypeIndex::Int32) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32"); - - // tuple indecies start from 1, in spark, start from 0 - Int32 field_index = static_cast(field.get() + 1); - const auto * index_node = addColumn(actions_dag, std::make_shared(), field_index); - parsed_args.emplace_back(index_node); - } - else if (function_name == "tuple") - { - // Arguments in the format, (, [, , ...]) - // We don't need to care the field names here. - for (int index = 1; index < args.size(); index += 2) - parseFunctionArgument(actions_dag, parsed_args, function_name, args[index]); - } - else if (function_name == "repeat") - { - // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait - // which must be a positive value into unsigned integer here. - parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); - const ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, function_name, args[1]); - DataTypeNullable target_type(std::make_shared()); - repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName()); - parsed_args.emplace_back(repeat_times_node); - } - else if (function_name == "isNaN") - { - // the result of isNaN(NULL) is NULL in CH, but false in Spark - const ActionsDAG::Node * arg_node = nullptr; - if (args[0].value().has_cast()) - { - arg_node = parseExpression(actions_dag, args[0].value().cast().input()); - const auto * res_type = arg_node->result_type.get(); - if (res_type->isNullable()) - { - res_type = typeid_cast(res_type)->getNestedType().get(); - } - if (isString(*res_type)) - { - ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node}; - arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args); - } - else - { - arg_node = parseFunctionArgument(actions_dag, function_name, args[0]); - } - } - else - { - arg_node = parseFunctionArgument(actions_dag, function_name, args[0]); - } - - ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, addColumn(actions_dag, std::make_shared(), 0)}; - parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args)); - } - else if (function_name == "space") - { - // convert space function to repeat - const ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, "repeat", args[0]); - const ActionsDAG::Node * space_str_node = addColumn(actions_dag, std::make_shared(), " "); - function_name = "repeat"; - parsed_args.emplace_back(space_str_node); - parsed_args.emplace_back(repeat_times_node); - } - else if (function_name == "trimBothSpark" || function_name == "trimLeftSpark" || function_name == "trimRightSpark") - { - /// In substrait, the first arg is srcStr, the second arg is trimStr - /// But in CH, the first arg is trimStr, the second arg is srcStr - parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]); - parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); - } - else if (startsWith(function_signature, "extract:")) - { - /// Skip the first arg of extract in substrait - for (int i = 1; i < args.size(); i++) - parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]); - - /// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait - if (function_name == "toDayOfWeek" || function_name == "DAYOFWEEK") - { - UInt8 mode = function_name == "toDayOfWeek" ? 1 : 3; - auto mode_type = std::make_shared(); - ColumnWithTypeAndName mode_col(mode_type->createColumnConst(1, mode), mode_type, getUniqueName(std::to_string(mode))); - const auto & mode_node = actions_dag->addColumn(std::move(mode_col)); - parsed_args.emplace_back(&mode_node); - } - } - else if (startsWith(function_signature, "sha2:")) - { - for (int i = 0; i < args.size() - 1; i++) - parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]); - } - else - { - // Default handle - for (const auto & arg : args) - parseFunctionArgument(actions_dag, parsed_args, function_name, arg); - } -} - -void SerializedPlanParser::parseFunctionArgument( - ActionsDAGPtr & actions_dag, - ActionsDAG::NodeRawConstPtrs & parsed_args, - const std::string & function_name, - const substrait::FunctionArgument & arg) -{ - parsed_args.emplace_back(parseFunctionArgument(actions_dag, function_name, arg)); -} - -const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( - ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) -{ - const ActionsDAG::Node * res; - if (arg.value().has_scalar_function()) - { - std::string arg_name; - bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name); - res = parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg); - } - else - { - res = parseExpression(actions_dag, arg.value()); - } - return res; } // Convert signed integer index into unsigned integer index @@ -1225,7 +879,7 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple( const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); - auto function_name = getFunctionName(function_signature, scalar_function); + String function_name = "json_tuple"; auto args = scalar_function.arguments(); if (args.size() < 2) { @@ -1847,8 +1501,7 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & auto substrait_name = function_signature.substr(0, function_signature.find(':')); auto func_parser = FunctionParserFactory::instance().tryGet(substrait_name, plan_parser); - String function_name - = func_parser ? func_parser->getName() : SerializedPlanParser::getFunctionName(function_signature, scalar_function); + String function_name = func_parser->getName(); ASTs ast_args; parseFunctionArgumentsToAST(names, scalar_function, ast_args); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 477fdb1f6d44..cdeb4bdd7aeb 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -35,182 +35,6 @@ namespace local_engine { -static const std::map SCALAR_FUNCTIONS - = {{"is_not_null", "isNotNull"}, - {"is_null", "isNull"}, - {"gte", "greaterOrEquals"}, - {"gt", "greater"}, - {"lte", "lessOrEquals"}, - {"lt", "less"}, - {"equal", "equals"}, - - {"and", "and"}, - {"or", "or"}, - {"not", "not"}, - {"xor", "xor"}, - - {"extract", ""}, - {"cast", "CAST"}, - {"alias", "alias"}, - - /// datetime functions - {"get_timestamp", "parseDateTimeInJodaSyntaxOrNull"}, // for spark function: to_date/to_timestamp - {"quarter", "toQuarter"}, - {"to_unix_timestamp", "parseDateTimeInJodaSyntaxOrNull"}, - //{"unix_timestamp", "toUnixTimestamp"}, - {"date_format", "formatDateTimeInJodaSyntax"}, - {"timestamp_add", "timestamp_add"}, - - - /// arithmetic functions - {"subtract", "minus"}, - {"multiply", "multiply"}, - {"add", "plus"}, - {"divide", "divide"}, - {"positive", "identity"}, - {"negative", "negate"}, - {"modulus", "modulo"}, - {"pmod", "pmod"}, - {"abs", "abs"}, - {"ceil", "ceil"}, - {"round", "roundHalfUp"}, - {"bround", "roundBankers"}, - {"exp", "exp"}, - {"power", "power"}, - {"cos", "cos"}, - {"cosh", "cosh"}, - {"sin", "sin"}, - {"sinh", "sinh"}, - {"tan", "tan"}, - {"tanh", "tanh"}, - {"acos", "acos"}, - {"asin", "asin"}, - {"atan", "atan"}, - {"atan2", "atan2"}, - {"asinh", "asinh"}, - {"acosh", "acosh"}, - {"atanh", "atanh"}, - {"bitwise_not", "bitNot"}, - {"bitwise_and", "bitAnd"}, - {"bitwise_or", "bitOr"}, - {"bitwise_xor", "bitXor"}, - {"bit_get", "bitTest"}, - {"bit_count", "bitCount"}, - {"sqrt", "sqrt"}, - {"cbrt", "cbrt"}, - {"degrees", "degrees"}, - {"e", "e"}, - {"pi", "pi"}, - {"hex", "hex"}, - {"unhex", "unhex"}, - {"hypot", "hypot"}, - {"sign", "sign"}, - {"radians", "radians"}, - {"greatest", "sparkGreatest"}, - {"least", "sparkLeast"}, - {"shiftleft", "bitShiftLeft"}, - {"shiftright", "bitShiftRight"}, - {"check_overflow", "checkDecimalOverflowSpark"}, - {"rand", "randCanonical"}, - {"isnan", "isNaN"}, - {"bin", "sparkBin"}, - {"rint", "sparkRint"}, - - /// string functions - {"like", "like"}, - {"not_like", "notLike"}, - {"starts_with", "startsWithUTF8"}, - {"ends_with", "endsWithUTF8"}, - {"contains", "countSubstrings"}, - {"substring", "substringUTF8"}, - {"substring_index", "substringIndexUTF8"}, - {"lower", "lowerUTF8"}, - {"upper", "upperUTF8"}, - {"trim", ""}, // trimLeft or trimLeftSpark, depends on argument size - {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size - {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size - {"strpos", "positionUTF8"}, - {"replace", "replaceAll"}, - {"regexp_replace", "replaceRegexpAll"}, - {"regexp_extract_all", "regexpExtractAllSpark"}, - {"rlike", "match"}, - {"ascii", "ascii"}, - {"split", "splitByRegexp"}, - {"concat_ws", "concat_ws"}, - {"base64", "base64Encode"}, - {"unbase64", "base64Decode"}, - {"lpad", "leftPadUTF8"}, - {"rpad", "rightPadUTF8"}, - {"reverse", ""}, /// dummy mapping - {"translate", "translateUTF8"}, - {"repeat", "repeat"}, - {"space", "space"}, - {"initcap", "initcapUTF8"}, - {"conv", "sparkConv"}, - {"uuid", "generateUUIDv4"}, - {"levenshteinDistance", "editDistanceUTF8"}, - - /// hash functions - {"crc32", "CRC32"}, - {"murmur3hash", "sparkMurmurHash3_32"}, - {"xxhash64", "sparkXxHash64"}, - - // in functions - {"in", "in"}, - - // null related functions - {"coalesce", "coalesce"}, - - // date or datetime functions - {"from_unixtime", "fromUnixTimestampInJodaSyntax"}, - {"date_add", "addDays"}, - {"date_sub", "subtractDays"}, - {"datediff", "dateDiff"}, - {"second", "toSecond"}, - {"add_months", "addMonths"}, - {"date_trunc", "dateTrunc"}, - {"floor_datetime", "dateTrunc"}, - {"floor", "sparkFloor"}, - {"months_between", "sparkMonthsBetween"}, - - // array functions - {"array", "array"}, - {"shuffle", "arrayShuffle"}, - {"range", "range"}, /// dummy mapping - {"flatten", "sparkArrayFlatten"}, - {"array_join", "sparkArrayJoin"}, - - // map functions - {"map", "map"}, - {"get_map_value", "arrayElement"}, - {"map_keys", "mapKeys"}, - {"map_values", "mapValues"}, - {"map_from_arrays", "mapFromArrays"}, - - // tuple functions - {"get_struct_field", "sparkTupleElement"}, - {"get_array_struct_fields", "sparkTupleElement"}, - {"named_struct", "tuple"}, - - // table-valued generator function - {"explode", "arrayJoin"}, - {"posexplode", "arrayJoin"}, - - // json functions - {"flattenJSONStringOnRequired", "flattenJSONStringOnRequired"}, - {"get_json_object", "get_json_object"}, - {"to_json", "toJSONString"}, - {"from_json", "JSONExtract"}, - {"json_tuple", "json_tuple"}, - {"json_array_length", "JSONArrayLength"}, - {"make_decimal", "makeDecimalSpark"}, - {"unscaled_value", "unscaleValueSpark"}, - - // runtime filter - {"might_contain", "bloomFilterContains"}}; - -static const std::set FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"}; - DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type); DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); @@ -308,12 +132,12 @@ class SerializedPlanParser RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } const std::unordered_map & getFunctionMapping() { return function_mapping; } - static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); std::optional getFunctionSignatureName(UInt32 function_ref) const; IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); - + static std::pair parseLiteral(const substrait::Expression_Literal & literal); static ContextMutablePtr global_context; @@ -364,15 +188,7 @@ class SerializedPlanParser void parseFunctionArguments( DB::ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, - std::string & function_name, const substrait::Expression_ScalarFunction & scalar_function); - void parseFunctionArgument( - DB::ActionsDAGPtr & actions_dag, - ActionsDAG::NodeRawConstPtrs & parsed_args, - const std::string & function_name, - const substrait::FunctionArgument & arg); - const DB::ActionsDAG::Node * - parseFunctionArgument(DB::ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg); void parseArrayJoinArguments( DB::ActionsDAGPtr & actions_dag, diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp index da99eb19537f..123d13c36587 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp @@ -43,7 +43,7 @@ String CountParser::getCHFunctionName(DB::DataTypes &) const } DB::ActionsDAG::NodeRawConstPtrs CountParser::parseFunctionArguments( - const CommonFunctionInfo & func_info, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const + const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const { if (func_info.arguments.size() < 1) { diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h index a561f87d940d..a83ec2d5a337 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h @@ -30,6 +30,6 @@ class CountParser : public AggregateFunctionParser String getCHFunctionName(const CommonFunctionInfo &) const override; String getCHFunctionName(DB::DataTypes &) const override; DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; + const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const override; }; } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp index dd9e3ff445fb..6a56a82d5044 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp @@ -24,7 +24,7 @@ namespace local_engine { DB::ActionsDAG::NodeRawConstPtrs -LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const +LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const { DB::ActionsDAG::NodeRawConstPtrs args; const auto & arg0 = func_info.arguments[0].value(); @@ -67,7 +67,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, const S AggregateFunctionParserRegister lead_register; DB::ActionsDAG::NodeRawConstPtrs -LagParser::parseFunctionArguments(const CommonFunctionInfo & func_info, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const +LagParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const { DB::ActionsDAG::NodeRawConstPtrs args; const auto & arg0 = func_info.arguments[0].value(); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h index 4fa1c1bbca13..25f679c77b40 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h @@ -29,7 +29,7 @@ class LeadParser : public AggregateFunctionParser String getCHFunctionName(const CommonFunctionInfo &) const override { return "leadInFrame"; } String getCHFunctionName(DB::DataTypes &) const override { return "leadInFrame"; } DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; + const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const override; }; class LagParser : public AggregateFunctionParser @@ -42,6 +42,6 @@ class LagParser : public AggregateFunctionParser String getCHFunctionName(const CommonFunctionInfo &) const override { return "lagInFrame"; } String getCHFunctionName(DB::DataTypes &) const override { return "lagInFrame"; } DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; + const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const override; }; } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp index 49a59c6570fb..19d7930fc1fc 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp @@ -22,7 +22,7 @@ namespace local_engine { DB::ActionsDAG::NodeRawConstPtrs -NtileParser::parseFunctionArguments(const CommonFunctionInfo & func_info, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const +NtileParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const { if (func_info.arguments.size() != 1) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function ntile takes exactly one argument"); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h index 441de2353247..28878a9f89db 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h @@ -29,6 +29,6 @@ class NtileParser : public AggregateFunctionParser String getCHFunctionName(const CommonFunctionInfo &) const override { return "ntile"; } String getCHFunctionName(DB::DataTypes &) const override { return "ntile"; } DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; + const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const override; }; } diff --git a/cpp-ch/local-engine/Parser/example_udf/myMd5.cpp b/cpp-ch/local-engine/Parser/example_udf/myMd5.cpp index 99e0a0041335..1e70c775e130 100644 --- a/cpp-ch/local-engine/Parser/example_udf/myMd5.cpp +++ b/cpp-ch/local-engine/Parser/example_udf/myMd5.cpp @@ -43,7 +43,7 @@ class FunctionParserMyMd5 : public FunctionParser { // In Spark: md5(str) // In CH: lower(hex(MD5(str))) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index 9ed777131a7e..ca90b0bdbf08 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -33,7 +33,6 @@ namespace local_engine { \ return #substrait_name; \ } \ - protected: \ String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override \ { \ return #ch_name; \ @@ -44,4 +43,136 @@ namespace local_engine REGISTER_COMMON_SCALAR_FUNCTION_PARSER(NextDay, next_day, spark_next_day) REGISTER_COMMON_SCALAR_FUNCTION_PARSER(LastDay, last_day, toLastDayOfMonth) REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Str2Map, str_to_map, spark_str_to_map) + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(IsNotNull, is_not_null, isNotNull); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(IsNull, is_null, isNull); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GTE, gte, greaterOrEquals); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GT, gt, greater); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(LTE, lte, lessOrEquals); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(LT, lt, less); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(And, and, and); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Or, or, or); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Equal, equal, equals); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Not, not, not); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Xor, xor, xor); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cast, cast, CAST); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GetTimestamp, get_timestamp, parseDateTimeInJodaSyntaxOrNull); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Quarter, quarter, toQuarter); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToUnixTimestamp, to_unix_timestamp, parseDateTimeInJodaSyntaxOrNull); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Position, positive, identity); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Negative, negative, negate); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pmod, pmod, pmod); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(abs, abs, abs); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ceil, ceil, ceil); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Round, round, roundHalfUp); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bround, bround, roundBankers); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Exp, exp, exp); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Power, power, power); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cos, cos, cos); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cosh, cosh, cosh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sin, sin, sin); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sinh, sinh, sinh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Tan, tan, tan); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Tanh, tanh, tanh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Acos, acos, acos); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Asin, asin, asin); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Atan, atan, atan); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Atan2, atan2, atan2); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Asinh, asinh, asinh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Acosh, acosh, acosh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Atanh, atanh, atanh); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitwiseNot, bitwise_not, bitNot); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitwiseAnd, bitwise_and, bitAnd); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitwiseOr, bitwise_or, bitOr); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitwiseXor, bitwise_xor, bitXor); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitGet, bit_get, bitTest); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitCount, bit_count, bitCount); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sqrt, sqrt, sqrt); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cbrc, cbrt, cbrt); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Degrees, degrees, degrees); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(E, e, e); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pi, pi, pi); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Hex, hex, hex); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Unhex, unhex, unhex); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Hypot, hypot, hypot); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Like, like, like); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(NotLike, not_like, notLike); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(StartsWith, starts_with, startsWithUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(EndsWith, ends_with, endsWithUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Contains, contains, countSubstrings); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(SubstringIndex, substring_index, substringIndexUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Lower, lower, lowerUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Upper, upper, upperUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Strpos, strpos, positionUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Replace, replace, replaceAll); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(RegexpReplace, regexp_replace, replaceRegexpAll); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(RegexpExtractAll, regexp_extract_all, regexpExtractAllSpark); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rlike, rlike, match); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ascii, ascii, ascii); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Base64, base64, base64Encode); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Unbase64, unbase64, base64Decode); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Lpad, lpad, leftPadUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rpad, rpad, rightPadUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Translate, translate, translateUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Initcap, initcap, initcapUTF8); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Conv, conv, sparkConv); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Uuid, uuid, generateUUIDv4); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(LevenshteinDistance, levenshteinDistance, editDistanceUTF8); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Crc32, crc32, CRC32); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Murmur3Hash, murmur3hash, sparkMurmurHash3_32); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Xxhash64, xxhash64, sparkXxHash64); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(In, in, in); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Coalesce, coalesce, coalesce); + +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FromUnixtime, from_unixtime, fromUnixTimestampInJodaSyntax); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(DateAdd, date_add, addDays); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(DateSub, date_sub, subtractDays); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(DateDiff, datediff, dateDiff); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Second, second, toSecond); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(AddMonths, add_months, addMonths); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(DateTrunc, date_trunc, dateTrunc); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FloorDatetime, floor_datetime, dateTrunc); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Floor, floor, sparkFloor); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MothsBetween, months_between, sparkMonthsBetween); + + +// array functions +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Array, array, array); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin); + +// map functions +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GetMapValue, get_map_value, arrayElement); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapKeys, map_keys, mapKeys); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapValues, map_values, mapValues); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapFromArrays, map_from_arrays, mapFromArrays); + + +// json functions +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FlattenJsonStringOnRequired, flattenJSONStringOnRequired, flattenJSONStringOnRequired); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToJson, to_json, toJSONString); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(JsonTuple, json_tuple, json_tuple); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(JsonArrayLen, json_array_length, JSONArrayLength); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(UnscaledValue, unscaled_value, unscaleValueSpark); + +// runtime filter +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MightContain, might_contain, bloomFilterContains); } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/alias.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/alias.cpp new file mode 100644 index 000000000000..e5493eb80b2a --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/alias.cpp @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace local_engine +{ +class SparkFunctionAliasParser : public FunctionParser +{ +public: + SparkFunctionAliasParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionAliasParser() override = default; + static constexpr auto name = "alias"; + String getName() const { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + for (const auto & arg : args) + { + if (arg.value().has_scalar_function()) + { + String empty_result_name;// no meaning + parsed_args.emplace_back(parseFunctionWithDAG(arg.value(), empty_result_name, actions_dag, true)); + } + else + parsed_args.emplace_back(parseExpression(actions_dag, arg.value())); + } + String result_name = parsed_args[0]->result_name; + actions_dag->addOrReplaceInOutputs(*parsed_args[0]); + return &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name); + } + +}; +static FunctionParserRegister register_alias; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp index d58b22a87e6c..7d8c3f948c63 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp @@ -146,7 +146,7 @@ class FunctionParserBinaryArithmetic : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { const auto ch_func_name = getCHFunctionName(substrait_func); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); @@ -191,6 +191,7 @@ class FunctionParserPlus final : public FunctionParserBinaryArithmetic static constexpr auto name = "add"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "plus"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override @@ -206,6 +207,7 @@ class FunctionParserMinus final : public FunctionParserBinaryArithmetic static constexpr auto name = "subtract"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "minus"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override @@ -220,6 +222,7 @@ class FunctionParserMultiply final : public FunctionParserBinaryArithmetic explicit FunctionParserMultiply(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } static constexpr auto name = "multiply"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "multiply"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override @@ -234,6 +237,7 @@ class FunctionParserModulo final : public FunctionParserBinaryArithmetic explicit FunctionParserModulo(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } static constexpr auto name = "modulus"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "modulo"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override @@ -248,6 +252,7 @@ class FunctionParserDivide final : public FunctionParserBinaryArithmetic explicit FunctionParserDivide(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } static constexpr auto name = "divide"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "divide"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayContains.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayContains.cpp index 05301cc82736..d92a1eac7da2 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayContains.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayContains.cpp @@ -45,8 +45,8 @@ class FunctionParserArrayContains : public FunctionParser String getName() const override { return name; } const ActionsDAG::Node * parse( - const substrait::Expression_ScalarFunction & substrait_func, - ActionsDAGPtr & actions_dag) const override + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override { /** parse array_contains(arr, value) as @@ -65,7 +65,7 @@ class FunctionParserArrayContains : public FunctionParser arr.nullable || value.nullable || arr.dataType.asInstanceOf[ArrayType].containsNull */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); @@ -108,7 +108,7 @@ class FunctionParserArrayContains : public FunctionParser }); return convertNodeTypeIfNeeded(substrait_func, multi_if_node, actions_dag); } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override { return "has"; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayDistinct.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayDistinct.cpp index b828b29f7ee2..30709a7e9ed6 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayDistinct.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayDistinct.cpp @@ -44,7 +44,7 @@ class FunctionParserArrayDistinct : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayElement.h b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayElement.h index 9081a37bd89e..5873d39cc22b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayElement.h +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayElement.h @@ -52,7 +52,7 @@ class FunctionParserArrayElement : public FunctionParser else arrayElement(arr, idx) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); @@ -70,7 +70,7 @@ class FunctionParserArrayElement : public FunctionParser auto * if_node = toFunctionNode(actions_dag, "if", {greater_or_equals_node, null_const_node, array_element_node}); return if_node; } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override { return "arrayElement"; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp index 3811880aea63..eacd72ed044f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -51,7 +51,7 @@ class ArrayFilter : public FunctionParser DB::ActionsDAGPtr & actions_dag) const { auto ch_func_name = getCHFunctionName(substrait_func); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); assert(parsed_args.size() == 2); if (collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function()).size() == 1) return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); @@ -86,7 +86,7 @@ class ArrayTransform : public FunctionParser { auto ch_func_name = getCHFunctionName(substrait_func); auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function()); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); assert(parsed_args.size() == 2); if (lambda_args.size() == 1) { @@ -121,7 +121,7 @@ class ArrayAggregate : public FunctionParser DB::ActionsDAGPtr & actions_dag) const { auto ch_func_name = getCHFunctionName(substrait_func); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); assert(parsed_args.size() == 3); const auto * function_type = typeid_cast(parsed_args[2]->result_type.get()); if (!function_type) @@ -166,7 +166,7 @@ class ArraySort : public FunctionParser DB::ActionsDAGPtr & actions_dag) const { auto ch_func_name = getCHFunctionName(substrait_func); - auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort function must have two arguments"); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayIntersect.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayIntersect.cpp index 03eec981a516..2891846ef014 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayIntersect.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayIntersect.cpp @@ -47,7 +47,7 @@ class FunctionParserArrayIntersect : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayMaxAndMin.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayMaxAndMin.cpp index 87ae8ab7764a..a0e6786442ee 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayMaxAndMin.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayMaxAndMin.cpp @@ -42,7 +42,7 @@ class BaseFunctionParserArrayMaxAndMin : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp index 35b61c99d273..d3eed7c67568 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp @@ -58,7 +58,7 @@ class FunctionParserArrayPosition : public FunctionParser 2. CH indexOf function cannot accept Nullable(Array()) type as first argument */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); @@ -92,7 +92,7 @@ class FunctionParserArrayPosition : public FunctionParser const auto * if_node = toFunctionNode(actions_dag, "if", {or_condition_node, null_const_node, wrap_index_of_node}); return convertNodeTypeIfNeeded(substrait_func, if_node, actions_dag); } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override { return "indexOf"; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayUnion.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayUnion.cpp index 917f37d53104..7a48d7920d2c 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayUnion.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayUnion.cpp @@ -45,7 +45,7 @@ class FunctionParserArrayUnion : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse array_union(a, b) as arrayDistinctSpark(arrayConcat(a, b)) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp index 9358c45788cf..b2389d276f10 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp @@ -41,7 +41,7 @@ class FunctionParserBitLength : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { // parse bit_length(a) as octet_length(a) * 8 - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/checkOverflow.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/checkOverflow.cpp new file mode 100644 index 000000000000..e5228d160870 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/checkOverflow.cpp @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ +class SparkFunctionCheckOverflow : public FunctionParser +{ +public: + SparkFunctionCheckOverflow(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionCheckOverflow() override = default; + + static constexpr auto name = "check_overflow"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + const auto & args = func.arguments(); + if (args.size() < 2) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two arguments"); + String ch_function_name = "checkDecimalOverflowSpark"; + auto null_on_overflow = args[1].value().literal().boolean(); + if (null_on_overflow) + ch_function_name += "OrNull"; + return ch_function_name; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + UInt32 precision = substrait_func.output_type().decimal().precision(); + UInt32 scale = substrait_func.output_type().decimal().scale(); + auto uint32_type = std::make_shared(); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, uint32_type, precision)); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, uint32_type, scale)); + + auto ch_function_name = getCHFunctionName(substrait_func); + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_check_overflow; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/chr.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/chr.cpp index d168e63d11dc..7b755b185637 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/chr.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/chr.cpp @@ -39,10 +39,10 @@ class FunctionParserChr : public FunctionParser String getName() const override { return name; } const ActionsDAG::Node * parse( - const substrait::Expression_ScalarFunction & substrait_func, - ActionsDAGPtr & actions_dag) const override + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires two or three arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp index 416fe7741812..cfafdfd98c37 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp @@ -42,6 +42,7 @@ class FunctionParserConcat : public FunctionParser static constexpr auto name = "concat"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; } const ActionsDAG::Node * parse( const substrait::Expression_ScalarFunction & substrait_func, @@ -55,7 +56,7 @@ class FunctionParserConcat : public FunctionParser 2) if args have size 1, return identity(args[0]) 3) otherwise return concat(args) */ - auto args = parseFunctionArguments(substrait_func, "", actions_dag); + auto args = parseFunctionArguments(substrait_func, actions_dag); const auto & output_type = substrait_func.output_type(); const ActionsDAG::Node * result_node = nullptr; if (output_type.has_list()) diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/concatWs.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/concatWs.cpp index d7f9d39b139a..e2993f1f2d66 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/concatWs.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/concatWs.cpp @@ -42,6 +42,7 @@ class FunctionParserConcatWS : public FunctionParser static constexpr auto name = "concat_ws"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; } const ActionsDAG::Node * parse( const substrait::Expression_ScalarFunction & substrait_func, @@ -51,7 +52,7 @@ class FunctionParserConcatWS : public FunctionParser parse concat_ws(sep, s1, s2, arr1, arr2, ...)) as arrayStringConcat(arrayFlatten(array(s1), array(s2), arr1, arr2, ...), sep) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.empty()) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least one argument", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/cot.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/cot.cpp index 84eaa3ea615a..47750403049c 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/cot.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/cot.cpp @@ -44,7 +44,7 @@ class FunctionParserCot : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse cot(x) as 1 / tan(x) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/csc.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/csc.cpp index b63a76ed5305..009c1b764f98 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/csc.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/csc.cpp @@ -44,7 +44,7 @@ class FunctionParserCsc : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse csc(x) as 1 / sin(x) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/dateFormat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/dateFormat.cpp new file mode 100644 index 000000000000..980fdd4cfec0 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/dateFormat.cpp @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace local_engine +{ +class SparkFunctionDateFormatParser : public FunctionParser +{ +public: + SparkFunctionDateFormatParser(SerializedPlanParser * plan_paser_) : FunctionParser(plan_paser_) {} + ~SparkFunctionDateFormatParser() override = default; + + static constexpr auto name = "date_format"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + return "formatDateTimeInJodaSyntax"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + for (const auto & arg : args) + parsed_args.emplace_back(parseExpression(actions_dag, arg.value())); + /// If the first argument of function formatDateTimeInJodaSyntax is integer, replace formatDateTimeInJodaSyntax with fromUnixTimestampInJodaSyntax + /// to avoid exception + auto ch_function_name = getCHFunctionName(substrait_func); + if (args.size() > 1 && DB::isInteger(parsed_args[0]->result_type)) + ch_function_name = "fromUnixTimestampInJodaSyntax"; + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + + } +}; +static FunctionParserRegister register_date_format; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/decode.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/decode.cpp index 212c40115675..48b86ed6b58b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/decode.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/decode.cpp @@ -41,11 +41,11 @@ class FunctionParserDecode : public FunctionParser String getName() const override { return name; } const ActionsDAG::Node * parse( - const substrait::Expression_ScalarFunction & substrait_func, - ActionsDAGPtr & actions_dag) const override + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override { /// Parse decode(bin, charset) as convertCharset(bin, charset, 'UTF-8') - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp index eb369a373bf5..ce18859174ad 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/elementAt.cpp @@ -31,7 +31,7 @@ class FunctionParserElementAt : public FunctionParserArrayElement const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); if (isMap(removeNullable(parsed_args[0]->result_type))) diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp index 916e16e7b70b..23f372e5aef8 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp @@ -56,7 +56,7 @@ class FunctionParserElt : public FunctionParser else arrayElement(array(e1, e2, e3, ...), index) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() < 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/empty2null.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/empty2null.cpp index 557d1e986e77..081cff67ee44 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/empty2null.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/empty2null.cpp @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include #include @@ -27,7 +43,7 @@ class FunctionParserEmpty2Null : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires exactly one arguments", getName()); if (parsed_args.at(0)->result_type->getName() != "Nullable(String)") diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/encode.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/encode.cpp index f370c9957f25..2dcbffca2098 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/encode.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/encode.cpp @@ -41,11 +41,11 @@ class FunctionParserEncode : public FunctionParser String getName() const override { return name; } const ActionsDAG::Node * parse( - const substrait::Expression_ScalarFunction & substrait_func, - ActionsDAGPtr & actions_dag) const override + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override { /// Parse encode(str, charset) as convertCharset(str, 'UTF-8', charset) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/equalNullSafe.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/equalNullSafe.cpp index 1ec0df52dce3..d35bf810ffc6 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/equalNullSafe.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/equalNullSafe.cpp @@ -49,7 +49,7 @@ class FunctionParserEqualNullSafe : public FunctionParser /// return false /// else /// return equals(left, right) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); @@ -71,4 +71,4 @@ class FunctionParserEqualNullSafe : public FunctionParser }; static FunctionParserRegister register_equal_null_safe; -} \ No newline at end of file +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/expm1.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/expm1.cpp index 7470da62b424..ef98de6417ff 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/expm1.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/expm1.cpp @@ -42,7 +42,7 @@ class FunctionParserExpm1 : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse expm1(x) as exp(x) - 1 - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/extract.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/extract.cpp new file mode 100644 index 000000000000..43cf1f3a34ef --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/extract.cpp @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +namespace local_engine +{ +class SparkFunctionExtractParser : public FunctionParser +{ +public: + SparkFunctionExtractParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionExtractParser() override = default; + + static constexpr auto name = "extract"; + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + const auto & args = func.arguments(); + if (args.size() != 2) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", func.ShortDebugString()); + const auto & extract_field = args.at(0); + String ch_function_name = ""; + if (extract_field.value().has_literal()) + { + const auto & field_value = extract_field.value().literal().string(); + if (field_value == "YEAR") + ch_function_name = "toYear"; // spark: extract(YEAR FROM) or year + else if (field_value == "YEAR_OF_WEEK") + ch_function_name = "toISOYear"; // spark: extract(YEAROFWEEK FROM) + else if (field_value == "QUARTER") + ch_function_name = "toQuarter"; // spark: extract(QUARTER FROM) or quarter + else if (field_value == "MONTH") + ch_function_name = "toMonth"; // spark: extract(MONTH FROM) or month + else if (field_value == "WEEK_OF_YEAR") + ch_function_name = "toISOWeek"; // spark: extract(WEEK FROM) or weekofyear + else if (field_value == "WEEK_DAY") + /// Spark WeekDay(date) (0 = Monday, 1 = Tuesday, ..., 6 = Sunday) + /// Substrait: extract(WEEK_DAY from date) + /// CH: toDayOfWeek(date, 1) + ch_function_name = "toDayOfWeek"; + else if (field_value == "DAY_OF_WEEK") + /// Spark: DayOfWeek(date) (1 = Sunday, 2 = Monday, ..., 7 = Saturday) + /// Substrait: extract(DAY_OF_WEEK from date) + /// CH: toDayOfWeek(date, 3) + /// DAYOFWEEK is alias of function toDayOfWeek. + /// This trick is to distinguish between extract fields DAY_OF_WEEK and WEEK_DAY in latter codes + ch_function_name = "DAYOFWEEK"; + else if (field_value == "DAY") + ch_function_name = "toDayOfMonth"; // spark: extract(DAY FROM) or dayofmonth + else if (field_value == "DAY_OF_YEAR") + ch_function_name = "toDayOfYear"; // spark: extract(DOY FROM) or dayofyear + else if (field_value == "HOUR") + ch_function_name = "toHour"; // spark: extract(HOUR FROM) or hour + else if (field_value == "MINUTE") + ch_function_name = "toMinute"; // spark: extract(MINUTE FROM) or minute + else if (field_value == "SECOND") + ch_function_name = "toSecond"; // spark: extract(SECOND FROM) or secondwithfraction + } + + if (ch_function_name.empty()) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong."); + return ch_function_name; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + + /// Skip the first arg of extract in substrait + for (int i = 1; i < args.size(); i++) + parsed_args.emplace_back(parseExpression(actions_dag, args[i].value())); + + /// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait + if (ch_function_name == "toDayOfWeek" || ch_function_name == "DAYOFWEEK") + { + UInt8 mode = ch_function_name == "toDayOfWeek" ? 1 : 3; + auto mode_type = std::make_shared(); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, mode_type, mode)); + } + + const DB::ActionsDAG::Node * func_node = nullptr; + if (ch_function_name == "toYear") + { + auto arg_func_name = parsed_args[0]->function ? parsed_args[0]->function->getName() : ""; + if (arg_func_name == "sparkToDate" || arg_func_name == "sparkToDateTime" && parsed_args[0]->children.size() > 0) + { + const auto * child_node = parsed_args[0]->children[0]; + if (child_node && DB::isString(DB::removeNullable(child_node->result_type))) + { + func_node = toFunctionNode(actions_dag, "sparkExtractYear", {child_node}); + } + } + } + + if (!func_node) + func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_extract; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/factorial.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/factorial.cpp index c2f0a383b4ef..f1ef4ec8b9ba 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/factorial.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/factorial.cpp @@ -46,7 +46,7 @@ class FunctionParserFactorial : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse factorial(x) as if (x > 20 || x < 0) null else factorial(x) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp index b297d7a67f66..345343119963 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp @@ -55,7 +55,7 @@ class FunctionParserFindInSet : public FunctionParser null else indexOf(assumeNotNull(splitByChar(',', str_array)), str) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/fromJson.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/fromJson.cpp new file mode 100644 index 000000000000..2dd8754189b7 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/fromJson.cpp @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +namespace local_engine +{ +class SparkFunctionFromJsonParser : public FunctionParser +{ +public: + SparkFunctionFromJsonParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionFromJsonParser() override = default; + + static constexpr auto name = "from_json"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*func*/) const override + { + return "JSONExtract"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + parsed_args.emplace_back(parseExpression(actions_dag, substrait_func.arguments()[0].value())); + auto data_type = TypeParser::parseType(substrait_func.output_type()); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, std::make_shared(), data_type->getName())); + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_from_json; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/getJSONObject.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/getJSONObject.cpp index 5757cb7d6f45..aad75130aa47 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/getJSONObject.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/getJSONObject.cpp @@ -41,7 +41,6 @@ class GetJSONObjectParser : public FunctionParser ~GetJSONObjectParser() override = default; String getName() const override { return name; } -protected: String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { const auto & args = scalar_function.arguments(); @@ -53,16 +52,16 @@ class GetJSONObjectParser : public FunctionParser return name; } +protected: /// Force to reuse the same flatten json column node DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const substrait::Expression_ScalarFunction & substrait_func, - const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override { const auto & args = substrait_func.arguments(); if (args.size() != 2) { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} requires 2 arguments", ch_func_name); + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} requires 2 arguments", getCHFunctionName(substrait_func)); } if (args[0].value().has_scalar_function() && args[0].value().scalar_function().function_reference() == SelfDefinedFunctionReference::GET_JSON_OBJECT) diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/isNaN.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/isNaN.cpp new file mode 100644 index 000000000000..3409c61d4651 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/isNaN.cpp @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +class SparkFunctionIsNaNParser : public FunctionParser +{ +public: + SparkFunctionIsNaNParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionIsNaNParser() override = default; + + static constexpr auto name = "isnan"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "isNaN"; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + // the result of isNaN(NULL) is NULL in CH, but false in Spark + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + const DB::ActionsDAG::Node * arg_node = nullptr; + if (args[0].value().has_cast()) + { + arg_node = parseExpression(actions_dag, args[0].value().cast().input()); + auto result_type = DB::removeNullable(arg_node->result_type); + if (DB::isString(*result_type)) + arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", {arg_node}); + else + arg_node = parseExpression(actions_dag, args[0].value()); + } + else + arg_node = parseExpression(actions_dag, args[0].value()); + + DB::ActionsDAG::NodeRawConstPtrs ifnull_args = {arg_node, addColumnToActionsDAG(actions_dag, std::make_shared(), 0)}; + parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_args)); + + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_isnan; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp index 57c076ed2670..6647b82b9566 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -65,28 +65,12 @@ class LambdaFunction : public FunctionParser ~LambdaFunction() override = default; String getName() const override { return name; } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction"); } - DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const substrait::Expression_ScalarFunction & substrait_func, - const String & ch_func_name, - DB::ActionsDAGPtr & actions_dag) const override - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); - } - - const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const substrait::Expression_ScalarFunction & substrait_func, - const DB::ActionsDAG::Node * func_node, - DB::ActionsDAGPtr & actions_dag) const override - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); - } - const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override { /// Some special cases, for example, `transform(arr, x -> concat(arr, array(x)))` refers to @@ -166,6 +150,21 @@ class LambdaFunction : public FunctionParser const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name); return result; } +protected: + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } }; static FunctionParserRegister register_lambda_function; @@ -179,28 +178,12 @@ class NamedLambdaVariable : public FunctionParser ~NamedLambdaVariable() override = default; String getName() const override { return name; } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable"); } - DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( - const substrait::Expression_ScalarFunction & substrait_func, - const String & ch_func_name, - DB::ActionsDAGPtr & actions_dag) const override - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); - } - - const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const substrait::Expression_ScalarFunction & substrait_func, - const DB::ActionsDAG::Node * func_node, - DB::ActionsDAGPtr & actions_dag) const override - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); - } - const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override { auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); @@ -215,6 +198,21 @@ class NamedLambdaVariable : public FunctionParser } return *it; } +protected: + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } }; static FunctionParserRegister register_named_lambda_variable; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp index 85fe1f29aa25..af998d4d2e69 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp @@ -47,7 +47,7 @@ class FunctionParserLength : public FunctionParser else length(a) as char_length(a) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp index bc9ea41f853a..efc6da7c4659 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp @@ -44,7 +44,7 @@ class FunctionParserLocate : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { /// Parse locate(substr, str, start_pos) as if(isNull(start_pos), 0, positionUTF8Spark(str, substr, start_pos) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 3) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly three arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp index bafca3b213d7..75a6894597f5 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp @@ -53,7 +53,7 @@ class FunctionParserLog : public FunctionParser else ln(y) / ln(x) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h b/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h index 2a879623ad88..7a83d78fa845 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h @@ -52,7 +52,7 @@ class FunctionParserLogBase : public FunctionParser else log(x) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/makeDecimal.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/makeDecimal.cpp new file mode 100644 index 000000000000..977167ef3601 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/makeDecimal.cpp @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ +class SparkFunctionMakeDecimalParser : public FunctionParser +{ +public: + SparkFunctionMakeDecimalParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionMakeDecimalParser() override = default; + + static constexpr auto name = "make_decimal"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + const auto & args = func.arguments(); + if (args.size() < 2) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least two arguments"); + String ch_function_name = "makeDecimalSpark"; + if (args[1].value().literal().boolean()) + ch_function_name += "OrNull"; + return ch_function_name; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + UInt32 precision = substrait_func.output_type().decimal().precision(); + UInt32 scale = substrait_func.output_type().decimal().scale(); + auto uint32_type = std::make_shared(); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, uint32_type, precision)); + parsed_args.emplace_back(addColumnToActionsDAG(actions_dag, uint32_type, scale)); + + auto ch_function_name = getCHFunctionName(substrait_func); + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_make_decimal; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/md5.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/md5.cpp index 84d497fa90ea..c57197e70d0b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/md5.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/md5.cpp @@ -43,7 +43,7 @@ class FunctionParserMd5 : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { /// Parse md5(str) as lower(hex(md5(str))) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp index 010f3eef32b5..d8f29d727576 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp @@ -55,7 +55,7 @@ class FunctionParserNaNvl : public FunctionParser else e1 */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp index 52cbd0317290..d2c159a1b69e 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp @@ -40,7 +40,7 @@ class FunctionParserOctetLength : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.cpp index ed30b0727f5d..af573367448f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.cpp @@ -99,7 +99,7 @@ String ParseURLParser::selectCHFunctionName(const substrait::Expression_ScalarFu } DB::ActionsDAG::NodeRawConstPtrs ParseURLParser::parseFunctionArguments( - const substrait::Expression_ScalarFunction & substrait_func, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const + const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const { DB::ActionsDAG::NodeRawConstPtrs arg_nodes; arg_nodes.push_back(parseExpression(actions_dag, substrait_func.arguments(0).value())); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.h b/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.h index 9d8aae8e21a6..a4d6e0f057ea 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.h +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/parseUrl.h @@ -26,12 +26,11 @@ class ParseURLParser final : public FunctionParser ~ParseURLParser() override = default; String getName() const override { return name; } -protected: String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override; +protected: DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const substrait::Expression_ScalarFunction & substrait_func, - const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/regexp_extract.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/regexp_extract.cpp index 8f75baf689b2..ba30a3c59e4c 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/regexp_extract.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/regexp_extract.cpp @@ -61,7 +61,7 @@ class FunctionParserRegexpExtract : public FunctionParser String sparkRegexp = adjustSparkRegexpRule(expr_str); const auto * regex_expr_node = addColumnToActionsDAG(actions_dag, std::make_shared(), sparkRegexp); - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); parsed_args[1] = regex_expr_node; const auto * result_node = toFunctionNode(actions_dag, "regexpExtract", parsed_args); return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp new file mode 100644 index 000000000000..cc32fc015535 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +class SparkFunctionRepeatParser : public FunctionParser +{ +public: + SparkFunctionRepeatParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionRepeatParser() override = default; + + static constexpr auto name = "repeat"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait + // which must be a positive value into unsigned integer here. + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + const auto * repeat_times_node = parseExpression(actions_dag, args[1].value()); + DB::DataTypeNullable target_type(std::make_shared()); + repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName()); + parsed_args.emplace_back(repeat_times_node); + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_repeat; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/reverse.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/reverse.cpp new file mode 100644 index 000000000000..86406c433959 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/reverse.cpp @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +} + +namespace local_engine +{ +class SparkFunctionReverseParser : public FunctionParser +{ +public: + SparkFunctionReverseParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionReverseParser() override = default; + + static constexpr auto name = "reverse"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + if (func.output_type().has_list()) + return "arrayReverse"; + return "reverseUTF8"; + } +}; +static FunctionParserRegister register_reverse; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sec.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sec.cpp index 70765e07d037..4b95bcbe530f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sec.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sec.cpp @@ -44,7 +44,7 @@ class FunctionParserSec : public FunctionParser ActionsDAGPtr & actions_dag) const override { /// parse sec(x) as 1 / cos(x) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sequence.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sequence.cpp index 373bd53132b3..0e98759f6c7f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sequence.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sequence.cpp @@ -61,7 +61,7 @@ class FunctionParserSequence : public FunctionParser step = if(start <= end, 1, -1) */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() < 2 || parsed_args.size() > 3) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires two or three arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sha1.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sha1.cpp index 4e7872c9633e..eb7578a3f4b6 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sha1.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sha1.cpp @@ -43,7 +43,7 @@ class FunctionParserSha1 : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { /// Parse sha1(str) as lower(hex(sha1(str))) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 1) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sha2.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sha2.cpp index 139d49936964..75db4cd173fd 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sha2.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sha2.cpp @@ -47,7 +47,7 @@ class FunctionParserSha2 : public FunctionParser /// Parse sha2(str, 224) as lower(hex(SHA224(str))) /// Parse sha2(str, 384) as lower(hex(SHA384(str))) /// Parse sha2(str, 512) as lower(hex(SHA512(str))) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp index b638e234fbc5..e0932e621b75 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp @@ -49,7 +49,7 @@ class FunctionParserShiftRightUnsigned : public FunctionParser /// else /// throw Exception - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/size.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/size.cpp index 544da96f92cc..09db14ced0f0 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/size.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/size.cpp @@ -44,7 +44,7 @@ class FunctionParserSize : public FunctionParser { /// Parse size(child, true) as ifNull(length(child), -1) /// Parse size(child, false) as length(child) - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp index 3fd26a41f6b3..2dca0cee182e 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp @@ -60,7 +60,7 @@ class FunctionParserArraySlice : public FunctionParser 2. Spark slice returns null if any of the argument is null */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 3) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires exactly three arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp index 4fd2fd4f6800..3386b642fa21 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp @@ -45,7 +45,7 @@ class FunctionParserSortArray : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/space.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/space.cpp new file mode 100644 index 000000000000..3698ddad78cf --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/space.cpp @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace local_engine +{ +class SparkFunctionSpaceParser : public FunctionParser +{ +public: + SparkFunctionSpaceParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionSpaceParser() override = default; + + static constexpr auto name = "space"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "repeat"; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + // convert space function to repeat + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + + const auto * repeat_times_node = parseExpression(actions_dag, args[0].value()); + const auto * space_str_node = addColumnToActionsDAG(actions_dag, std::make_shared(), " "); + parsed_args = {space_str_node, repeat_times_node}; + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_space; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp new file mode 100644 index 000000000000..05749da89552 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace local_engine +{ +class SparkFunctionSplitParser : public FunctionParser +{ +public: + SparkFunctionSplitParser(SerializedPlanParser * plan_paser_) : FunctionParser(plan_paser_) {} + ~SparkFunctionSplitParser() override = default; + static constexpr auto name = "split"; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "splitByRegexp"; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + for (const auto & arg : args) + parsed_args.emplace_back(parseExpression(actions_dag, arg.value())); + /// In Spark: split(str, regex [, limit] ) + /// In CH: splitByRegexp(regexp, str [, limit]) + if (parsed_args.size() >= 2) + std::swap(parsed_args[0], parsed_args[1]); + auto ch_function_name = getCHFunctionName(substrait_func); + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_split; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/substring.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/substring.cpp index 550e77344ddf..444213973cb2 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/substring.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/substring.cpp @@ -42,7 +42,7 @@ class FunctionParserSubstring : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 3) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires two or three arguments", getName()); @@ -64,7 +64,7 @@ class FunctionParserSubstring : public FunctionParser const auto * substring_func_node = toFunctionNode(actions_dag, "substringUTF8", {str_arg, if_node, if_len_node}); return convertNodeTypeIfNeeded(substrait_func, substring_func_node, actions_dag); } -protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override { return "substringUTF8"; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp index d76431c0a096..6e92a7b928bc 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp @@ -42,10 +42,11 @@ class FunctionParserTimestampAdd : public FunctionParser static constexpr auto name = "timestamp_add"; String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "timestamp_add"; } const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 4) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly four arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/trimFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/trimFunctions.cpp new file mode 100644 index 000000000000..e07196b282e0 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/trimFunctions.cpp @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace local_engine +{ +class SparkFunctionTrimParser : public FunctionParser +{ +public: + SparkFunctionTrimParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionTrimParser() override = default; + + static constexpr auto name = "trim"; + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + return func.arguments().size() == 1 ? "trimBoth" : "trimBothSpark"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + + /// In substrait, the first arg is srcStr, the second arg is trimStr + /// But in CH, the first arg is trimStr, the second arg is srcStr + if (args.size() > 1) + { + parsed_args.emplace_back(parseExpression(actions_dag, args[1].value())); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + } + else + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; + +static FunctionParserRegister register_trim; + +class SparkFunctionLtrimParser : public FunctionParser +{ +public: + SparkFunctionLtrimParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionLtrimParser() override = default; + + static constexpr auto name = "ltrim"; + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + return func.arguments().size() == 1 ? "trimLeft" : "trimLeftSpark"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + + /// In substrait, the first arg is srcStr, the second arg is trimStr + /// But in CH, the first arg is trimStr, the second arg is srcStr + if (args.size() > 1) + { + parsed_args.emplace_back(parseExpression(actions_dag, args[1].value())); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + } + else + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_ltrim; + +class SparkFunctionRtrimParser : public FunctionParser +{ +public: + SparkFunctionRtrimParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionRtrimParser() override = default; + + static constexpr auto name = "rtrim"; + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & func) const override + { + return func.arguments().size() == 1 ? "trimRight" : "trimRightSpark"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + auto ch_function_name = getCHFunctionName(substrait_func); + const auto & args = substrait_func.arguments(); + + /// In substrait, the first arg is srcStr, the second arg is trimStr + /// But in CH, the first arg is trimStr, the second arg is srcStr + if (args.size() > 1) + { + parsed_args.emplace_back(parseExpression(actions_dag, args[1].value())); + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + } + else + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); + + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; +static FunctionParserRegister register_rtrim; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/trunc.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/trunc.cpp index db45bb464a52..625d67a7e1c6 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/trunc.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/trunc.cpp @@ -47,7 +47,7 @@ class FunctionParserTrunc : public FunctionParser const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/tuple.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/tuple.cpp new file mode 100644 index 000000000000..3228efb0ed88 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/tuple.cpp @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace local_engine +{ +class SparkFunctionNamedStructParser : public FunctionParser +{ +public: + SparkFunctionNamedStructParser(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~SparkFunctionNamedStructParser() override = default; + + static constexpr auto name = "named_struct"; + String getName () const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "tuple"; } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + DB::ActionsDAG::NodeRawConstPtrs parsed_args; + const auto & args = substrait_func.arguments(); + auto ch_function_name = getCHFunctionName(substrait_func); + // Arguments in the format, (, [, , ...]) + // We don't need to care the field names here. + for (int i = 1; i < args.size(); i += 2) + { + parsed_args.emplace_back(parseExpression(actions_dag, args[i].value())); + } + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; + +static FunctionParserRegister register_named_struct; +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp new file mode 100644 index 000000000000..6cf0acff0d04 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} +namespace local_engine +{ +// tuple indecies start from 1, in spark, start from 0 +#define REGISTER_TUPLE_ELEMENT_PARSER(class_name, substrait_name, ch_name) \ + class SparkFunctionParser##class_name : public FunctionParser \ + { \ + public: \ + SparkFunctionParser##class_name(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}\ + ~SparkFunctionParser##class_name() override = default; \ + static constexpr auto name = #substrait_name; \ + String getName () const override { return name; } \ + String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return #ch_name; } \ + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override \ + { \ + DB::ActionsDAG::NodeRawConstPtrs parsed_args; \ + auto ch_function_name = getCHFunctionName(substrait_func); \ + const auto & args = substrait_func.arguments(); \ + parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); \ + if (!args[1].value().has_literal()) \ + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{}'s sceond argument must be a literal", #substrait_name); \ + auto [data_type, field] = parseLiteral(args[1].value().literal()); \ + if (!DB::WhichDataType(data_type).isInt32()) \ + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{}'s second argument must be i32", #substrait_name); \ + Int32 field_index = static_cast(field.get() + 1); \ + const auto * index_node = addColumnToActionsDAG(actions_dag, std::make_shared(), field_index); \ + parsed_args.emplace_back(index_node); \ + const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); \ + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); \ + } \ + }; \ + static FunctionParserRegister register_##substrait_name; + +REGISTER_TUPLE_ELEMENT_PARSER(GetStructField, get_struct_field, sparkTupleElement); +REGISTER_TUPLE_ELEMENT_PARSER(GetArrayStructFields, get_array_struct_fields, sparkTupleElement); +} + diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp index c12a6d33fd17..9488b89be67a 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp @@ -53,7 +53,7 @@ class FunctionParserUnixTimestamp : public FunctionParser 2. If expr type is date/TIMESTAMP, ch function = toUnixTimestamp(expr, format) 3. Otherwise, throw exception */ - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/utcTimestampTransform.h b/cpp-ch/local-engine/Parser/scalar_function_parser/utcTimestampTransform.h index 87ea19024169..b3b639c562bd 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/utcTimestampTransform.h +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/utcTimestampTransform.h @@ -44,14 +44,14 @@ class FunctionParserUtcTimestampTransform : public FunctionParser /// Convert timezone value to clickhouse backend supported, i.e. GMT+8 -> Etc/GMT-8, +08:00 -> Etc/GMT-8 if (substrait_func.arguments_size() != 2) throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s must have 2 arguments", getName()); - + const substrait::Expression & arg1 = substrait_func.arguments()[1].value(); if (!arg1.has_literal() || !arg1.literal().has_string()) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s 2nd argument should be string literal", getName()); - + const String & arg1_literal = arg1.literal().string(); String time_zone_val = DateTimeUtil::convertTimeZone(arg1_literal); - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); auto nullable_string_type = DB::makeNullable(std::make_shared()); const auto * time_zone_node = addColumnToActionsDAG(actions_dag, nullable_string_type, time_zone_val); const auto * result_node = toFunctionNode(actions_dag, getCHFunctionName(substrait_func), {parsed_args[0], time_zone_node});