Skip to content

Commit

Permalink
[GLUTEN-5620][CORE] Simplify Decimal process logic (apache#5621)
Browse files Browse the repository at this point in the history
* rescaleCastForDecimal refactor

* refactor isPromoteCast

* Simplify Decimal process logic and re-implement FunctionParserDivide, so divide.cpp is deleted.

* remove SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs

* rename noCheckOverflow to dontTransformCheckOverflow

* update per comments

* fix warning

* fix style warning

* fix typo
  • Loading branch information
baibaichen authored May 8, 2024
1 parent 071d891 commit c9018cd
Show file tree
Hide file tree
Showing 10 changed files with 593 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
override def needOutputSchemaForPlan(): Boolean = true

override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss
override def transformCheckOverflow: Boolean = false

override def requiredInputFilePaths(): Boolean = true

Expand Down
63 changes: 47 additions & 16 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "CHUtil.h"
#include <filesystem>
#include <memory>
#include <optional>
#include <unistd.h>
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/registerAggregateFunctions.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnTuple.h>
#include <Columns/IColumn.h>
Expand All @@ -30,14 +32,17 @@
#include <Core/Defines.h>
#include <Core/NamesAndTypes.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/NestedUtils.h>
#include <Disks/registerDisks.h>
#include <Disks/registerGlutenDisks.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/registerFunctions.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/SharedThreadPools.h>
Expand All @@ -51,8 +56,11 @@
#include <Storages/Output/WriteBufferBuilder.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/ReadBufferBuilder.h>
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/wrappers.pb.h>
#include <sys/resource.h>
#include <Poco/Logger.h>
#include <Poco/Util/MapConfiguration.h>
#include <Common/BitHelpers.h>
Expand All @@ -63,20 +71,12 @@
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>

#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/predicate.hpp>

#include "CHUtil.h"
#include "Disks/registerGlutenDisks.h"

#include <unistd.h>
#include <sys/resource.h>

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int UNKNOWN_TYPE;
}
}

Expand Down Expand Up @@ -311,16 +311,48 @@ size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)

std::string PlanUtil::explainPlan(DB::QueryPlan & plan)
{
std::string plan_str;
DB::QueryPlan::ExplainPlanOptions buf_opt{
constexpr DB::QueryPlan::ExplainPlanOptions buf_opt{
.header = true,
.actions = true,
.indexes = true,
};
DB::WriteBufferFromOwnString buf;
plan.explainPlan(buf, buf_opt);
plan_str = buf.str();
return plan_str;

return buf.str();
}

void PlanUtil::checkOuputType(const DB::QueryPlan & plan)
{
// QueryPlan::checkInitialized is a private method, so we assume plan is initialized, otherwise there is a core dump here.
// It's okay, because it's impossible for us not to initialize where we call this method.
const auto & step = *plan.getRootNode()->step;
if (!step.hasOutputStream())
return;
if (!step.getOutputStream().header)
return;
for (const auto & elem : step.getOutputStream().header)
{
const DB::DataTypePtr & ch_type = elem.type;
const auto ch_type_without_nullable = DB::removeNullable(ch_type);
const DB::WhichDataType which(ch_type_without_nullable);
if (which.isDateTime64())
{
const auto * ch_type_datetime64 = checkAndGetDataType<DataTypeDateTime64>(ch_type_without_nullable.get());
if (ch_type_datetime64->getScale() != 6)
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
}
else if (which.isDecimal())
{
if (which.isDecimal256())
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());

const auto scale = getDecimalScale(*ch_type_without_nullable);
const auto precision = getDecimalPrecision(*ch_type_without_nullable);
if (scale == 0 && precision == 0)
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
}
}
}

NestedColumnExtractHelper::NestedColumnExtractHelper(const DB::Block & block_, bool case_insentive_)
Expand Down Expand Up @@ -713,7 +745,6 @@ void registerAllFunctions()
auto & factory = AggregateFunctionCombinatorFactory::instance();
registerAggregateFunctionCombinatorPartialMerge(factory);
}

}

void registerGlutenDisks()
Expand Down
16 changes: 9 additions & 7 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
*/
#pragma once
#include <filesystem>
#include <Columns/IColumn.h>
#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/NamesAndTypes.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Functions/CastOverloadResolver.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/Context.h>
#include <Processors/Chunk.h>
#include <Storages/IStorage.h>
#include <base/types.h>
#include <Common/CurrentThread.h>
#include <Common/logger_useful.h>

namespace DB
{
class QueryPipeline;
class QueryPlan;
}

namespace local_engine
{
Expand Down Expand Up @@ -96,10 +98,10 @@ class NestedColumnExtractHelper
const DB::ColumnWithTypeAndName * findColumn(const DB::Block & block, const std::string & name) const;
};

class PlanUtil
namespace PlanUtil
{
public:
static std::string explainPlan(DB::QueryPlan & plan);
std::string explainPlan(DB::QueryPlan & plan);
void checkOuputType(const DB::QueryPlan & plan);
};

class ActionsDAGUtil
Expand Down
82 changes: 7 additions & 75 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
auto pos = function_signature.find(':');
auto func_name = function_signature.substr(0, pos);

auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this);
if (func_parser)
if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this))
{
LOG_DEBUG(
&Poco::Logger::get("SerializedPlanParser"),
Expand Down Expand Up @@ -971,13 +970,12 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
args = std::move(new_args);
}

bool converted_decimal_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, args, scalar_function);
auto function_builder = FunctionFactory::instance().get(ch_func_name, context);
std::string args_name = join(args, ',');
result_name = ch_func_name + "(" + args_name + ")";
const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name);
result_node = function_node;
if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args)
if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type))
{
auto result_type = TypeParser::parseType(rel.scalar_function().output_type());
if (isDecimalOrNullableDecimal(result_type))
Expand Down Expand Up @@ -1014,76 +1012,6 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
return result_node;
}

bool SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
ActionsDAGPtr actions_dag,
ActionsDAG::NodeRawConstPtrs & args,
const substrait::Expression_ScalarFunction & arithmeticFun)
{
auto function_signature = function_mapping.at(std::to_string(arithmeticFun.function_reference()));
auto pos = function_signature.find(':');
auto func_name = function_signature.substr(0, pos);

if (func_name == "divide" || func_name == "multiply" || func_name == "plus" || func_name == "minus")
{
/// for divide/plus/minus, we need to convert first arg to result precision and scale
/// for multiply, we need to convert first arg to result precision, but keep scale
auto arg1_type = removeNullable(args[0]->result_type);
auto arg2_type = removeNullable(args[1]->result_type);
if (isDecimal(arg1_type) && isDecimal(arg2_type))
{
UInt32 p1 = getDecimalPrecision(*arg1_type);
UInt32 s1 = getDecimalScale(*arg1_type);
UInt32 p2 = getDecimalPrecision(*arg2_type);
UInt32 s2 = getDecimalScale(*arg2_type);

UInt32 precision;
UInt32 scale;

if (func_name == "plus" || func_name == "minus")
{
scale = s1;
precision = scale + std::max(p1 - s1, p2 - s2) + 1;
}
else if (func_name == "divide")
{
scale = std::max(static_cast<UInt32>(6), s1 + p2 + 1);
precision = p1 - s1 + s2 + scale;
}
else // multiply
{
scale = s1;
precision = p1 + p2 + 1;
}

UInt32 maxPrecision = DataTypeDecimal256::maxPrecision();
UInt32 maxScale = DataTypeDecimal128::maxPrecision();
precision = std::min(precision, maxPrecision);
scale = std::min(scale, maxScale);

ActionsDAG::NodeRawConstPtrs new_args;
new_args.reserve(args.size());

ActionsDAG::NodeRawConstPtrs cast_args;
cast_args.reserve(2);
cast_args.emplace_back(args[0]);
DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale);
ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type);
String type_name = ch_type->getName();
DataTypePtr str_type = std::make_shared<DataTypeString>();
const ActionsDAG::Node * type_node = &actions_dag->addColumn(
ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name)));
cast_args.emplace_back(type_node);
const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args);
actions_dag->addOrReplaceInOutputs(*cast_node);
new_args.emplace_back(cast_node);
new_args.emplace_back(args[1]);
args = std::move(new_args);
return true;
}
}
return false;
}

void SerializedPlanParser::parseFunctionArguments(
ActionsDAGPtr & actions_dag,
ActionsDAG::NodeRawConstPtrs & parsed_args,
Expand Down Expand Up @@ -1835,11 +1763,15 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan)

auto res = parse(std::move(plan_ptr));

#ifndef NDEBUG
PlanUtil::checkOuputType(*res);
#endif

auto * logger = &Poco::Logger::get("SerializedPlanParser");
if (logger->debug())
{
auto out = PlanUtil::explainPlan(*res);
LOG_DEBUG(logger, "clickhouse plan:\n{}", out);
LOG_ERROR(logger, "clickhouse plan:\n{}", out);
}
return res;
}
Expand Down
12 changes: 3 additions & 9 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,17 @@
#pragma once

#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/SortDescription.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Interpreters/Aggregator.h>
#include <Parser/CHColumnToSparkRow.h>
#include <Parser/RelMetric.h>
#include <Processors/Executors/PullingPipelineExecutor.h>
#include <Processors/Formats/Impl/CHColumnToArrowColumn.h>
#include <Processors/QueryPlan/ISourceStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/Pipe.h>
#include <Storages/CustomStorageMergeTree.h>
#include <Storages/IStorage.h>
#include <Storages/SourceFromJavaIter.h>
#include <arrow/ipc/writer.h>
#include <base/types.h>
#include <substrait/plan.pb.h>
#include <Common/BlockIterator.h>
Expand Down Expand Up @@ -301,9 +296,6 @@ class SerializedPlanParser

static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function);

bool convertBinaryArithmeticFunDecimalArgs(
ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun);

IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns);

static ContextMutablePtr global_context;
Expand Down Expand Up @@ -383,7 +375,6 @@ class SerializedPlanParser
void wrapNullable(
const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names);
static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field);
const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field);

int name_no = 0;
std::unordered_map<std::string, std::string> function_mapping;
Expand All @@ -395,6 +386,9 @@ class SerializedPlanParser
// for parse rel node, collect steps from a rel node
std::vector<IQueryPlanStep *> temp_step_collection;
std::vector<RelMetricPtr> metrics;

public:
const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field);
};

struct SparkBuffer
Expand Down
Loading

0 comments on commit c9018cd

Please sign in to comment.