Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-1632][CH]Daily Update Clickhouse Version (20240814) #6844

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -937,4 +937,37 @@ class GlutenClickHouseNativeWriteTableSuite
_ => {})
)
}

test("GLUTEN-2584: fix native write and read mismatch about complex types") {
def table(format: String): String = s"t_$format"
def create(format: String, table_name: Option[String] = None): String =
s"""CREATE TABLE ${table_name.getOrElse(table(format))}(
| id INT,
| info STRUCT<name:STRING, age:INT>,
| data MAP<STRING, INT>,
| values ARRAY<INT>
|) stored as $format""".stripMargin
def insert(format: String, table_name: Option[String] = None): String =
s"""INSERT overwrite ${table_name.getOrElse(table(format))} VALUES
| (6, null, null, null);
""".stripMargin

nativeWrite2(
format => (table(format), create(format), insert(format)),
(table_name, format) => {
val vanilla_table = s"${table_name}_v"
val vanilla_create = create(format, Some(vanilla_table))
vanillaWrite {
withDestinationTable(vanilla_table, Option(vanilla_create)) {
checkInsertQuery(insert(format, Some(vanilla_table)), checkNative = false)
}
}
val rowsFromOriginTable =
spark.sql(s"select * from $vanilla_table").collect()
val dfFromWriteTable =
spark.sql(s"select * from $table_name")
checkAnswer(dfFromWriteTable, rowsFromOriginTable)
}
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ trait NativeWriteChecker
}
}

def vanillaWrite(block: => Unit): Unit = {
withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) {
block
}
}

def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)(
block: => Unit): Unit = {
withSQLConf(pairs: _*) {
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/clickhouse.version
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CH_ORG=Kyligence
CH_BRANCH=rebase_ch/20240809
CH_COMMIT=01e780d46d9
CH_BRANCH=rebase_ch/20240814
CH_COMMIT=176802ed082

Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ createAggregateFunctionBloomFilter(const std::string & name, const DataTypes & a
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be Int64 or UInt64", name);

if ((type == Field::Types::Int64 && parameters[i].get<Int64>() < 0))
if ((type == Field::Types::Int64 && parameters[i].safeGet<Int64>() < 0))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be non-negative number", name);

return parameters[i].get<UInt64>();
return parameters[i].safeGet<UInt64>();
};

filter_size = get_parameter(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argu
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);

bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
Expand Down
15 changes: 9 additions & 6 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Parser/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
#include <Storages/Cache/CacheManager.h>
#include <Storages/Output/WriteBufferBuilder.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/ReadBufferBuilder.h>
Expand All @@ -72,7 +74,6 @@
#include <Common/LoggerExtend.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include <Storages/Cache/CacheManager.h>

namespace DB
{
Expand Down Expand Up @@ -463,20 +464,22 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB
const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const DataTypePtr & cast_to_type,
const std::string & result_name,
CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
type_name_col.name = type_name;
type_name_col.name = cast_to_type->getName();
type_name_col.column = DB::DataTypeString().createColumnConst(0, type_name_col.name);
type_name_col.type = std::make_shared<DB::DataTypeString>();
const auto * right_arg = &actions_dag.addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
ColumnWithTypeAndName left_column{nullptr, node->result_type, {}};
DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
return &actions_dag.addFunction(
DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name);
auto func_base_cast = createInternalCast(std::move(left_column), cast_to_type, cast_type, diagnostic);

return &actions_dag.addFunction(func_base_cast, std::move(children), result_name);
}

const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
Expand All @@ -489,7 +492,7 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
if (node->result_type->equals(*dst_type))
return node;

return convertNodeType(actions_dag, node, dst_type->getName(), result_name, cast_type);
return convertNodeType(actions_dag, node, dst_type, result_name, cast_type);
}

String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class ActionsDAGUtil
public:
static const DB::ActionsDAG::Node * convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const DB::ActionsDAG::Node * node_to_cast,
const DB::DataTypePtr & cast_to_type,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct LambdaLess
auto compare_res_col = lambda_->reduce();
DB::Field field;
compare_res_col.column->get(0, field);
return field.get<Int32>() < 0;
return field.safeGet<Int32>() < 0;
}
private:
ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ template <typename To>
Field convertNumericType(const Field & from)
{
if (from.getType() == Field::Types::UInt64)
return convertNumericTypeImpl<UInt64, To>(from.get<UInt64>());
return convertNumericTypeImpl<UInt64, To>(from.safeGet<UInt64>());
if (from.getType() == Field::Types::Int64)
return convertNumericTypeImpl<Int64, To>(from.get<Int64>());
return convertNumericTypeImpl<Int64, To>(from.safeGet<Int64>());
if (from.getType() == Field::Types::UInt128)
return convertNumericTypeImpl<UInt128, To>(from.get<UInt128>());
return convertNumericTypeImpl<UInt128, To>(from.safeGet<UInt128>());
if (from.getType() == Field::Types::Int128)
return convertNumericTypeImpl<Int128, To>(from.get<Int128>());
return convertNumericTypeImpl<Int128, To>(from.safeGet<Int128>());
if (from.getType() == Field::Types::UInt256)
return convertNumericTypeImpl<UInt256, To>(from.get<UInt256>());
return convertNumericTypeImpl<UInt256, To>(from.safeGet<UInt256>());
if (from.getType() == Field::Types::Int256)
return convertNumericTypeImpl<Int256, To>(from.get<Int256>());
return convertNumericTypeImpl<Int256, To>(from.safeGet<Int256>());

throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch. Expected: Integer. Got: {}", from.getType());
}
Expand All @@ -81,7 +81,7 @@ inline UInt32 extractArgument(const ColumnWithTypeAndName & named_column)
throw Exception(
ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow, precision/scale value must in UInt32", named_column.type->getName());
}
return static_cast<UInt32>(to.get<UInt32>());
return static_cast<UInt32>(to.safeGet<UInt32>());
}

}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionFloor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class SparkFunctionFloor : public DB::FunctionFloor
if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");

