Skip to content

Commit

Permalink
[Bug Fix] collect_partition_cols and Remove ApplySquashingTransform a…
Browse files Browse the repository at this point in the history
…nd PlanSquashingTransform
  • Loading branch information
baibaichen committed Dec 15, 2024
1 parent 9b89791 commit 937caff
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,23 @@ class CHTransformerApi extends TransformerApi with Logging {
register.shortName
case _ => "UnknownFileFormat"
}
val write = Write
val childOutput = writeExec.child.output

val partitionIndexes =
writeExec.partitionColumns.map(p => childOutput.indexWhere(_.exprId == p.exprId))
require(partitionIndexes.forall(_ >= 0))

val common = Write.Common
.newBuilder()
.setCommon(
Write.Common
.newBuilder()
.setFormat(fileFormatStr)
.setJobTaskAttemptId("") // we cannot get job and task id at the driver side
.build())
.setFormat(s"$fileFormatStr")
.setJobTaskAttemptId("") // we cannot get job and task id at the driver side)
partitionIndexes.foreach {
idx =>
require(idx >= 0)
common.addPartitionColIndex(idx)
}

val write = Write.newBuilder().setCommon(common.build())

writeExec.fileFormat match {
case d: MergeTreeFileFormat =>
Expand All @@ -271,5 +280,5 @@ class CHTransformerApi extends TransformerApi with Logging {

/** use Hadoop Path class to encode the file path */
override def encodeFilePathIfNeed(filePath: String): String =
(new Path(filePath)).toUri.toASCIIString
new Path(filePath).toUri.toASCIIString
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object RuntimeConfig {
import CHConf._
import SQLConf._

/** Clickhouse Configuration */
val PATH =
buildConf(runtimeConfig("path"))
.doc(
Expand All @@ -37,9 +38,25 @@ object RuntimeConfig {
.createWithDefault("/tmp/libch")
// scalastyle:on line.size.limit

// scalastyle:off line.size.limit
val LOGGER_LEVEL =
buildConf(runtimeConfig("logger.level"))
.doc(
"https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings#logger")
.stringConf
.createWithDefault("warning")
// scalastyle:on line.size.limit

/** Gluten Configuration */
val USE_CURRENT_DIRECTORY_AS_TMP =
buildConf(runtimeConfig("use_current_directory_as_tmp"))
.doc("Use the current directory as the temporary directory.")
.booleanConf
.createWithDefault(false)

val DUMP_PIPELINE =
buildConf(runtimeConfig("dump_pipeline"))
.doc("Dump pipeline to file after execution")
.booleanConf
.createWithDefault(false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class GlutenClickHouseMergeTreeWriteSuite
}

test("test mergetree write with partition") {
withSQLConf((CHConf.ENABLE_ONEPIPELINE_MERGETREE_WRITE.key, spark35.toString)) {
withSQLConf((CHConf.ENABLE_ONEPIPELINE_MERGETREE_WRITE.key, "false")) {
spark.sql(s"""
|DROP TABLE IF EXISTS lineitem_mergetree_partition;
|""".stripMargin)
Expand Down Expand Up @@ -703,7 +703,7 @@ class GlutenClickHouseMergeTreeWriteSuite

val mergetreeScan = scanExec.head
assert(mergetreeScan.nodeName.startsWith("ScanTransformer mergetree"))
// assertResult(3745)(mergetreeScan.metrics("numFiles").value)
assertResult(3745)(mergetreeScan.metrics("numFiles").value)

val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex]
assert(ClickHouseTableV2.getTable(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty)
Expand Down Expand Up @@ -1601,7 +1601,7 @@ class GlutenClickHouseMergeTreeWriteSuite
case scanExec: BasicScanExecTransformer => scanExec
}
assertResult(1)(plans.size)
assertResult(conf._2)(plans.head.getSplitInfos.size)
assertResult(conf._2)(plans.head.getSplitInfos().size)
}
}
})
Expand All @@ -1625,7 +1625,7 @@ class GlutenClickHouseMergeTreeWriteSuite
case scanExec: BasicScanExecTransformer => scanExec
}
assertResult(1)(plans.size)
assertResult(1)(plans.head.getSplitInfos.size)
assertResult(1)(plans.head.getSplitInfos().size)
}
}
}
Expand Down Expand Up @@ -1733,7 +1733,7 @@ class GlutenClickHouseMergeTreeWriteSuite
case f: BasicScanExecTransformer => f
}
assertResult(2)(scanExec.size)
assertResult(conf._2)(scanExec(1).getSplitInfos.size)
assertResult(conf._2)(scanExec(1).getSplitInfos().size)
}
}
})
Expand Down Expand Up @@ -1779,7 +1779,7 @@ class GlutenClickHouseMergeTreeWriteSuite

