Skip to content

Commit

Permalink
finish dev
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Aug 2, 2024
1 parent 6c0d34d commit 6badc77
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.types.StructType

import io.substrait.proto.{NamedStruct, Type}
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.TaskAttemptContext

import io.substrait.proto.{NamedStruct, Type}

trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {

override def createOutputWriter(
Expand All @@ -40,21 +39,20 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
nativeConf: java.util.Map[String, String]): OutputWriter = {
val originPath = path
val datasourceJniWrapper = new CHDatasourceJniWrapper();
// scalastyle:off println
println(s"xxx dataSchema:$dataSchema")
// scalastyle:on println

val namedStructBuilder = NamedStruct.newBuilder
for (name <- dataSchema.fieldNames) {
namedStructBuilder.addNames(name)
val structBuilder = Type.Struct.newBuilder
for (field <- dataSchema.fields) {
namedStructBuilder.addNames(field.name)
structBuilder.addTypes(ConverterUtils.getTypeNode(field.dataType, field.nullable).toProtobuf)
}
val structNode = ConverterUtils.getTypeNode(dataSchema, nullable = true)
namedStructBuilder.setStruct(structNode.toProtobuf.asInstanceOf[Type.Struct])
namedStructBuilder.setStruct(structBuilder.build)
var namedStruct = namedStructBuilder.build

val instance =
datasourceJniWrapper.nativeInitFileWriterWrapper(
path,
namedStructBuilder.build.toByteArray,
namedStruct.toByteArray,
getFormatName());

new OutputWriter {
Expand Down
26 changes: 20 additions & 6 deletions cpp-ch/local-engine/Storages/Output/FileWriterWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* limitations under the License.
*/
#include "FileWriterWrappers.h"
#include <algorithm>

namespace local_engine
{
Expand All @@ -28,7 +27,6 @@ NormalFileWriter::NormalFileWriter(const OutputFormatFilePtr & file_, const DB::
{
}


void NormalFileWriter::consume(DB::Block & block)
{
if (!writer) [[unlikely]]
Expand All @@ -39,6 +37,22 @@ void NormalFileWriter::consume(DB::Block & block)
writer = std::make_unique<DB::PushingPipelineExecutor>(*pipeline);
}

/// In case input block didn't have the same types as the preferred schema, we cast the input block to the preferred schema.
/// Notice that preferred_schema is the actual file schema, which is also the data schema of current inserted table.
/// Refer to issue: https://github.com/apache/incubator-gluten/issues/6588
size_t index = 0;
const auto & preferred_schema = file->getPreferredSchema();
for (auto & column : block)
{
if (column.name.starts_with("__bucket_value__"))
continue;

const auto & preferred_column = preferred_schema.getByPosition(index++);
column.column = DB::castColumn(column, preferred_column.type);
column.name = preferred_column.name;
column.type = preferred_column.type;
}

/// Although gluten will append MaterializingTransform to the end of the pipeline before native insert in most cases, there are some cases in which MaterializingTransform won't be appended.
/// e.g. https://github.com/oap-project/gluten/issues/2900
/// So we need to do materialize here again to make sure all blocks passed to native writer are all materialized.
Expand All @@ -55,7 +69,7 @@ void NormalFileWriter::close()
}

OutputFormatFilePtr createOutputFormatFile(
const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & header, const std::string & format_hint)
const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint)
{
// the passed in file_uri is exactly what is expected to see in the output folder
// e.g /xxx/中文/timestamp_field=2023-07-13 03%3A00%3A17.622/abc.parquet
Expand All @@ -64,13 +78,13 @@ OutputFormatFilePtr createOutputFormatFile(
Poco::URI::encode(file_uri, "", encoded); // encode the space and % seen in the file_uri
Poco::URI poco_uri(encoded);
auto write_buffer_builder = WriteBufferBuilderFactory::instance().createBuilder(poco_uri.getScheme(), context);
return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, header, format_hint);
return OutputFormatFileUtil::createFile(context, write_buffer_builder, encoded, preferred_schema, format_hint);
}

std::unique_ptr<FileWriterWrapper> createFileWriterWrapper(
const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & header, const std::string & format_hint)
const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint)
{
return std::make_unique<NormalFileWriter>(createOutputFormatFile(context, file_uri, header, format_hint), context);
return std::make_unique<NormalFileWriter>(createOutputFormatFile(context, file_uri, preferred_schema, format_hint), context);
}

}
12 changes: 6 additions & 6 deletions cpp-ch/local-engine/Storages/Output/FileWriterWrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class FileWriterWrapper
public:
explicit FileWriterWrapper(const OutputFormatFilePtr & file_) : file(file_) { }
virtual ~FileWriterWrapper() = default;

virtual void consume(DB::Block & block) = 0;
virtual void close() = 0;

Expand All @@ -53,10 +54,9 @@ using FileWriterWrapperPtr = std::shared_ptr<FileWriterWrapper>;
class NormalFileWriter : public FileWriterWrapper
{
public:
//TODO: EmptyFileReader and ConstColumnsFileReader ?
//TODO: to support complex types
NormalFileWriter(const OutputFormatFilePtr & file_, const DB::ContextPtr & context_);
~NormalFileWriter() override = default;

void consume(DB::Block & block) override;
void close() override;

Expand All @@ -71,13 +71,13 @@ class NormalFileWriter : public FileWriterWrapper
std::unique_ptr<FileWriterWrapper> createFileWriterWrapper(
const DB::ContextPtr & context,
const std::string & file_uri,
const DB::Block & header,
const DB::Block & preferred_schema,
const std::string & format_hint);

static OutputFormatFilePtr createOutputFormatFile(
OutputFormatFilePtr createOutputFormatFile(
const DB::ContextPtr & context,
const std::string & file_uri,
const DB::Block & header,
const DB::Block & preferred_schema,
const std::string & format_hint);

class WriteStats : public DB::ISimpleTransform
Expand Down Expand Up @@ -191,7 +191,7 @@ class SubstraitFileSink final : public SinkToStorage
: SinkToStorage(header)
, partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id)
, relative_path_(relative)
, output_format_(createOutputFormatFile(context, makeFilename(base_path, partition_id, relative), header.getNames(), format_hint)
, output_format_(createOutputFormatFile(context, makeFilename(base_path, partition_id, relative), header, format_hint)
->createOutputFormat(header))
{
}
Expand Down
32 changes: 15 additions & 17 deletions cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,24 @@ OutputFormatFile::OutputFormatFile(

Block OutputFormatFile::creatHeaderWithPreferredSchema(const Block & header)
{
if (preferred_schema)
if (!preferred_schema)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_schema is empty");

/// Create a new header with the preferred column name and type
DB::ColumnsWithTypeAndName columns;
columns.reserve(preferred_schema.columns());
size_t index = 0;
for (const auto & name_type : header.getNamesAndTypesList())
{
/// Create a new header with the preferred column name and type
DB::ColumnsWithTypeAndName columns;
columns.reserve(preferred_schema.columns());
size_t index = 0;
for (const auto & name_type : header.getNamesAndTypesList())
{
if (name_type.name.starts_with("__bucket_value__"))
continue;
if (name_type.name.starts_with("__bucket_value__"))
continue;

const auto & preferred_column = preferred_schema.getByPosition(index++);
ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name);
columns.emplace_back(std::move(column));
}
assert(preferred_column_names.size() == index);
return {std::move(columns)};
const auto & preferred_column = preferred_schema.getByPosition(index++);
ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name);
columns.emplace_back(std::move(column));
}
else
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "preferred_schema is empty");
assert(preferred_column_names.size() == index);
return {std::move(columns)};
}

OutputFormatFilePtr OutputFormatFileUtil::createFile(
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Storages/Output/OutputFormatFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class OutputFormatFile

virtual OutputFormatPtr createOutputFormat(const DB::Block & header_) = 0;

virtual const DB::Block getPreferredSchema() const { return preferred_schema; }

protected:
DB::Block creatHeaderWithPreferredSchema(const DB::Block & header);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

#if USE_PARQUET

#include <memory>
#include <IO/WriteBuffer.h>
#include <Storages/Output/OutputFormatFile.h>

Expand Down
20 changes: 10 additions & 10 deletions cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,33 +867,33 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHBlockWriterJniWrapper_nativeC
}

JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_nativeInitFileWriterWrapper(
JNIEnv * env, jobject, jstring file_uri_, jbyteArray data_schema_, jstring format_hint_)
JNIEnv * env, jobject, jstring file_uri_, jbyteArray preferred_schema_, jstring format_hint_)
{
LOCAL_ENGINE_JNI_METHOD_START

const auto data_schema_ref = local_engine::getByteArrayElementsSafe(env, data_schema_);
auto parse_data_shema = [&]() -> std::optional<substrait::NamedStruct>
const auto preferred_schema_ref = local_engine::getByteArrayElementsSafe(env, preferred_schema_);
auto parse_named_struct = [&]() -> std::optional<substrait::NamedStruct>
{
std::string_view data_schema_view{
reinterpret_cast<const char *>(data_schema_ref.elems()), static_cast<size_t>(data_schema_ref.length())};
std::string_view view{
reinterpret_cast<const char *>(preferred_schema_ref.elems()), static_cast<size_t>(preferred_schema_ref.length())};

substrait::NamedStruct res;
bool ok = res.ParseFromString(data_schema_view);
bool ok = res.ParseFromString(view);
if (!ok)
return {};
return res;
return std::move(res);
};

auto named_struct = parse_data_shema();
auto named_struct = parse_named_struct();
if (!named_struct.has_value())
throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse schema from substrait protobuf failed");

Block header = local_engine::TypeParser::buildBlockFromNamedStructWithoutDFS(*named_struct);
Block preferred_schema = local_engine::TypeParser::buildBlockFromNamedStructWithoutDFS(*named_struct);
const auto file_uri = jstring2string(env, file_uri_);
const auto format_hint = jstring2string(env, format_hint_);
// for HiveFileFormat, the file url may not end with .parquet, so we pass in the format as a hint
const auto context = DB::Context::createCopy(local_engine::SerializedPlanParser::global_context);
auto * writer = local_engine::createFileWriterWrapper(context, file_uri, names, format_hint).release();
auto * writer = local_engine::createFileWriterWrapper(context, file_uri, preferred_schema, format_hint).release();
return reinterpret_cast<jlong>(writer);
LOCAL_ENGINE_JNI_METHOD_END(env, 0)
}
Expand Down

0 comments on commit 6badc77

Please sign in to comment.