Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jul 9, 2024
1 parent a39411f commit a008e5b
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 35 deletions.
32 changes: 25 additions & 7 deletions cpp-ch/local-engine/Parser/FunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/TypeParser.h>
#include <Common/CHUtil.h>

Expand Down Expand Up @@ -70,13 +71,30 @@ const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded(
{
const auto & output_type = substrait_func.output_type();
if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
return ActionsDAGUtil::convertNodeType(
actions_dag,
func_node,
// as stated in isTypeMatched, currently we don't change nullability of the result type
func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))->getName()
: DB::removeNullable(TypeParser::parseType(output_type))->getName(),
func_node->result_name);
{
auto result_type = TypeParser::parseType(substrait_func.output_type());
if (DB::isDecimalOrNullableDecimal(result_type))
{
return ActionsDAGUtil::convertNodeType(
actions_dag,
func_node,
// as stated in isTypeMatched, currently we don't change nullability of the result type
func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName()
: local_engine::removeNullable(result_type)->getName(),
func_node->result_name,
CastType::accurateOrNull);
}
else
{
return ActionsDAGUtil::convertNodeType(
actions_dag,
func_node,
// as stated in isTypeMatched, currently we don't change nullability of the result type
func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))->getName()
: DB::removeNullable(TypeParser::parseType(output_type))->getName(),
func_node->result_name);
}
}
else
return func_node;
}
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
const auto & scalar_function = rel.scalar_function();

auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
auto function_name = getFunctionName(function_signature, scalar_function);
String function_name = "arrayJoin";

/// Whether the input argument of explode/posexplode is map type
bool is_map;
Expand Down Expand Up @@ -879,7 +879,7 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(

const auto & scalar_function = rel.scalar_function();
auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
auto function_name = getFunctionName(function_signature, scalar_function);
String function_name = "json_tuple";
auto args = scalar_function.arguments();
if (args.size() < 2)
{
Expand Down
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

namespace local_engine
{
static const std::set<std::string> FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"};

DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type);
DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type);
Expand Down Expand Up @@ -138,7 +137,7 @@ class SerializedPlanParser

IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns);
IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header);

static std::pair<DataTypePtr, Field> parseLiteral(const substrait::Expression_Literal & literal);

static ContextMutablePtr global_context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,9 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cast, cast, CAST);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GetTimestamp, get_timestamp, parseDateTimeInJodaSyntaxOrNull);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Quarter, quarter, toQuarter);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToUnixTimestamp, to_unix_timestamp, parseDateTimeInJodaSyntaxOrNull);
//REGISTER_COMMON_SCALAR_FUNCTION_PARSER(DateFormat, date_format, formatDateTimeInJodaSyntax);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(TimestampAdd, timestamp_add, timestamp_add);

REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Substract, substract, minus);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Multiply, multiply, multiply);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Add, add, plus);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Divide, divide, divide);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Position, positive, identity);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Negative, negative, negate);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Modulus, modulus, modulo);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pmod, pmod, pmod);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(abs, abs, abs);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ceil, ceil, ceil);
Expand Down Expand Up @@ -97,7 +90,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitwiseXor, bitwise_xor, bitXor);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitGet, bit_get, bitTest);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(BitCount, bit_count, bitCount);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sqrt, sqrt, sqrt);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cbrc, cbrc, cbrt);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Cbrc, cbrt, cbrt);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Degrees, degrees, degrees);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(E, e, e);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pi, pi, pi);
Expand All @@ -110,7 +103,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(CheckOverflow, check_overflow, checkDecimalOverflowSpark);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint);
Expand All @@ -120,7 +112,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(NotLike, not_like, notLike);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(StartsWith, starts_with, startsWithUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(EndsWith, ends_with, endsWithUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Contains, contains, countSubstrings);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Substring, substring, substringUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(SubstringIndex, substring_index, substringIndexUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Lower, lower, lowerUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Upper, upper, upperUTF8);
Expand All @@ -130,7 +121,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(RegexpReplace, regexp_replace, replaceReg
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(RegexpExtractAll, regexp_extract_all, regexpExtractAllSpark);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rlike, rlike, match);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ascii, ascii, ascii);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ConcatWs, concat_ws, concat_ws);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Base64, base64, base64Encode);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Unbase64, unbase64, base64Decode);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Lpad, lpad, leftPadUTF8);
Expand Down Expand Up @@ -166,6 +156,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Array, array, array);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);