Seq("true", "false").foreach {
skip =>
withSQLConf("spark.databricks.delta.stats.skipping" -> skip.toString) {
withSQLConf("spark.databricks.delta.stats.skipping" -> skip) {
val sqlStr =
s"""
|SELECT
Expand Down Expand Up @@ -1903,7 +1903,7 @@ class GlutenClickHouseMergeTreeWriteSuite
Seq(("-1", 3), ("3", 3), ("6", 1)).foreach(
conf => {
withSQLConf(
("spark.gluten.sql.columnar.backend.ch.files.per.partition.threshold" -> conf._1)) {
"spark.gluten.sql.columnar.backend.ch.files.per.partition.threshold" -> conf._1) {
val sql =
s"""
|select count(1), min(l_returnflag) from lineitem_split
Expand All @@ -1916,7 +1916,7 @@ class GlutenClickHouseMergeTreeWriteSuite
val scanExec = collect(df.queryExecution.executedPlan) {
case f: FileSourceScanExecTransformer => f
}
assert(scanExec(0).getPartitions.size == conf._2)
assert(scanExec.head.getPartitions.size == conf._2)
}
}
})
Expand Down
58 changes: 35 additions & 23 deletions cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#include <DataTypes/DataTypeTuple.h>
#include <Interpreters/Context.h>
#include <Parser/TypeParser.h>
#include <Processors/Transforms/ApplySquashingTransform.h>
#include <Processors/Transforms/ExpressionTransform.h>
#include <Processors/Transforms/PlanSquashingTransform.h>
#include <Processors/Transforms/MaterializingTransform.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Storages/MergeTree/SparkMergeTreeMeta.h>
#include <Storages/MergeTree/SparkMergeTreeSink.h>
Expand Down Expand Up @@ -103,7 +102,7 @@ void adjust_output(const DB::QueryPipelineBuilderPtr & builder, const DB::Block
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Missmatch result columns size, input size is {}, but output size is {}",
"Mismatch result columns size, input size is {}, but output size is {}",
input.columns(),
output.columns());
}
Expand Down Expand Up @@ -164,12 +163,6 @@ void addMergeTreeSinkTransform(
: std::make_shared<SparkMergeTreePartitionedFileSink>(header, partition_by, merge_tree_table, write_settings, context, stats);

chain.addSource(sink);
const DB::Settings & settings = context->getSettingsRef();
chain.addSource(std::make_shared<ApplySquashingTransform>(
header, settings[Setting::min_insert_block_size_rows], settings[Setting::min_insert_block_size_bytes]));
chain.addSource(std::make_shared<PlanSquashingTransform>(
header, settings[Setting::min_insert_block_size_rows], settings[Setting::min_insert_block_size_bytes]));

builder->addChain(std::move(chain));
}

