diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 015c53a466bb8..d59a51881c9bc 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -35,7 +35,6 @@ #include #include #include -#include #include #include #include @@ -390,7 +389,11 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB } const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType( - DB::ActionsDAGPtr & actions_dag, const DB::ActionsDAG::Node * node, const std::string & type_name, const std::string & result_name) + DB::ActionsDAGPtr & actions_dag, + const DB::ActionsDAG::Node * node, + const std::string & type_name, + const std::string & result_name, + CastType cast_type) { DB::ColumnWithTypeAndName type_name_col; type_name_col.name = type_name; @@ -399,13 +402,9 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType( const auto * right_arg = &actions_dag->addColumn(std::move(type_name_col)); const auto * left_arg = node; DB::CastDiagnostic diagnostic = {node->result_name, node->result_name}; - auto type = CastType::nonAccurate; - - if (startsWith(type_name, "Nullable")) - type = CastType::accurateOrNull; - DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg}; - return &actions_dag->addFunction(DB::createInternalCastOverloadResolver(type, std::move(diagnostic)), std::move(children), result_name); + return &actions_dag->addFunction( + DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name); } String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 5766b1c802b8a..5ceaacf97a50d 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -99,7 +100,8 @@ class ActionsDAGUtil DB::ActionsDAGPtr & actions_dag, const DB::ActionsDAG::Node * node, const std::string & type_name, - const std::string & result_name = ""); + const std::string & result_name = "", + DB::CastType cast_type = DB::CastType::nonAccurate); }; class QueryPipelineUtil diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index a2c959b9324c9..c6c1e2b6462a0 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1038,14 +1038,28 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args) { - result_node = ActionsDAGUtil::convertNodeType( - actions_dag, - function_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, TypeParser::parseType(rel.scalar_function().output_type()))->getName() - : local_engine::removeNullable(TypeParser::parseType(rel.scalar_function().output_type()))->getName(), - function_node->result_name); + auto result_type = TypeParser::parseType(rel.scalar_function().output_type()); + if (isDecimalOrNullableDecimal(result_type)) + { + result_node = ActionsDAGUtil::convertNodeType( + actions_dag, + function_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), + function_node->result_name, + DB::CastType::accurateOrNull); + } + else + { + result_node = ActionsDAGUtil::convertNodeType( + actions_dag, + function_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), + function_node->result_name); + } } if (ch_func_name == "JSON_VALUE")