// map functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map);
Expand All @@ -177,11 +168,9 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MapFromArrays, map_from_arrays, mapFromAr

// json functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FlattenJsonStringOnRequired, flattenJSONStringOnRequired, flattenJSONStringOnRequired);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(GetJsonObject, get_json_object, get_json_object);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToJson, to_json, toJSONString);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(JsonTuple, json_tuple, json_tuple);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(JsonArrayLen, json_array_length, JSONArrayLength);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(MakeDecimal, make_decimal, makeDecimalSpark);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(UnscaledValue, unscaled_value, unscaleValueSpark);

// runtime filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class FunctionParserPlus final : public FunctionParserBinaryArithmetic

static constexpr auto name = "add";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "plus"; }

protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
Expand All @@ -206,6 +207,7 @@ class FunctionParserMinus final : public FunctionParserBinaryArithmetic

static constexpr auto name = "subtract";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "minus"; }

protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
Expand All @@ -220,6 +222,7 @@ class FunctionParserMultiply final : public FunctionParserBinaryArithmetic
explicit FunctionParserMultiply(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { }
static constexpr auto name = "multiply";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "multiply"; }

protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
Expand All @@ -234,6 +237,7 @@ class FunctionParserModulo final : public FunctionParserBinaryArithmetic
explicit FunctionParserModulo(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { }
static constexpr auto name = "modulus";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "modulo"; }

protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
Expand All @@ -248,6 +252,7 @@ class FunctionParserDivide final : public FunctionParserBinaryArithmetic
explicit FunctionParserDivide(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { }
static constexpr auto name = "divide";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "divide"; }

protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class FunctionParserArrayContains : public FunctionParser
String getName() const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
{
/**
parse array_contains(arr, value) as
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/scalar_function_parser/chr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class FunctionParserChr : public FunctionParser
String getName() const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
{
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FunctionParserConcat : public FunctionParser
static constexpr auto name = "concat";

String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FunctionParserConcatWS : public FunctionParser
static constexpr auto name = "concat_ws";

String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/scalar_function_parser/decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class FunctionParserDecode : public FunctionParser
String getName() const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
{
/// Parse decode(bin, charset) as convertCharset(bin, charset, 'UTF-8')
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/scalar_function_parser/encode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class FunctionParserEncode : public FunctionParser
String getName() const override { return name; }

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAGPtr & actions_dag) const override
{
/// Parse encode(str, charset) as convertCharset(str, 'UTF-8', charset)
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class SparkFunctionExtractParser : public FunctionParser
}
}
}
else

if (!func_node)
func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args);
return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FunctionParserTimestampAdd : public FunctionParser
static constexpr auto name = "timestamp_add";

String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "timestamp_add"; }

const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class FunctionParserUtcTimestampTransform : public FunctionParser
/// Convert timezone value to clickhouse backend supported, i.e. GMT+8 -> Etc/GMT-8, +08:00 -> Etc/GMT-8
if (substrait_func.arguments_size() != 2)
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s must have 2 arguments", getName());

const substrait::Expression & arg1 = substrait_func.arguments()[1].value();
if (!arg1.has_literal() || !arg1.literal().has_string())
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s 2nd argument should be string literal", getName());

const String & arg1_literal = arg1.literal().string();
String time_zone_val = DateTimeUtil::convertTimeZone(arg1_literal);
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
Expand Down

0 comments on commit a008e5b

Please sign in to comment.