diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala index 9b414d6f8cd08..9f0554db6b6f5 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHFormatWriterInjects.scala @@ -26,11 +26,10 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.types.StructType +import io.substrait.proto.{NamedStruct, Type} import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.TaskAttemptContext -import io.substrait.proto.{NamedStruct, Type} - trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase { override def createOutputWriter( @@ -40,21 +39,20 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase { nativeConf: java.util.Map[String, String]): OutputWriter = { val originPath = path val datasourceJniWrapper = new CHDatasourceJniWrapper(); - // scalastyle:off println - println(s"xxx dataSchema:$dataSchema") - // scalastyle:on println val namedStructBuilder = NamedStruct.newBuilder - for (name <- dataSchema.fieldNames) { - namedStructBuilder.addNames(name) + val structBuilder = Type.Struct.newBuilder + for (field <- dataSchema.fields) { + namedStructBuilder.addNames(field.name) + structBuilder.addTypes(ConverterUtils.getTypeNode(field.dataType, field.nullable).toProtobuf) } - val structNode = ConverterUtils.getTypeNode(dataSchema, nullable = true) - namedStructBuilder.setStruct(structNode.toProtobuf.asInstanceOf[Type.Struct]) + namedStructBuilder.setStruct(structBuilder.build) + var namedStruct = namedStructBuilder.build val instance = datasourceJniWrapper.nativeInitFileWriterWrapper( path, - namedStructBuilder.build.toByteArray, + namedStruct.toByteArray, getFormatName()); new OutputWriter { diff --git a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp index 2d371dfd4d714..46edb7f30d5b8 100644 --- a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp +++ b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp @@ -15,7 +15,6 @@ * limitations under the License. */ #include "FileWriterWrappers.h" -#include namespace local_engine { @@ -28,7 +27,6 @@ NormalFileWriter::NormalFileWriter(const OutputFormatFilePtr & file_, const DB:: { } - void NormalFileWriter::consume(DB::Block & block) { if (!writer) [[unlikely]] @@ -39,6 +37,22 @@ void NormalFileWriter::consume(DB::Block & block) writer = std::make_unique(*pipeline); } + /// In case input block didn't have the same types as the preferred schema, we cast the input block to the preferred schema. + /// Notice that preferred_schema is the actual file schema, which is also the data schema of current inserted table. + /// Refer to issue: https://github.com/apache/incubator-gluten/issues/6588 + size_t index = 0; + const auto & preferred_schema = file->getPreferredSchema(); + for (auto & column : block) + { + if (column.name.starts_with("__bucket_value__")) + continue; + + const auto & preferred_column = preferred_schema.getByPosition(index++); + column.column = DB::castColumn(column, preferred_column.type); + column.name = preferred_column.name; + column.type = preferred_column.type; + } + /// Although gluten will append MaterializingTransform to the end of the pipeline before native insert in most cases, there are some cases in which MaterializingTransform won't be appended. /// e.g. https://github.com/oap-project/gluten/issues/2900 /// So we need to do materialize here again to make sure all blocks passed to native writer are all materialized. @@ -55,7 +69,7 @@ void NormalFileWriter::close() } OutputFormatFilePtr createOutputFormatFile( - const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & header, const std::string & format_hint) + const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint) { // the passed in file_uri is exactly what is expected to see in the output folder // e.g /xxx/中文/timestamp_field=2023-07-13 03%3A00%3A17.622/abc.parquet @@ -64,13 +78,13 @@ OutputFormatFilePtr createOutputFormatFile( Poco::URI::encode(file_uri, "", encoded); // encode the space and % seen in the file_uri Poco::URI poco_uri(encoded); auto write_buffer_builder = WriteBufferBuilderFactory::instance().createBuilder(poco_uri.getScheme(), context); - return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, header, format_hint); + return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, preferred_schema, format_hint); } std::unique_ptr createFileWriterWrapper( - const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & header, const std::string & format_hint) + const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint) { - return std::make_unique(createOutputFormatFile(context, file_uri, header, format_hint), context); + return std::make_unique(createOutputFormatFile(context, file_uri, preferred_schema, format_hint), context); } } diff --git a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h index 68cca1b2e473d..736f5a95f6bd1 100644 --- a/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h +++ b/cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h @@ -41,6 +41,7 @@ class FileWriterWrapper public: explicit FileWriterWrapper(const OutputFormatFilePtr & file_) : file(file_) { } virtual ~FileWriterWrapper() = default; + virtual void consume(DB::Block & block) = 0; virtual void close() = 0; @@ -53,10 +54,9 @@ using FileWriterWrapperPtr = std::shared_ptr; class NormalFileWriter : public FileWriterWrapper { public: - //TODO: EmptyFileReader and ConstColumnsFileReader ? - //TODO: to support complex types NormalFileWriter(const OutputFormatFilePtr & file_, const DB::ContextPtr & context_); ~NormalFileWriter() override = default; + void consume(DB::Block & block) override; void close() override; @@ -71,13 +71,13 @@ class NormalFileWriter : public FileWriterWrapper std::unique_ptr createFileWriterWrapper( const DB::ContextPtr & context, const std::string & file_uri, - const DB::Block & header, + const DB::Block & preferred_schema, const std::string & format_hint); -static OutputFormatFilePtr createOutputFormatFile( +OutputFormatFilePtr createOutputFormatFile( const DB::ContextPtr & context, const std::string & file_uri, - const DB::Block & header, + const DB::Block & preferred_schema, const std::string & format_hint); class WriteStats : public DB::ISimpleTransform @@ -191,7 +191,7 @@ class SubstraitFileSink final : public SinkToStorage : SinkToStorage(header) , partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id) , relative_path_(relative) - , output_format_(createOutputFormatFile(context, makeFilename(base_path, partition_id, relative), header.getNames(), format_hint) + , output_format_(createOutputFormatFile(context, makeFilename(base_path, partition_id, relative), header, format_hint) ->createOutputFormat(header)) { } diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp index 5d456479a6649..1e8364c6dac2d 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp @@ -45,26 +45,24 @@ OutputFormatFile::OutputFormatFile( Block OutputFormatFile::creatHeaderWithPreferredSchema(const Block & header) { - if (preferred_schema) + if (!preferred_schema) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_schema is empty"); + + /// Create a new header with the preferred column name and type + DB::ColumnsWithTypeAndName columns; + columns.reserve(preferred_schema.columns()); + size_t index = 0; + for (const auto & name_type : header.getNamesAndTypesList()) { - /// Create a new header with the preferred column name and type - DB::ColumnsWithTypeAndName columns; - columns.reserve(preferred_schema.columns()); - size_t index = 0; - for (const auto & name_type : header.getNamesAndTypesList()) - { - if (name_type.name.starts_with("__bucket_value__")) - continue; + if (name_type.name.starts_with("__bucket_value__")) + continue; - const auto & preferred_column = preferred_schema.getByPosition(index++); - ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name); - columns.emplace_back(std::move(column)); - } - assert(preferred_column_names.size() == index); - return {std::move(columns)}; + const auto & preferred_column = preferred_schema.getByPosition(index++); + ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name); + columns.emplace_back(std::move(column)); } - else - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_schema is empty"); + assert(preferred_column_names.size() == index); + return {std::move(columns)}; } OutputFormatFilePtr OutputFormatFileUtil::createFile( diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h index 0064da8ae135f..93c26d7d188bc 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.h @@ -49,6 +49,8 @@ class OutputFormatFile virtual OutputFormatPtr createOutputFormat(const DB::Block & header_) = 0; + virtual const DB::Block getPreferredSchema() const { return preferred_schema; } + protected: DB::Block creatHeaderWithPreferredSchema(const DB::Block & header); diff --git a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h index fecc651b0c616..cc87da7da8542 100644 --- a/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h +++ b/cpp-ch/local-engine/Storages/Output/ParquetOutputFormatFile.h @@ -20,7 +20,6 @@ #if USE_PARQUET -#include #include #include diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 8ce98ec03f5bf..680957b3db3b5 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -867,33 +867,33 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHBlockWriterJniWrapper_nativeC } JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_nativeInitFileWriterWrapper( - JNIEnv * env, jobject, jstring file_uri_, jbyteArray data_schema_, jstring format_hint_) + JNIEnv * env, jobject, jstring file_uri_, jbyteArray preferred_schema_, jstring format_hint_) { LOCAL_ENGINE_JNI_METHOD_START - const auto data_schema_ref = local_engine::getByteArrayElementsSafe(env, data_schema_); - auto parse_data_shema = [&]() -> std::optional + const auto preferred_schema_ref = local_engine::getByteArrayElementsSafe(env, preferred_schema_); + auto parse_named_struct = [&]() -> std::optional { - std::string_view data_schema_view{ - reinterpret_cast(data_schema_ref.elems()), static_cast(data_schema_ref.length())}; + std::string_view view{ + reinterpret_cast(preferred_schema_ref.elems()), static_cast(preferred_schema_ref.length())}; substrait::NamedStruct res; - bool ok = res.ParseFromString(data_schema_view); + bool ok = res.ParseFromString(view); if (!ok) return {}; - return res; + return std::move(res); }; - auto named_struct = parse_data_shema(); + auto named_struct = parse_named_struct(); if (!named_struct.has_value()) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse schema from substrait protobuf failed"); - Block header = local_engine::TypeParser::buildBlockFromNamedStructWithoutDFS(*named_struct); + Block preferred_schema = local_engine::TypeParser::buildBlockFromNamedStructWithoutDFS(*named_struct); const auto file_uri = jstring2string(env, file_uri_); const auto format_hint = jstring2string(env, format_hint_); // for HiveFileFormat, the file url may not end with .parquet, so we pass in the format as a hint const auto context = DB::Context::createCopy(local_engine::SerializedPlanParser::global_context); - auto * writer = local_engine::createFileWriterWrapper(context, file_uri, names, format_hint).release(); + auto * writer = local_engine::createFileWriterWrapper(context, file_uri, preferred_schema, format_hint).release(); return reinterpret_cast(writer); LOCAL_ENGINE_JNI_METHOD_END(env, 0) }