Skip to content

Commit

Permalink
Add NativeOutputWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Sep 10, 2024
1 parent 614e4e4 commit 866a60a
Show file tree
Hide file tree
Showing 23 changed files with 791 additions and 386 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down Expand Up @@ -172,7 +171,7 @@ object OptimizeTableCommandOverwrites extends Logging {
bucketNum: String,
bin: Seq[AddFile],
maxFileSize: Long): Seq[FileAction] = {
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog);
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog)

val sparkSession = SparkSession.getActiveSession.get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

public class CHDatasourceJniWrapper {

public native long nativeInitFileWriterWrapper(
String filePath, byte[] preferredSchema, String formatHint);
public native void write(long instanceId, long blockAddress);

public native String close(long instanceId);

/// FileWriter
public native long createFilerWriter(String filePath, byte[] preferredSchema, String formatHint);

public native long nativeInitMergeTreeWriterWrapper(
byte[] plan,
/// MergeTreeWriter
public native long createMergeTreeWriter(
byte[] splitInfo,
String uuid,
String taskId,
Expand All @@ -31,43 +35,28 @@ public native long nativeInitMergeTreeWriterWrapper(
byte[] confArray);

public native String nativeMergeMTParts(
byte[] plan,
byte[] splitInfo,
String uuid,
String taskId,
String partition_dir,
String bucket_dir);
byte[] splitInfo, String uuid, String taskId, String partition_dir, String bucket_dir);

public static native String filterRangesOnDriver(byte[] plan, byte[] read);

public native void write(long instanceId, long blockAddress);

public native void writeToMergeTree(long instanceId, long blockAddress);

public native void close(long instanceId);

public native String closeMergeTreeWriter(long instanceId);

/*-
/**
* The input block is already sorted by partition columns + bucket expressions. (check
* org.apache.spark.sql.execution.datasources.FileFormatWriter#write)
* However, the input block may contain parts(we call it stripe here) belonging to
* different partition/buckets.
* org.apache.spark.sql.execution.datasources.FileFormatWriter#write) However, the input block may
* contain parts(we call it stripe here) belonging to different partition/buckets.
*
* If bucketing is enabled, the input block's last column is guaranteed to be _bucket_value_.
* <p>If bucketing is enabled, the input block's last column is guaranteed to be _bucket_value_.
*
* This function splits the input block in to several blocks, each of which belonging
* to the same partition/bucket. Notice the stripe will NOT contain partition columns
* <p>This function splits the input block in to several blocks, each of which belonging to the
* same partition/bucket. Notice the stripe will NOT contain partition columns
*
* Since all rows in a stripe share the same partition/bucket,
* we only need to check the heading row.
* So, for each stripe, the native code also returns each stripe's first row's index.
* Caller can use these indice to get UnsafeRows from the input block,
* to help FileFormatDataWriter to aware partition/bucket changes.
* <p>Since all rows in a stripe share the same partition/bucket, we only need to check the
* heading row. So, for each stripe, the native code also returns each stripe's first row's index.
* Caller can use these indices to get UnsafeRows from the input block, to help
* FileFormatDataWriter to aware partition/bucket changes.
*/
public static native BlockStripes splitBlockByPartitionAndBucket(
long blockAddress,
int[] partitionColIndice,
int[] partitionColIndices,
boolean hasBucket,
boolean reserve_partition_columns);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
context: TaskAttemptContext,
nativeConf: java.util.Map[String, String]): OutputWriter = {
val originPath = path
val datasourceJniWrapper = new CHDatasourceJniWrapper();
val datasourceJniWrapper = new CHDatasourceJniWrapper()
CHThreadGroup.registerNewThreadGroup()

val namedStructBuilder = NamedStruct.newBuilder
Expand All @@ -52,10 +52,7 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
var namedStruct = namedStructBuilder.build

val instance =
datasourceJniWrapper.nativeInitFileWriterWrapper(
path,
namedStruct.toByteArray,
getFormatName());
datasourceJniWrapper.createFilerWriter(path, namedStruct.toByteArray, getFormatName())

new OutputWriter {
override def write(row: InternalRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ class CHMergeTreeWriterInjects extends GlutenFormatWriterInjectsBase {
)
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val instance =
datasourceJniWrapper.nativeInitMergeTreeWriterWrapper(
planWithSplitInfo.plan,
datasourceJniWrapper.createMergeTreeWriter(
planWithSplitInfo.splitInfo,
uuid,
context.getTaskAttemptID.getTaskID.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class MergeTreeFileFormatDataWriter(
protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
protected var currentWriter: OutputWriter = _

protected val returnedMetrics = mutable.HashMap[String, AddFile]()
protected val returnedMetrics: mutable.Map[String, AddFile] = mutable.HashMap[String, AddFile]()

/** Trackers for computing various statistics on the data as it's being written out. */
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
Expand All @@ -71,10 +71,10 @@ abstract class MergeTreeFileFormatDataWriter(
try {
currentWriter.close()
statsTrackers.foreach(_.closeFile(currentWriter.path()))
val ret = currentWriter.asInstanceOf[MergeTreeOutputWriter].getAddFiles()
if (ret.nonEmpty) {
ret.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
}
currentWriter
.asInstanceOf[MergeTreeOutputWriter]
.getAddFiles
.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
} finally {
currentWriter = null
}
Expand Down Expand Up @@ -117,12 +117,7 @@ abstract class MergeTreeFileFormatDataWriter(
releaseResources()
val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs {
// committer.commitTask(taskAttemptContext)
val statuses = returnedMetrics
.map(
v => {
v._2
})
.toSeq
val statuses = returnedMetrics.values.toSeq
new TaskCommitMessage(statuses)
}

Expand All @@ -142,7 +137,7 @@ abstract class MergeTreeFileFormatDataWriter(

override def close(): Unit = {}

def getReturnedMetrics(): mutable.Map[String, AddFile] = returnedMetrics
def getReturnedMetrics: mutable.Map[String, AddFile] = returnedMetrics
}

/** FileFormatWriteTask for empty partitions */
Expand Down Expand Up @@ -443,7 +438,11 @@ class MergeTreeDynamicPartitionDataSingleWriter(
case fakeRow: FakeRow =>
if (fakeRow.batch.numRows() > 0) {
val blockStripes = GlutenRowSplitter.getInstance
.splitBlockByPartitionAndBucket(fakeRow, partitionColIndice, isBucketed, true)
.splitBlockByPartitionAndBucket(
fakeRow,
partitionColIndice,
isBucketed,
reserve_partition_columns = true)

val iter = blockStripes.iterator()
while (iter.hasNext) {
Expand Down Expand Up @@ -526,10 +525,10 @@ class MergeTreeDynamicPartitionDataConcurrentWriter(
if (status.outputWriter != null) {
try {
status.outputWriter.close()
val ret = status.outputWriter.asInstanceOf[MergeTreeOutputWriter].getAddFiles()
if (ret.nonEmpty) {
ret.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
}
status.outputWriter
.asInstanceOf[MergeTreeOutputWriter]
.getAddFiles
.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
} finally {
status.outputWriter = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ class MergeTreeOutputWriter(

if (nextBatch.numRows > 0) {
val col = nextBatch.column(0).asInstanceOf[CHColumnVector]
datasourceJniWrapper.writeToMergeTree(instance, col.getBlockAddress)
datasourceJniWrapper.write(instance, col.getBlockAddress)
} // else just ignore this empty block
}

override def close(): Unit = {
val returnedMetrics = datasourceJniWrapper.closeMergeTreeWriter(instance)
val returnedMetrics = datasourceJniWrapper.close(instance)
if (returnedMetrics != null && returnedMetrics.nonEmpty) {
addFiles.appendAll(
AddFileTags.partsMetricsToAddFile(
Expand All @@ -64,7 +64,7 @@ class MergeTreeOutputWriter(
originPath
}

def getAddFiles(): ArrayBuffer[AddFile] = {
def getAddFiles: ArrayBuffer[AddFile] = {
addFiles
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class GlutenClickHouseMergeTreeWriteSuite
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.files.maxPartitionBytes", "20000000")
.set("spark.gluten.sql.native.writer.enabled", "true")
.setCHSettings("min_insert_block_size_rows", 100000)
.setCHSettings("mergetree.merge_after_insert", false)
.setCHSettings("input_format_parquet_max_block_size", 8192)
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ foreach(child ${children})
add_headers_and_sources(function_parsers ${child})
endforeach()

# Notice: soures files under Parser/*_udf subdirectories must be built into
# Notice: sources files under Parser/*_udf subdirectories must be built into
# target ${LOCALENGINE_SHARED_LIB} directly to make sure all function parsers
# are registered successly.
# are registered successfully.
add_library(
${LOCALENGINE_SHARED_LIB} SHARED
local_engine_jni.cpp ${local_udfs_sources} ${function_parsers_sources}
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/Output/FileWriterWrappers.h>
#include <Storages/Output/NormalFileWriter.h>
#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <Storages/SubstraitSource/SubstraitFileSourceStep.h>
#include <google/protobuf/util/json_util.h>
Expand Down
6 changes: 2 additions & 4 deletions cpp-ch/local-engine/Parser/SubstraitParserUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ namespace local_engine
{
void logDebugMessage(const google::protobuf::Message & message, const char * type)
{
auto * logger = &Poco::Logger::get("SubstraitPlan");
if (logger->debug())
if (auto * logger = &Poco::Logger::get("SubstraitPlan"); logger->debug())
{
namespace pb_util = google::protobuf::util;
pb_util::JsonOptions options;
std::string json;
auto s = pb_util::MessageToJsonString(message, &json, options);
if (!s.ok())
if (auto s = pb_util::MessageToJsonString(message, &json, options); !s.ok())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert {} to Json", type);
LOG_DEBUG(logger, "{}:\n{}", type, json);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/WriteRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <Parser/TypeParser.h>
#include <Processors/Transforms/ExpressionTransform.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Storages/Output/FileWriterWrappers.h>
#include <Storages/Output/NormalFileWriter.h>
#include <substrait/algebra.pb.h>
#include <substrait/type.pb.h>
#include <Poco/StringTokenizer.h>
Expand Down
70 changes: 36 additions & 34 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,44 @@ Block removeColumnSuffix(const Block & block)
}
return Block(columns);
}

}

namespace local_engine
{

std::string PartInfo::toJson(const std::vector<PartInfo> & part_infos)
{
rapidjson::StringBuffer result;
rapidjson::Writer<rapidjson::StringBuffer> writer(result);
writer.StartArray();
for (const auto & item : part_infos)
{
writer.StartObject();
writer.Key("part_name");
writer.String(item.part_name.c_str());
writer.Key("mark_count");
writer.Uint(item.mark_count);
writer.Key("disk_size");
writer.Uint(item.disk_size);
writer.Key("row_count");
writer.Uint(item.row_count);
writer.Key("bucket_id");
writer.String(item.bucket_id.c_str());
writer.Key("partition_values");
writer.StartObject();
for (const auto & key_value : item.partition_values)
{
writer.Key(key_value.first.c_str());
writer.String(key_value.second.c_str());
}
writer.EndObject();
writer.EndObject();
}
writer.EndArray();
return result.GetString();
}

std::unique_ptr<SparkMergeTreeWriter> SparkMergeTreeWriter::create(
const MergeTreeTable & merge_tree_table,
const SparkMergeTreeWritePartitionSettings & write_settings_,
Expand Down Expand Up @@ -82,7 +115,7 @@ SparkMergeTreeWriter::SparkMergeTreeWriter(
{
}

void SparkMergeTreeWriter::write(const DB::Block & block)
void SparkMergeTreeWriter::write(DB::Block & block)
{
auto new_block = removeColumnSuffix(block);
auto converter = ActionsDAG::makeConvertingActions(
Expand All @@ -92,9 +125,10 @@ void SparkMergeTreeWriter::write(const DB::Block & block)
executor.push(new_block);
}

void SparkMergeTreeWriter::finalize()
std::string SparkMergeTreeWriter::close()
{
executor.finish();
return PartInfo::toJson(getAllPartInfo());
}

std::vector<PartInfo> SparkMergeTreeWriter::getAllPartInfo() const
Expand All @@ -116,36 +150,4 @@ std::vector<PartInfo> SparkMergeTreeWriter::getAllPartInfo() const
return res;
}

String SparkMergeTreeWriter::partInfosToJson(const std::vector<PartInfo> & part_infos)
{
rapidjson::StringBuffer result;
rapidjson::Writer<rapidjson::StringBuffer> writer(result);
writer.StartArray();
for (const auto & item : part_infos)
{
writer.StartObject();
writer.Key("part_name");
writer.String(item.part_name.c_str());
writer.Key("mark_count");
writer.Uint(item.mark_count);
writer.Key("disk_size");
writer.Uint(item.disk_size);
writer.Key("row_count");
writer.Uint(item.row_count);
writer.Key("bucket_id");
writer.String(item.bucket_id.c_str());
writer.Key("partition_values");
writer.StartObject();
for (const auto & key_value : item.partition_values)
{
writer.Key(key_value.first.c_str());
writer.String(key_value.second.c_str());
}
writer.EndObject();
writer.EndObject();
}
writer.EndArray();
return result.GetString();
}

}
Loading

0 comments on commit 866a60a

Please sign in to comment.