diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9f5dc4d3ca8c..652b15fc2da0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -937,4 +937,37 @@ class GlutenClickHouseNativeWriteTableSuite _ => {}) ) } + + test("GLUTEN-2584: fix native write and read mismatch about complex types") { + def table(format: String): String = s"t_$format" + def create(format: String, table_name: Option[String] = None): String = + s"""CREATE TABLE ${table_name.getOrElse(table(format))}( + | id INT, + | info STRUCT, + | data MAP, + | values ARRAY + |) stored as $format""".stripMargin + def insert(format: String, table_name: Option[String] = None): String = + s"""INSERT overwrite ${table_name.getOrElse(table(format))} VALUES + | (6, null, null, null); + """.stripMargin + + nativeWrite2( + format => (table(format), create(format), insert(format)), + (table_name, format) => { + val vanilla_table = s"${table_name}_v" + val vanilla_create = create(format, Some(vanilla_table)) + vanillaWrite { + withDestinationTable(vanilla_table, Option(vanilla_create)) { + checkInsertQuery(insert(format, Some(vanilla_table)), checkNative = false) + } + } + val rowsFromOriginTable = + spark.sql(s"select * from $vanilla_table").collect() + val dfFromWriteTable = + spark.sql(s"select * from $table_name") + checkAnswer(dfFromWriteTable, rowsFromOriginTable) + } + ) + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala index 4bee3f1771a9..49e368c888e7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala @@ -75,6 +75,12 @@ trait NativeWriteChecker } } + def vanillaWrite(block: => Unit): Unit = { + withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) { + block + } + } + def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)( block: => Unit): Unit = { withSQLConf(pairs: _*) { diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 8068b57f24c6..7c93bc1240ce 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,4 +1,3 @@ CH_ORG=Kyligence -CH_BRANCH=rebase_ch/20240809 -CH_COMMIT=01e780d46d9 - +CH_BRANCH=rebase_ch/20240815 +CH_COMMIT=d87dbba64fc diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp index 5555302a5c2f..1b853cc67c69 100644 --- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp @@ -62,10 +62,10 @@ createAggregateFunctionBloomFilter(const std::string & name, const DataTypes & a if (type != Field::Types::Int64 && type != Field::Types::UInt64) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be Int64 or UInt64", name); - if ((type == Field::Types::Int64 && parameters[i].get() < 0)) + if ((type == Field::Types::Int64 && parameters[i].safeGet() < 0)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be non-negative number", name); - return parameters[i].get(); + return parameters[i].safeGet(); }; filter_size = get_parameter(0); diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp index 5eb3a0b36057..0aa233145728 100644 --- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp @@ -140,7 +140,7 @@ createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argu throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); - bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get(); + bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet(); const UInt32 p1 = DB::getDecimalPrecision(*data_type); const UInt32 s1 = DB::getDecimalScale(*data_type); auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 35b4f0c97806..0409b66bd920 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -51,11 +51,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -72,7 +74,6 @@ #include #include #include -#include namespace DB { @@ -463,20 +464,22 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType( DB::ActionsDAG & actions_dag, const DB::ActionsDAG::Node * node, - const std::string & type_name, + const DataTypePtr & cast_to_type, const std::string & result_name, CastType cast_type) { DB::ColumnWithTypeAndName type_name_col; - type_name_col.name = type_name; + type_name_col.name = cast_to_type->getName(); type_name_col.column = DB::DataTypeString().createColumnConst(0, type_name_col.name); type_name_col.type = std::make_shared(); const auto * right_arg = &actions_dag.addColumn(std::move(type_name_col)); const auto * left_arg = node; DB::CastDiagnostic diagnostic = {node->result_name, node->result_name}; + ColumnWithTypeAndName left_column{nullptr, node->result_type, {}}; DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg}; - return &actions_dag.addFunction( - DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name); + auto func_base_cast = createInternalCast(std::move(left_column), cast_to_type, cast_type, diagnostic); + + return &actions_dag.addFunction(func_base_cast, std::move(children), result_name); } const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded( @@ -489,7 +492,7 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded( if (node->result_type->equals(*dst_type)) return node; - return convertNodeType(actions_dag, node, dst_type->getName(), result_name, cast_type); + return convertNodeType(actions_dag, node, dst_type, result_name, cast_type); } String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 785d5d6c0056..c91b7264db31 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -128,8 +128,8 @@ class ActionsDAGUtil public: static const DB::ActionsDAG::Node * convertNodeType( DB::ActionsDAG & actions_dag, - const DB::ActionsDAG::Node * node, - const std::string & type_name, + const DB::ActionsDAG::Node * node_to_cast, + const DB::DataTypePtr & cast_to_type, const std::string & result_name = "", DB::CastType cast_type = DB::CastType::nonAccurate); diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp index 1371ec60e179..cf9d67f1696b 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp @@ -60,7 +60,7 @@ struct LambdaLess auto compare_res_col = lambda_->reduce(); DB::Field field; compare_res_col.column->get(0, field); - return field.get() < 0; + return field.safeGet() < 0; } private: ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h index 32bf79a563a7..e501c7fc5ffb 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h @@ -50,17 +50,17 @@ template Field convertNumericType(const Field & from) { if (from.getType() == Field::Types::UInt64) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); if (from.getType() == Field::Types::Int64) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); if (from.getType() == Field::Types::UInt128) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); if (from.getType() == Field::Types::Int128) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); if (from.getType() == Field::Types::UInt256) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); if (from.getType() == Field::Types::Int256) - return convertNumericTypeImpl(from.get()); + return convertNumericTypeImpl(from.safeGet()); throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch. Expected: Integer. Got: {}", from.getType()); } @@ -81,7 +81,7 @@ inline UInt32 extractArgument(const ColumnWithTypeAndName & named_column) throw Exception( ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow, precision/scale value must in UInt32", named_column.type->getName()); } - return static_cast(to.get()); + return static_cast(to.safeGet()); } } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h index ce33d11dbd8c..4a3f99a9a356 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h @@ -197,7 +197,7 @@ class SparkFunctionFloor : public DB::FunctionFloor if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type"); - Int64 scale64 = scale_field.get(); + Int64 scale64 = scale_field.safeGet(); if (scale64 > std::numeric_limits::max() || scale64 < std::numeric_limits::min()) throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large"); diff --git a/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h b/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h index 57bf00ba9904..c6499031492e 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h @@ -101,42 +101,42 @@ class SparkFunctionAnyHash : public IFunction if (which.isNothing()) return seed; else if (which.isUInt8()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isUInt16()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isUInt32()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isUInt64()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isInt8()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isInt16()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isInt32()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isInt64()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isFloat32()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isFloat64()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isDate()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isDate32()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isDateTime()) - return applyNumber(field.get(), seed); + return applyNumber(field.safeGet(), seed); else if (which.isDateTime64()) - return applyDecimal(field.get(), seed); + return applyDecimal(field.safeGet(), seed); else if (which.isDecimal32()) - return applyDecimal(field.get(), seed); + return applyDecimal(field.safeGet(), seed); else if (which.isDecimal64()) - return applyDecimal(field.get(), seed); + return applyDecimal(field.safeGet(), seed); else if (which.isDecimal128()) - return applyDecimal(field.get(), seed); + return applyDecimal(field.safeGet(), seed); else if (which.isStringOrFixedString()) { - const String & str = field.get(); + const String & str = field.safeGet(); return applyUnsafeBytes(str.data(), str.size(), seed); } else if (which.isTuple()) @@ -145,7 +145,7 @@ class SparkFunctionAnyHash : public IFunction assert(tuple_type); const auto & elements = tuple_type->getElements(); - const Tuple & tuple = field.get(); + const Tuple & tuple = field.safeGet(); assert(tuple.size() == elements.size()); for (size_t i = 0; i < elements.size(); ++i) @@ -160,7 +160,7 @@ class SparkFunctionAnyHash : public IFunction assert(array_type); const auto & nested_type = array_type->getNestedType(); - const Array & array = field.get(); + const Array & array = field.safeGet(); for (size_t i=0; i < array.size(); ++i) { seed = applyGeneric(array[i], seed, nested_type); diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp index 231856b0288f..795e2b0be329 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp @@ -205,7 +205,7 @@ namespace else return false; } - result = static_cast(convert_to.get()); + result = static_cast(convert_to.safeGet()); ToNativeType pow10 = intExp10OfSize(precision_value); if ((result < 0 && result <= -pow10) || result >= pow10) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h index 441842d4e7e1..0bd28b116d9a 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h @@ -271,7 +271,7 @@ class FunctionRoundingHalfUp : public IFunction if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type"); - Int64 scale64 = scale_field.get(); + Int64 scale64 = scale_field.safeGet(); if (scale64 > std::numeric_limits::max() || scale64 < std::numeric_limits::min()) throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large"); diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h index 980af85bd983..aab8aabc3a8d 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h @@ -128,7 +128,7 @@ class SparkFunctionConvertToDateTime : public IFunction Field field; named_column.column->get(0, field); - return static_cast(field.get()); + return static_cast(field.safeGet()); } DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override diff --git a/cpp-ch/local-engine/Operator/ExpandTransform.cpp b/cpp-ch/local-engine/Operator/ExpandTransform.cpp index 5100ad070638..29e254bc01a7 100644 --- a/cpp-ch/local-engine/Operator/ExpandTransform.cpp +++ b/cpp-ch/local-engine/Operator/ExpandTransform.cpp @@ -110,7 +110,7 @@ void ExpandTransform::work() if (kind == EXPAND_FIELD_KIND_SELECTION) { - auto index = field.get(); + auto index = field.safeGet(); const auto & input_column = input_columns[index]; DB::ColumnWithTypeAndName input_arg; diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp index f976d50ad3b2..b843d1565fce 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp @@ -155,7 +155,7 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded( if (need_convert_type) { func_node = ActionsDAGUtil::convertNodeType( - actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name); + actions_dag, func_node, TypeParser::parseType(output_type), func_node->result_name); actions_dag.addOrReplaceInOutputs(*func_node); } diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp index 3d5a7731bffb..602cd3d6837e 100644 --- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp @@ -501,7 +501,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field & field) const if (which.isStringOrFixedString()) { - const auto & str = field.get(); + const auto & str = field.safeGet(); return roundNumberOfBytesToNearestWord(str.size()); } @@ -511,7 +511,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field & field) const if (which.isArray()) { /// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) | values(每个值长度与类型有关) | backing buffer - const auto & array = field.get(); /// Array can not be wrapped with Nullable + const auto & array = field.safeGet(); /// Array can not be wrapped with Nullable const auto num_elems = array.size(); int64_t res = 8 + calculateBitSetWidthInBytes(num_elems); @@ -531,7 +531,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field & field) const int64_t res = 8; /// Construct Array of keys and values from Map - const auto & map = field.get(); /// Map can not be wrapped with Nullable + const auto & map = field.safeGet(); /// Map can not be wrapped with Nullable const auto num_keys = map.size(); auto array_key = Array(); auto array_val = Array(); @@ -539,7 +539,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field & field) const array_val.reserve(num_keys); for (size_t i = 0; i < num_keys; ++i) { - const auto & pair = map[i].get(); + const auto & pair = map[i].safeGet(); array_key.push_back(pair[0]); array_val.push_back(pair[1]); } @@ -561,7 +561,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field & field) const if (which.isTuple()) { /// 内存布局:null_bitmap(字节数与字段数成正比) | field1 value(8B) | field2 value(8B) | ... | fieldn value(8B) | backing buffer - const auto & tuple = field.get(); /// Tuple can not be wrapped with Nullable + const auto & tuple = field.safeGet(); /// Tuple can not be wrapped with Nullable const auto * type_tuple = typeid_cast(type_without_nullable.get()); const auto & type_fields = type_tuple->getElements(); const auto num_fields = type_fields.size(); @@ -687,30 +687,7 @@ int64_t VariableLengthDataWriter::writeArray(size_t row_idx, const DB::Array & a bitSet(buffer_address + offset + start + 8, i); else { - if (writer.getWhichDataType().isFloat32()) - { - // We can not use get() directly here to process Float32 field, - // because it will get 8 byte data, but Float32 is 4 byte, which will cause error conversion. - auto v = static_cast(elem.get()); - writer.unsafeWrite( - reinterpret_cast(&v), buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); - } - else if (writer.getWhichDataType().isFloat64()) - { - // Fix 'Invalid Field get from type Float64 to type Int64' in debug build. - auto v = elem.get(); - writer.unsafeWrite(reinterpret_cast(&v), buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); - } - else if (writer.getWhichDataType().isDecimal32()) - { - // We can not use get() directly here to process Decimal32 field, - // because it will get 4 byte data, but Decimal32 is 8 byte in Spark, which will cause error conversion. - writer.write(elem, buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); - } - else - writer.unsafeWrite( - reinterpret_cast(&elem.get()), - buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); + writer.write(elem, buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size); } } } @@ -754,7 +731,7 @@ int64_t VariableLengthDataWriter::writeMap(size_t row_idx, const DB::Map & map, val_array.reserve(num_pairs); for (size_t i = 0; i < num_pairs; ++i) { - const auto & pair = map[i].get(); + const auto & pair = map[i].safeGet(); key_array.push_back(pair[0]); val_array.push_back(pair[1]); } @@ -812,27 +789,7 @@ int64_t VariableLengthDataWriter::writeStruct(size_t row_idx, const DB::Tuple & if (BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(field_type))) { FixedLengthDataWriter writer(field_type); - if (writer.getWhichDataType().isFloat32()) - { - // We can not use get() directly here to process Float32 field, - // because it will get 8 byte data, but Float32 is 4 byte, which will cause error conversion. - auto v = static_cast(field_value.get()); - writer.unsafeWrite(reinterpret_cast(&v), buffer_address + offset + start + len_null_bitmap + i * 8); - } - else if (writer.getWhichDataType().isFloat64()) - { - // Fix 'Invalid Field get from type Float64 to type Int64' in debug build. - auto v = field_value.get(); - writer.unsafeWrite(reinterpret_cast(&v), buffer_address + offset + start + len_null_bitmap + i * 8); - } - else if (writer.getWhichDataType().isDecimal64() || writer.getWhichDataType().isDateTime64()) - { - auto v = field_value.get(); - writer.unsafeWrite(reinterpret_cast(&v), buffer_address + offset + start + len_null_bitmap + i * 8); - } - else - writer.unsafeWrite( - reinterpret_cast(&field_value.get()), buffer_address + offset + start + len_null_bitmap + i * 8); + writer.write(field_value, buffer_address + offset + start + len_null_bitmap + i * 8); } else { @@ -853,7 +810,7 @@ int64_t VariableLengthDataWriter::write(size_t row_idx, const DB::Field & field, if (which.isStringOrFixedString()) { - const auto & str = field.get(); + const auto & str = field.safeGet(); return writeUnalignedBytes(row_idx, str.data(), str.size(), parent_offset); } @@ -868,19 +825,19 @@ int64_t VariableLengthDataWriter::write(size_t row_idx, const DB::Field & field, if (which.isArray()) { - const auto & array = field.get(); + const auto & array = field.safeGet(); return writeArray(row_idx, array, parent_offset); } if (which.isMap()) { - const auto & map = field.get(); + const auto & map = field.safeGet(); return writeMap(row_idx, map, parent_offset); } if (which.isTuple()) { - const auto & tuple = field.get(); + const auto & tuple = field.safeGet(); return writeStruct(row_idx, tuple, parent_offset); } @@ -926,64 +883,64 @@ void FixedLengthDataWriter::write(const DB::Field & field, char * buffer) if (which.isUInt8()) { - const auto value = UInt8(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 1); } else if (which.isUInt16() || which.isDate()) { - const auto value = UInt16(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 2); } else if (which.isUInt32() || which.isDate32()) { - const auto value = UInt32(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 4); } else if (which.isUInt64()) { - const auto & value = field.get(); + const auto & value = field.safeGet(); memcpy(buffer, &value, 8); } else if (which.isInt8()) { - const auto value = Int8(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 1); } else if (which.isInt16()) { - const auto value = Int16(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 2); } else if (which.isInt32()) { - const auto value = Int32(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 4); } else if (which.isInt64()) { - const auto & value = field.get(); + const auto & value = field.safeGet(); memcpy(buffer, &value, 8); } else if (which.isFloat32()) { - const auto value = Float32(field.get()); + const auto value = static_cast(field.safeGet()); memcpy(buffer, &value, 4); } else if (which.isFloat64()) { - const auto & value = field.get(); + const auto & value = field.safeGet(); memcpy(buffer, &value, 8); } else if (which.isDecimal32()) { - const auto & value = field.get(); + const auto & value = field.safeGet(); const Int64 decimal = static_cast(value.getValue()); memcpy(buffer, &decimal, 8); } else if (which.isDecimal64() || which.isDateTime64()) { - const auto & value = field.get(); - auto decimal = value.getValue(); + const auto & value = field.safeGet(); + const auto decimal = value.getValue(); memcpy(buffer, &decimal, 8); } else diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp b/cpp-ch/local-engine/Parser/FunctionParser.cpp index d46110431ab4..a875da275501 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp @@ -80,8 +80,8 @@ const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded( actions_dag, func_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type) + : local_engine::removeNullable(result_type), func_node->result_name, CastType::accurateOrNull); } @@ -91,8 +91,8 @@ const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded( actions_dag, func_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type))->getName() - : DB::removeNullable(TypeParser::parseType(output_type))->getName(), + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type)) + : DB::removeNullable(TypeParser::parseType(output_type)), func_node->result_name); } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index d0924a745716..297551bcccc2 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -285,7 +285,10 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrai if (rel.has_local_files()) local_files = rel.local_files(); else + { local_files = BinaryToMessage(split_infos.at(nextSplitInfoIndex())); + logDebugMessage(local_files, "local_files"); + } auto source = std::make_shared(context, header, local_files); auto source_pipe = Pipe(source); auto source_step = std::make_unique(context, std::move(source_pipe), "substrait local files"); @@ -496,7 +499,10 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list if (read.has_extension_table()) extension_table = read.extension_table(); else + { extension_table = BinaryToMessage(split_infos.at(nextSplitInfoIndex())); + logDebugMessage(extension_table, "extension_table"); + } MergeTreeRelParser mergeTreeParser(this, context); query_plan = mergeTreeParser.parseReadRel(std::make_unique(), read, extension_table); @@ -689,7 +695,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( /// pos = cast(arrayJoin(arg_not_null).1, "Int32") const auto * pos_node = add_tuple_element(array_join_node, 1); - pos_node = ActionsDAGUtil::convertNodeType(actions_dag, pos_node, "Int32"); + pos_node = ActionsDAGUtil::convertNodeType(actions_dag, pos_node, INT()); /// if is_map is false, output col = arrayJoin(arg_not_null).2 /// if is_map is true, output (key, value) = arrayJoin(arg_not_null).2 @@ -772,7 +778,7 @@ std::pair SerializedPlanParser::convertStructFieldType(const #define UINT_CONVERT(type_ptr, field, type_name) \ if ((type_ptr)->getTypeId() == TypeIndex::type_name) \ { \ - return {std::make_shared(), static_cast((field).get()) + 1}; \ + return {std::make_shared(), static_cast((field).safeGet()) + 1}; \ } auto type_id = type->getTypeId(); diff --git a/cpp-ch/local-engine/Parser/WriteRelParser.cpp b/cpp-ch/local-engine/Parser/WriteRelParser.cpp index 9b6226adbed8..1a468a41eef2 100644 --- a/cpp-ch/local-engine/Parser/WriteRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WriteRelParser.cpp @@ -137,12 +137,12 @@ void addSinkTransfrom(const DB::ContextPtr & context, const substrait::WriteRel DB::Field field_tmp_dir; if (!settings.tryGet(SPARK_TASK_WRITE_TMEP_DIR, field_tmp_dir)) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline need inject temp directory."); - const auto & tmp_dir = field_tmp_dir.get(); + const auto & tmp_dir = field_tmp_dir.safeGet(); DB::Field field_filename; if (!settings.tryGet(SPARK_TASK_WRITE_FILENAME, field_filename)) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline need inject file name."); - const auto & filename = field_filename.get(); + const auto & filename = field_filename.safeGet(); assert(write_rel.has_named_table()); const substrait::NamedObjectWrite & named_table = write_rel.named_table(); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp index 237da650c8e1..ceddbd2aef80 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp @@ -98,7 +98,7 @@ DB::Array ApproxPercentileParser::parseFunctionParameters( if (isArray(type2)) { /// Multiple percentages for quantilesGK - const Array & percentags = field2.get(); + const Array & percentags = field2.safeGet(); for (const auto & percentage : percentags) params.emplace_back(percentage); } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp index 8788abb6dcf7..10bf0b09482e 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp @@ -63,8 +63,8 @@ DB::Array AggregateFunctionParserBloomFilterAgg::parseFunctionParameters( node->column->get(0, ret); return ret; }; - Int64 insert_num = get_parameter_field(arg_nodes[1], 1).get(); - Int64 bits_num = get_parameter_field(arg_nodes[2], 2).get(); + Int64 insert_num = get_parameter_field(arg_nodes[1], 1).safeGet(); + Int64 bits_num = get_parameter_field(arg_nodes[2], 2).safeGet(); // Delete all args except the first arg. arg_nodes.resize(1); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp index 6d0075705c44..536aec1b60f4 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace local_engine @@ -41,7 +42,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Act node = ActionsDAGUtil::convertNodeType( actions_dag, &actions_dag.findInOutputs(arg0_col_name), - DB::makeNullable(arg0_col_type)->getName(), + DB::makeNullable(arg0_col_type), arg0_col_name); actions_dag.addOrReplaceInOutputs(*node); args.push_back(node); @@ -52,7 +53,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Act } node = parseExpression(actions_dag, arg1); - node = ActionsDAGUtil::convertNodeType(actions_dag, node, DB::DataTypeInt64().getName()); + node = ActionsDAGUtil::convertNodeType(actions_dag, node, BIGINT()); actions_dag.addOrReplaceInOutputs(*node); args.push_back(node); @@ -84,7 +85,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Acti node = ActionsDAGUtil::convertNodeType( actions_dag, &actions_dag.findInOutputs(arg0_col_name), - DB::makeNullable(arg0_col_type)->getName(), + makeNullable(arg0_col_type), arg0_col_name); actions_dag.addOrReplaceInOutputs(*node); args.push_back(node); @@ -100,7 +101,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Acti auto real_field = 0 - literal_result.second.safeGet(); node = &actions_dag.addColumn(ColumnWithTypeAndName( literal_result.first->createColumnConst(1, real_field), literal_result.first, getUniqueName(toString(real_field)))); - node = ActionsDAGUtil::convertNodeType(actions_dag, node, DB::DataTypeInt64().getName()); + node = ActionsDAGUtil::convertNodeType(actions_dag, node, BIGINT()); actions_dag.addOrReplaceInOutputs(*node); args.push_back(node); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp index 62f83223c06f..1a24e320609e 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp @@ -32,7 +32,7 @@ NtileParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Ac auto [data_type, field] = parseLiteral(arg0.literal()); if (!(DB::WhichDataType(data_type).isInt32())) throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must be i32"); - Int32 field_index = static_cast(field.get()); + Int32 field_index = static_cast(field.safeGet()); // For CH, the data type of the args[0] must be the UInt32 const auto * index_node = addColumnToActionsDAG(actions_dag, std::make_shared(), field_index); args.emplace_back(index_node); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp index a475a1efb367..aa82b33a7a3c 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,7 @@ class ArrayFilter : public FunctionParser /// filter with index argument. const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})}); range_end_node = ActionsDAGUtil::convertNodeType( - actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name); + actions_dag, range_end_node, makeNullable(INT()), range_end_node->result_name); const auto * index_array_node = toFunctionNode( actions_dag, "range", @@ -106,7 +107,7 @@ class ArrayTransform : public FunctionParser /// transform with index argument. const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})}); range_end_node = ActionsDAGUtil::convertNodeType( - actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name); + actions_dag, range_end_node, makeNullable(INT()), range_end_node->result_name); const auto * index_array_node = toFunctionNode( actions_dag, "range", @@ -141,7 +142,7 @@ class ArrayAggregate : public FunctionParser parsed_args[1] = ActionsDAGUtil::convertNodeType( actions_dag, parsed_args[1], - function_type->getReturnType()->getName(), + function_type->getReturnType(), parsed_args[1]->result_name); } @@ -215,14 +216,14 @@ class ArraySort : public FunctionParser if (!var_expr.has_literal()) return false; auto [_, name] = plan_parser->parseLiteral(var_expr.literal()); - return var == name.get(); + return var == name.safeGet(); }; auto is_int_value = [&](const substrait::Expression & expr, Int32 val) { if (!expr.has_literal()) return false; auto [_, x] = plan_parser->parseLiteral(expr.literal()); - return val == x.get(); + return val == x.safeGet(); }; auto is_variable_null = [&](const substrait::Expression & expr, const String & var) { diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp index 1fda3d8fa753..b0ade35a3590 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp @@ -86,7 +86,7 @@ class FunctionParserArrayPosition : public FunctionParser DataTypePtr wrap_arr_nullable_type = wrapNullableType(true, ch_function_node->result_type); const auto * wrap_index_of_node = ActionsDAGUtil::convertNodeType( - actions_dag, ch_function_node, wrap_arr_nullable_type->getName(), ch_function_node->result_name); + actions_dag, ch_function_node, wrap_arr_nullable_type, ch_function_node->result_name); const auto * null_const_node = addColumnToActionsDAG(actions_dag, wrap_arr_nullable_type, Field{}); const auto * or_condition_node = toFunctionNode(actions_dag, "or", {arr_is_null_node, val_is_null_node}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp index 992235cd9a0b..accc6d418b9f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp @@ -74,7 +74,7 @@ class FunctionParserElt : public FunctionParser auto nullable_result_type = makeNullable(result_type); const auto * nullable_array_element_node = ActionsDAGUtil::convertNodeType( - actions_dag, array_element_node, nullable_result_type->getName(), array_element_node->result_name); + actions_dag, array_element_node, nullable_result_type, array_element_node->result_name); const auto * null_const_node = addColumnToActionsDAG(actions_dag, nullable_result_type, Field()); const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {index_arg}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp index ca9fb372c2fd..96fedc6fe646 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace DB @@ -73,9 +74,9 @@ class FunctionParserFindInSet : public FunctionParser if (!str_is_nullable && !str_array_is_nullable) return convertNodeTypeIfNeeded(substrait_func, index_of_node, actions_dag); - auto nullable_result_type = makeNullable(std::make_shared()); + auto nullable_result_type = makeNullable(INT()); const auto * nullable_index_of_node = ActionsDAGUtil::convertNodeType( - actions_dag, index_of_node, nullable_result_type->getName(), index_of_node->result_name); + actions_dag, index_of_node, nullable_result_type, index_of_node->result_name); const auto * null_const_node = addColumnToActionsDAG(actions_dag, nullable_result_type, Field()); const auto * str_is_null_node = toFunctionNode(actions_dag, "isNull", {str_arg}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp index 547ffd971fcd..c2841564e8c3 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -43,7 +43,7 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p && plan_parser_.getFunctionSignatureName(arg.value().scalar_function().function_reference()) == "namedlambdavariable") { auto [_, col_name_field] = plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal()); - String col_name = col_name_field.get(); + String col_name = col_name_field.safeGet(); if (collected_names.contains(col_name)) { continue; @@ -187,7 +187,7 @@ class NamedLambdaVariable : public FunctionParser const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override { auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); - String col_name = col_name_field.get(); + String col_name = col_name_field.safeGet(); auto type = TypeParser::parseType(substrait_func.output_type()); const auto & inputs = actions_dag.getInputs(); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp index b948daeda0ea..17115895eaff 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp @@ -17,6 +17,7 @@ #include #include +#include #include namespace DB @@ -50,7 +51,7 @@ class FunctionParserLocate : public FunctionParser const auto * substr_arg = parsed_args[0]; const auto * str_arg = parsed_args[1]; - const auto * start_pos_arg = ActionsDAGUtil::convertNodeType(actions_dag, parsed_args[2], "Nullable(UInt32)"); + const auto * start_pos_arg = ActionsDAGUtil::convertNodeType(actions_dag, parsed_args[2], makeNullable(UINT())); const auto * is_start_pos_null_node = toFunctionNode(actions_dag, "isNull", {start_pos_arg}); const auto * const_1_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 0); const auto * position_node = toFunctionNode(actions_dag, "positionUTF8Spark", {str_arg, substr_arg, start_pos_arg}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp index ada91f8537fe..74254911a0a0 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -42,8 +43,7 @@ class SparkFunctionRepeatParser : public FunctionParser const auto & args = substrait_func.arguments(); parsed_args.emplace_back(parseExpression(actions_dag, args[0].value())); const auto * repeat_times_node = parseExpression(actions_dag, args[1].value()); - DB::DataTypeNullable target_type(std::make_shared()); - repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName()); + repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, makeNullable(UINT())); parsed_args.emplace_back(repeat_times_node); const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp index 2643207354ae..a96dca8efe4d 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp @@ -89,7 +89,7 @@ class FunctionParserArraySlice : public FunctionParser DataTypePtr wrap_arr_nullable_type = wrapNullableType(true, slice_node->result_type); const auto * wrap_slice_node = ActionsDAGUtil::convertNodeType( - actions_dag, slice_node, wrap_arr_nullable_type->getName(), slice_node->result_name); + actions_dag, slice_node, wrap_arr_nullable_type, slice_node->result_name); const auto * null_const_node = addColumnToActionsDAG(actions_dag, wrap_arr_nullable_type, Field{}); const auto * arr_is_null_node = toFunctionNode(actions_dag, "isNull", {arr_arg}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp index 179aa7860484..4809cc887b8d 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp @@ -45,7 +45,7 @@ namespace local_engine auto [data_type, field] = parseLiteral(args[1].value().literal()); \ if (!DB::WhichDataType(data_type).isInt32()) \ throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{}'s second argument must be i32", #substrait_name); \ - Int32 field_index = static_cast(field.get() + 1); \ + Int32 field_index = static_cast(field.safeGet() + 1); \ const auto * index_node = addColumnToActionsDAG(actions_dag, std::make_shared(), field_index); \ parsed_args.emplace_back(index_node); \ const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args); \ diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp index 6fee65efe593..93f4374d4ce1 100644 --- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp +++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp @@ -71,16 +71,16 @@ SparkMergeTreeWriter::SparkMergeTreeWriter( , thread_pool(CurrentMetrics::LocalThread, CurrentMetrics::LocalThreadActive, CurrentMetrics::LocalThreadScheduled, 1, 1, 100000) { const DB::Settings & settings = context->getSettingsRef(); - merge_after_insert = settings.get(MERGETREE_MERGE_AFTER_INSERT).get(); - insert_without_local_storage = settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).get(); + merge_after_insert = settings.get(MERGETREE_MERGE_AFTER_INSERT).safeGet(); + insert_without_local_storage = settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).safeGet(); Field limit_size_field; if (settings.tryGet("optimize.minFileSize", limit_size_field)) - merge_min_size = limit_size_field.get() <= 0 ? merge_min_size : limit_size_field.get(); + merge_min_size = limit_size_field.safeGet() <= 0 ? merge_min_size : limit_size_field.safeGet(); Field limit_cnt_field; if (settings.tryGet("mergetree.max_num_part_per_merge_task", limit_cnt_field)) - merge_limit_parts = limit_cnt_field.get() <= 0 ? merge_limit_parts : limit_cnt_field.get(); + merge_limit_parts = limit_cnt_field.safeGet() <= 0 ? merge_limit_parts : limit_cnt_field.safeGet(); dest_storage = MergeTreeRelParser::parseStorage(merge_tree_table, SerializedPlanParser::global_context); isRemoteStorage = dest_storage->getStoragePolicy()->getAnyDisk()->isRemote(); diff --git a/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h b/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h index 89e83e668aeb..312cea7efc0a 100644 --- a/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h +++ b/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h @@ -38,9 +38,9 @@ struct ToParquet T as(const DB::Field & value, const parquet::ColumnDescriptor &) { if constexpr (std::is_same_v) - return static_cast(value.get()); + return static_cast(value.safeGet()); // parquet::BooleanType, parquet::Int64Type, parquet::FloatType, parquet::DoubleType - return value.get(); // FLOAT, DOUBLE, INT64 + return value.safeGet(); // FLOAT, DOUBLE, INT64 } }; @@ -51,7 +51,7 @@ struct ToParquet T as(const DB::Field & value, const parquet::ColumnDescriptor &) { assert(value.getType() == DB::Field::Types::String); - const std::string & s = value.get(); + const std::string & s = value.safeGet(); const auto * const ptr = reinterpret_cast(s.data()); return parquet::ByteArray(static_cast(s.size()), ptr); } @@ -74,7 +74,7 @@ struct ToParquet "descriptor.type_length() = {} , which is > {}, e.g. sizeof(Int128)", descriptor.type_length(), sizeof(Int128)); - Int128 val = value.get>().getValue(); + Int128 val = value.safeGet>().getValue(); std::reverse(reinterpret_cast(&val), reinterpret_cast(&val) + sizeof(val)); const int offset = sizeof(Int128) - descriptor.type_length(); memcpy(buf, reinterpret_cast(&val) + offset, descriptor.type_length()); diff --git a/cpp-ch/local-engine/tests/data/68135.snappy.parquet b/cpp-ch/local-engine/tests/data/68135.snappy.parquet new file mode 100644 index 000000000000..ddd627790cd1 Binary files /dev/null and b/cpp-ch/local-engine/tests/data/68135.snappy.parquet differ diff --git a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp index 6352a819927c..9e4165d90437 100644 --- a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp +++ b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp @@ -63,7 +63,7 @@ TEST(Clickhouse, PR54881) Field field; const auto & col_1 = *(block.getColumns()[1]); col_1.get(0, field); - const Tuple & row_0 = field.get(); + const Tuple & row_0 = field.safeGet(); EXPECT_EQ(2, row_0.size()); Int64 actual{-1}; @@ -74,7 +74,7 @@ TEST(Clickhouse, PR54881) EXPECT_EQ(10, actual); col_1.get(1, field); - const Tuple & row_1 = field.get(); + const Tuple & row_1 = field.safeGet(); EXPECT_EQ(2, row_1.size()); EXPECT_TRUE(row_1[0].tryGet(actual)); EXPECT_EQ(10, actual); @@ -96,4 +96,24 @@ TEST(Clickhouse, PR65234) const auto plan = local_engine::JsonStringToMessage( {reinterpret_cast(gresource_embedded_pr_65234_jsonData), gresource_embedded_pr_65234_jsonSize}); auto query_plan = parser.parse(plan); +} + +INCBIN(resource_embedded_pr_68135_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/clickhouse_pr_68135.json"); +TEST(Clickhouse, PR68135) +{ + const std::string split_template + = R"({"items":[{"uriFile":"{replace_local_files}","partitionIndex":"0","length":"461","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; + const std::string split + = replaceLocalFilesWildcards(split_template, GLUTEN_DATA_DIR("/utils/extern-local-engine/tests/data/68135.snappy.parquet")); + + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(local_engine::JsonStringToBinary(split)); + + const auto plan = local_engine::JsonStringToMessage( + {reinterpret_cast(gresource_embedded_pr_68135_jsonData), gresource_embedded_pr_68135_jsonSize}); + + auto local_executor = parser.createExecutor(plan); + EXPECT_TRUE(local_executor->hasNext()); + const Block & x = *local_executor->nextColumnar(); + debug::headBlock(x); } \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json b/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json new file mode 100644 index 000000000000..c8b49857c79a --- /dev/null +++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json @@ -0,0 +1,160 @@ +{ + "relations": [ + { + "root": { + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "decimal": { + "scale": 2, + "precision": 9, + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "columnTypes": [ + "NORMAL_COL" + ] + }, + "filter": { + "singularOrList": { + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + }, + "options": [ + { + "literal": { + "decimal": { + "value": "yAAAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "LAEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "kAEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "9AEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + } + ] + } + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree=0\n" + } + } + } + }, + "condition": { + "singularOrList": { + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + }, + "options": [ + { + "literal": { + "decimal": { + "value": "yAAAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "LAEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "kAEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + }, + { + "literal": { + "decimal": { + "value": "9AEAAAAAAAAAAAAAAAAAAA==", + "precision": 9, + "scale": 2 + } + } + } + ] + } + } + } + }, + "names": [ + "a#26" + ], + "outputSchema": { + "types": [ + { + "decimal": { + "scale": 2, + "precision": 9, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + } + } + ] +} \ No newline at end of file