Int64 scale64 = scale_field.get<Int64>();
Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 < std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large");

Expand Down
40 changes: 20 additions & 20 deletions cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,42 +101,42 @@ class SparkFunctionAnyHash : public IFunction
if (which.isNothing())
return seed;
else if (which.isUInt8())
return applyNumber<UInt8>(field.get<UInt8>(), seed);
return applyNumber<UInt8>(field.safeGet<UInt8>(), seed);
else if (which.isUInt16())
return applyNumber<UInt16>(field.get<UInt16>(), seed);
return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isUInt32())
return applyNumber<UInt32>(field.get<UInt32>(), seed);
return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isUInt64())
return applyNumber<UInt64>(field.get<UInt64>(), seed);
return applyNumber<UInt64>(field.safeGet<UInt64>(), seed);
else if (which.isInt8())
return applyNumber<Int8>(field.get<Int8>(), seed);
return applyNumber<Int8>(field.safeGet<Int8>(), seed);
else if (which.isInt16())
return applyNumber<Int16>(field.get<Int16>(), seed);
return applyNumber<Int16>(field.safeGet<Int16>(), seed);
else if (which.isInt32())
return applyNumber<Int32>(field.get<Int32>(), seed);
return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isInt64())
return applyNumber<Int64>(field.get<Int64>(), seed);
return applyNumber<Int64>(field.safeGet<Int64>(), seed);
else if (which.isFloat32())
return applyNumber<Float32>(field.get<Float32>(), seed);
return applyNumber<Float32>(field.safeGet<Float32>(), seed);
else if (which.isFloat64())
return applyNumber<Float64>(field.get<Float64>(), seed);
return applyNumber<Float64>(field.safeGet<Float64>(), seed);
else if (which.isDate())
return applyNumber<UInt16>(field.get<UInt16>(), seed);
return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isDate32())
return applyNumber<Int32>(field.get<Int32>(), seed);
return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isDateTime())
return applyNumber<UInt32>(field.get<UInt32>(), seed);
return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isDateTime64())
return applyDecimal<DateTime64>(field.get<DateTime64>(), seed);
return applyDecimal<DateTime64>(field.safeGet<DateTime64>(), seed);
else if (which.isDecimal32())
return applyDecimal<Decimal32>(field.get<Decimal32>(), seed);
return applyDecimal<Decimal32>(field.safeGet<Decimal32>(), seed);
else if (which.isDecimal64())
return applyDecimal<Decimal64>(field.get<Decimal64>(), seed);
return applyDecimal<Decimal64>(field.safeGet<Decimal64>(), seed);
else if (which.isDecimal128())
return applyDecimal<Decimal128>(field.get<Decimal128>(), seed);
return applyDecimal<Decimal128>(field.safeGet<Decimal128>(), seed);
else if (which.isStringOrFixedString())
{
const String & str = field.get<String>();
const String & str = field.safeGet<String>();
return applyUnsafeBytes(str.data(), str.size(), seed);
}
else if (which.isTuple())
Expand All @@ -145,7 +145,7 @@ class SparkFunctionAnyHash : public IFunction
assert(tuple_type);

const auto & elements = tuple_type->getElements();
const Tuple & tuple = field.get<Tuple>();
const Tuple & tuple = field.safeGet<Tuple>();
assert(tuple.size() == elements.size());

for (size_t i = 0; i < elements.size(); ++i)
Expand All @@ -160,7 +160,7 @@ class SparkFunctionAnyHash : public IFunction
assert(array_type);

const auto & nested_type = array_type->getNestedType();
const Array & array = field.get<Array>();
const Array & array = field.safeGet<Array>();
for (size_t i=0; i < array.size(); ++i)
{
seed = applyGeneric(array[i], seed, nested_type);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ namespace
else
return false;
}
result = static_cast<ToNativeType>(convert_to.get<ToNativeType>());
result = static_cast<ToNativeType>(convert_to.safeGet<ToNativeType>());

ToNativeType pow10 = intExp10OfSize<ToNativeType>(precision_value);
if ((result < 0 && result <= -pow10) || result >= pow10)
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class FunctionRoundingHalfUp : public IFunction
if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");

Int64 scale64 = scale_field.get<Int64>();
Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 < std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large");

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SparkFunctionConvertToDateTime : public IFunction

Field field;
named_column.column->get(0, field);
return static_cast<UInt32>(field.get<UInt32>());
return static_cast<UInt32>(field.safeGet<UInt32>());
}

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Operator/ExpandTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void ExpandTransform::work()

if (kind == EXPAND_FIELD_KIND_SELECTION)
{
const auto & original_col = original_cols.at(field.get<Int32>());
const auto & original_col = original_cols.at(field.safeGet<Int32>());
if (type->isNullable() == original_col->isNullable())
{
cols.push_back(original_col);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
actions_dag, func_node, TypeParser::parseType(output_type), func_node->result_name);
actions_dag.addOrReplaceInOutputs(*func_node);
}

Expand Down
Loading
Loading