Expand Down Expand Up @@ -212,6 +205,7 @@ void addNormalFileWriterSinkTransform(
namespace local_engine
{


IMPLEMENT_GLUTEN_SETTINGS(GlutenWriteSettings, WRITE_RELATED_SETTINGS)

void addSinkTransform(const DB::ContextPtr & context, const substrait::WriteRel & write_rel, const DB::QueryPipelineBuilderPtr & builder)
Expand All @@ -224,36 +218,54 @@ void addSinkTransform(const DB::ContextPtr & context, const substrait::WriteRel
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Failed to unpack write optimization with local_engine::Write.");
assert(write.has_common());
const substrait::NamedStruct & table_schema = write_rel.table_schema();
auto output = TypeParser::buildBlockFromNamedStruct(table_schema);
adjust_output(builder, output);
const auto partitionCols = collect_partition_cols(output, table_schema);
auto partition_indexes = write.common().partition_col_index();
if (write.has_mergetree())
{
local_engine::MergeTreeTable merge_tree_table(write, table_schema);
MergeTreeTable merge_tree_table(write, table_schema);
auto output = TypeParser::buildBlockFromNamedStruct(table_schema, merge_tree_table.low_card_key);
adjust_output(builder, output);

builder->addSimpleTransform(
[&](const Block & in_header) -> ProcessorPtr { return std::make_shared<MaterializingTransform>(in_header, false); });

const auto partition_by = collect_partition_cols(output, table_schema, partition_indexes);

GlutenWriteSettings write_settings = GlutenWriteSettings::get(context);
if (write_settings.task_write_tmp_dir.empty())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "MergeTree Write Pipeline need inject relative path.");
if (!merge_tree_table.relative_path.empty())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Non empty relative path for MergeTree table in pipeline mode.");

merge_tree_table.relative_path = write_settings.task_write_tmp_dir;
addMergeTreeSinkTransform(context, builder, merge_tree_table, output, partitionCols);
addMergeTreeSinkTransform(context, builder, merge_tree_table, output, partition_by);
}
else
addNormalFileWriterSinkTransform(context, builder, write.common().format(), output, partitionCols);
{
auto output = TypeParser::buildBlockFromNamedStruct(table_schema);
adjust_output(builder, output);
const auto partition_by = collect_partition_cols(output, table_schema, partition_indexes);
addNormalFileWriterSinkTransform(context, builder, write.common().format(), output, partition_by);
}
}

DB::Names collect_partition_cols(const DB::Block & header, const substrait::NamedStruct & struct_)
DB::Names collect_partition_cols(const DB::Block & header, const substrait::NamedStruct & struct_, const PartitionIndexes & partition_by)
{
DB::Names result;
if (partition_by.empty())
{
assert(std::ranges::all_of(
struct_.column_types(), [](const int32_t type) { return type != ::substrait::NamedStruct::PARTITION_COL; }));
return {};
}
assert(struct_.column_types_size() == header.columns());
assert(struct_.column_types_size() == struct_.struct_().types_size());

auto name_iter = header.begin();
auto type_iter = struct_.column_types().begin();
for (; name_iter != header.end(); ++name_iter, ++type_iter)
if (*type_iter == ::substrait::NamedStruct::PARTITION_COL)
result.push_back(name_iter->name);
DB::Names result;
result.reserve(partition_by.size());
for (auto idx : partition_by)
{
assert(idx >= 0 && idx < header.columns());
assert(struct_.column_types(idx) == ::substrait::NamedStruct::PARTITION_COL);
result.emplace_back(header.getByPosition(idx).name);
}
return result;
}

Expand Down
5 changes: 4 additions & 1 deletion cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <Core/Block.h>
#include <Core/Names.h>
#include <Interpreters/Context_fwd.h>
#include <google/protobuf/repeated_field.h>
#include <Common/GlutenSettings.h>

namespace substrait
Expand All @@ -38,9 +39,11 @@ using QueryPipelineBuilderPtr = std::unique_ptr<QueryPipelineBuilder>;
namespace local_engine
{

using PartitionIndexes = google::protobuf::RepeatedField<::int32_t>;

void addSinkTransform(const DB::ContextPtr & context, const substrait::WriteRel & write_rel, const DB::QueryPipelineBuilderPtr & builder);

DB::Names collect_partition_cols(const DB::Block & header, const substrait::NamedStruct & struct_);
DB::Names collect_partition_cols(const DB::Block & header, const substrait::NamedStruct & struct_, const PartitionIndexes & partition_by);

#define WRITE_RELATED_SETTINGS(M, ALIAS) \
M(String, task_write_tmp_dir, , "The temporary directory for writing data") \
Expand Down
43 changes: 30 additions & 13 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,37 @@ extern const Metric GlobalThreadActive;
extern const Metric GlobalThreadScheduled;
}

namespace DB::Setting
{
extern const SettingsUInt64 min_insert_block_size_rows;
extern const SettingsUInt64 min_insert_block_size_bytes;
}
namespace local_engine
{

void SparkMergeTreeSink::consume(Chunk & chunk)
void SparkMergeTreeSink::write(const Chunk & chunk)
{
assert(!sink_helper->metadata_snapshot->hasPartitionKey());
CurrentThread::flushUntrackedMemory();

/// Reset earlier, so put it in the scope
BlockWithPartition item{getHeader().cloneWithColumns(chunk.getColumns()), Row{}};
size_t before_write_memory = 0;
if (auto * memory_tracker = CurrentThread::getMemoryTracker())
{
CurrentThread::flushUntrackedMemory();
before_write_memory = memory_tracker->get();
}

sink_helper->writeTempPart(item, context, part_num);
part_num++;
/// Reset earlier to free memory
item.block.clear();
item.partition.clear();
}

sink_helper->checkAndMerge();
void SparkMergeTreeSink::consume(Chunk & chunk)
{
Chunk tmp;
tmp.swap(chunk);
squashed_chunk = squashing.add(std::move(tmp));
if (static_cast<bool>(squashed_chunk))
{
write(Squashing::squash(std::move(squashed_chunk)));
sink_helper->checkAndMerge();
}
assert(squashed_chunk.getNumRows() == 0);
assert(chunk.getNumRows() == 0);
}

void SparkMergeTreeSink::onStart()
Expand All @@ -61,6 +71,11 @@ void SparkMergeTreeSink::onStart()

void SparkMergeTreeSink::onFinish()
{
assert(squashed_chunk.getNumRows() == 0);
squashed_chunk = squashing.flush();
if (static_cast<bool>(squashed_chunk))
write(Squashing::squash(std::move(squashed_chunk)));
assert(squashed_chunk.getNumRows() == 0);
sink_helper->finish(context);
if (stats_.has_value())
(*stats_)->collectStats(sink_helper->unsafeGet(), sink_helper->write_settings.partition_settings.partition_dir);
Expand Down Expand Up @@ -91,7 +106,9 @@ SinkToStoragePtr SparkMergeTreeSink::create(
}
else
sink_helper = std::make_shared<DirectSinkHelper>(dest_storage, write_settings_, isRemoteStorage);
return std::make_shared<SparkMergeTreeSink>(sink_helper, context, stats);
const DB::Settings & settings = context->getSettingsRef();
return std::make_shared<SparkMergeTreeSink>(
sink_helper, context, stats, settings[Setting::min_insert_block_size_rows], settings[Setting::min_insert_block_size_bytes]);
}

SinkHelper::SinkHelper(const SparkStorageMergeTreePtr & data_, const SparkMergeTreeWriteSettings & write_settings_, bool isRemoteStorage_)
Expand Down
17 changes: 15 additions & 2 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,17 @@ class SparkMergeTreeSink : public DB::SinkToStorage
const DB::ContextMutablePtr & context,
const SinkStatsOption & stats = {});

explicit SparkMergeTreeSink(const SinkHelperPtr & sink_helper_, const ContextPtr & context_, const SinkStatsOption & stats)
: SinkToStorage(sink_helper_->metadata_snapshot->getSampleBlock()), context(context_), sink_helper(sink_helper_), stats_(stats)
explicit SparkMergeTreeSink(
const SinkHelperPtr & sink_helper_,
const ContextPtr & context_,
const SinkStatsOption & stats,
size_t min_block_size_rows,
size_t min_block_size_bytes)
: SinkToStorage(sink_helper_->metadata_snapshot->getSampleBlock())
, context(context_)
, sink_helper(sink_helper_)
, stats_(stats)
, squashing(sink_helper_->metadata_snapshot->getSampleBlock(), min_block_size_rows, min_block_size_bytes)
{
}
~SparkMergeTreeSink() override = default;
Expand All @@ -241,9 +250,13 @@ class SparkMergeTreeSink : public DB::SinkToStorage
const SinkHelper & sinkHelper() const { return *sink_helper; }

private:
void write(const Chunk & chunk);

ContextPtr context;
SinkHelperPtr sink_helper;
std::optional<std::shared_ptr<MergeTreeStats>> stats_;
Squashing squashing;
Chunk squashed_chunk;
int part_num = 1;
};

Expand Down
Loading

0 comments on commit 937caff

Please sign in to comment.