Skip to content

Commit

Permalink
SparkMergeTreeWriter using PushingPipelineExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Sep 4, 2024
1 parent 717a5a6 commit 10c5b8e
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 155 deletions.
27 changes: 12 additions & 15 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,24 +190,25 @@ MergeTreeDataWriter::TemporaryPart SparkMergeTreeDataWriter::writeTempPart(


SinkToStoragePtr SparkStorageMergeTree::write(
const ASTPtr &, const StorageMetadataPtr & storage_in_memory_metadata, ContextPtr context, bool /*async_insert*/)
const ASTPtr &, const StorageMetadataPtr & /*storage_in_memory_metadata*/, ContextPtr context, bool /*async_insert*/)
{
GlutenMergeTreeWriteSettings settings{.partition_settings{MergeTreePartitionWriteSettings::get(context)}};
settings.load(context);
SinkHelperPtr sink_helper = SinkHelper::create(table, settings, getContext());
SinkHelperPtr sink_helper = SparkMergeTreeSink::create(table, settings, getContext());
#ifndef NDEBUG
auto dest_storage = MergeTreeRelParser::getStorage(table, getContext());
assert(dest_storage.get() == this);
#endif

return std::make_shared<SparkMergeTreeSink>(*this, storage_in_memory_metadata, context);
return std::make_shared<SparkMergeTreeSink>(sink_helper, context);
}

void SparkMergeTreeSink::consume(Chunk & chunk)
{
assert(!metadata_snapshot->hasPartitionKey());
assert(!sink_helper->metadata_snapshot->hasPartitionKey());

auto block = getHeader().cloneWithColumns(chunk.getColumns());
auto blocks_with_partition = MergeTreeDataWriter::splitBlockIntoParts(std::move(block), 10, metadata_snapshot, context);
auto blocks_with_partition = MergeTreeDataWriter::splitBlockIntoParts(std::move(block), 10, sink_helper->metadata_snapshot, context);

for (auto & item : blocks_with_partition)
{
Expand All @@ -217,15 +218,13 @@ void SparkMergeTreeSink::consume(Chunk & chunk)
CurrentThread::flushUntrackedMemory();
before_write_memory = memory_tracker->get();
}

MergeTreeDataWriter::TemporaryPart temp_part
= storage.writer.writeTempPart(item, metadata_snapshot, context, write_settings, part_num);
new_parts.emplace_back(temp_part.part);
sink_helper->writeTempPart(item, context, part_num);
part_num++;
/// Reset earlier to free memory
item.block.clear();
item.partition.clear();
}
sink_helper->checkAndMerge();
}

void SparkMergeTreeSink::onStart()
Expand All @@ -235,11 +234,11 @@ void SparkMergeTreeSink::onStart()

void SparkMergeTreeSink::onFinish()
{
// DO NOTHING
sink_helper->finish(context);
}

/////
SinkHelperPtr SinkHelper::create(
SinkHelperPtr SparkMergeTreeSink::create(
const MergeTreeTable & merge_tree_table, const GlutenMergeTreeWriteSettings & write_settings_, const DB::ContextMutablePtr & context)
{
auto dest_storage = MergeTreeRelParser::getStorage(merge_tree_table, context);
Expand All @@ -264,7 +263,6 @@ SinkHelper::SinkHelper(const CustomStorageMergeTreePtr & data_, const GlutenMerg
, thread_pool(CurrentMetrics::LocalThread, CurrentMetrics::LocalThreadActive, CurrentMetrics::LocalThreadScheduled, 1, 1, 100000)
, write_settings(write_settings_)
, metadata_snapshot(data->getInMemoryMetadataPtr())
, header(metadata_snapshot->getSampleBlock())
{
}

Expand Down Expand Up @@ -335,8 +333,8 @@ void SinkHelper::doMergePartsAsync(const std::vector<DB::MergeTreeDataPartPtr> &
}
void SinkHelper::writeTempPart(DB::BlockWithPartition & block_with_partition, const ContextPtr & context, int part_num)
{
auto tmp
= dataRef().writer.writeTempPart(block_with_partition, metadata_snapshot, context, write_settings.partition_settings, part_num);
auto tmp = dataRef().getWriter().writeTempPart(
block_with_partition, metadata_snapshot, context, write_settings.partition_settings, part_num);
new_parts.emplace_back(tmp.part);
}

Expand Down Expand Up @@ -454,7 +452,6 @@ void CopyToRemoteSinkHelper::commit(const ReadSettings & read_settings, const Wr
void DirectSinkHelper::cleanup()
{
// default storage need clean temp.

std::unordered_set<String> final_parts;
for (const auto & merge_tree_data_part : new_parts.unsafeGet())
final_parts.emplace(merge_tree_data_part->name);
Expand Down
61 changes: 26 additions & 35 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,8 @@ class SparkMergeTreeDataWriter
LoggerPtr log;
};

class SparkMergeTreeSink;
class SinkHelper;
using SinkHelperPtr = std::shared_ptr<SinkHelper>;

class SparkStorageMergeTree final : public CustomStorageMergeTree
{
friend class SparkMergeTreeSink;
friend class SinkHelper;

public:
SparkStorageMergeTree(const MergeTreeTable & table_, const StorageInMemoryMetadata & metadata, const ContextMutablePtr & context_)
: CustomStorageMergeTree(
Expand All @@ -103,6 +96,8 @@ class SparkStorageMergeTree final : public CustomStorageMergeTree
SinkToStoragePtr
write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;

SparkMergeTreeDataWriter & getWriter() { return writer; }

private:
MergeTreeTable table;
SparkMergeTreeDataWriter writer;
Expand Down Expand Up @@ -177,32 +172,26 @@ class SinkHelper
public:
const GlutenMergeTreeWriteSettings write_settings;
const DB::StorageMetadataPtr metadata_snapshot;
const DB::Block header;

protected:
virtual CustomStorageMergeTree & dest_storage() { return *data; }

void doMergePartsAsync(const std::vector<DB::MergeTreeDataPartPtr> & prepare_merge_parts);
void finalizeMerge();
virtual void cleanup() { }
virtual void commit(const ReadSettings & read_settings, const WriteSettings & write_settings) { }
void saveMetadata(const DB::ContextPtr & context);
SparkStorageMergeTree & dataRef() const { return assert_cast<SparkStorageMergeTree &>(*data); }

public:
void writeTempPart(DB::BlockWithPartition & block_with_partition, const ContextPtr & context, int part_num);

const std::deque<DB::MergeTreeDataPartPtr> & unsafeGet() const { return new_parts.unsafeGet(); }

void writeTempPart(DB::BlockWithPartition & block_with_partition, const ContextPtr & context, int part_num);
void checkAndMerge(bool force = false);
void finish(const DB::ContextPtr & context);

virtual ~SinkHelper() = default;
SinkHelper(const CustomStorageMergeTreePtr & data_, const GlutenMergeTreeWriteSettings & write_settings_, bool isRemoteStorage_);
static SinkHelperPtr create(
const MergeTreeTable & merge_tree_table,
const GlutenMergeTreeWriteSettings & write_settings_,
const DB::ContextMutablePtr & context);

virtual CustomStorageMergeTree & dest_storage() { return *data; }

virtual void commit(const ReadSettings & read_settings, const WriteSettings & write_settings) { }
void saveMetadata(const DB::ContextPtr & context);
};

class DirectSinkHelper : public SinkHelper
Expand All @@ -222,6 +211,11 @@ class CopyToRemoteSinkHelper : public SinkHelper
{
CustomStorageMergeTreePtr dest;

protected:
void commit(const ReadSettings & read_settings, const WriteSettings & write_settings) override;
CustomStorageMergeTree & dest_storage() override { return *dest; }
const CustomStorageMergeTreePtr & temp_storage() const { return data; }

public:
explicit CopyToRemoteSinkHelper(
const CustomStorageMergeTreePtr & temp,
Expand All @@ -231,23 +225,20 @@ class CopyToRemoteSinkHelper : public SinkHelper
{
assert(data != dest);
}

CustomStorageMergeTree & dest_storage() override { return *dest; }
const CustomStorageMergeTreePtr & temp_storage() const { return data; }

void commit(const ReadSettings & read_settings, const WriteSettings & write_settings) override;
};

using SinkHelperPtr = std::shared_ptr<SinkHelper>;

class SparkMergeTreeSink : public DB::SinkToStorage
{
public:
explicit SparkMergeTreeSink(
SparkStorageMergeTree & storage_, const StorageMetadataPtr & metadata_snapshot_, const ContextPtr & context_)
: SinkToStorage(metadata_snapshot_->getSampleBlock())
, storage(storage_)
, metadata_snapshot(metadata_snapshot_)
, context(context_)
, write_settings(MergeTreePartitionWriteSettings::get(context_))
static SinkHelperPtr create(
const MergeTreeTable & merge_tree_table,
const GlutenMergeTreeWriteSettings & write_settings_,
const DB::ContextMutablePtr & context);

explicit SparkMergeTreeSink(const SinkHelperPtr & sink_helper_, const ContextPtr & context_)
: SinkToStorage(sink_helper_->metadata_snapshot->getSampleBlock()), context(context_), sink_helper(sink_helper_)
{
}
~SparkMergeTreeSink() override = default;
Expand All @@ -257,13 +248,13 @@ class SparkMergeTreeSink : public DB::SinkToStorage
void onStart() override;
void onFinish() override;

const SinkHelper & sinkHelper() const { return *sink_helper; }

private:
SparkStorageMergeTree & storage;
StorageMetadataPtr metadata_snapshot;
ContextPtr context;
MergeTreePartitionWriteSettings write_settings;
SinkHelperPtr sink_helper;

int part_num = 1;
std::vector<DB::MergeTreeDataPartPtr> new_parts{};
};

}
87 changes: 37 additions & 50 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <Interpreters/ActionsDAG.h>
#include <Parser/MergeTreeRelParser.h>
#include <Processors/Transforms/ApplySquashingTransform.h>
#include <Processors/Transforms/PlanSquashingTransform.h>
#include <Storages/MergeTree/DataPartStorageOnDiskFull.h>
#include <Storages/MergeTree/MetaDataHelper.h>
#include <Storages/MergeTree/SparkMergeTreeSink.h>
Expand Down Expand Up @@ -47,74 +49,58 @@ Block removeColumnSuffix(const Block & block)

namespace local_engine
{
SparkMergeTreeWriter::SparkMergeTreeWriter(
const MergeTreeTable & merge_tree_table, const GlutenMergeTreeWriteSettings & write_settings_, const DB::ContextPtr & context_)
: dataWrapper(SinkHelper::create(merge_tree_table, write_settings_, SerializedPlanParser::global_context)), context(context_)

std::unique_ptr<SparkMergeTreeWriter> SparkMergeTreeWriter::create(
const MergeTreeTable & merge_tree_table, const MergeTreePartitionWriteSettings & write_settings_, const DB::ContextMutablePtr & context)
{
const DB::Settings & settings = context->getSettingsRef();
squashing
= std::make_unique<DB::Squashing>(dataWrapper->header, settings.min_insert_block_size_rows, settings.min_insert_block_size_bytes);
if (!write_settings_.partition_settings.partition_dir.empty())
extractPartitionValues(write_settings_.partition_settings.partition_dir, partition_values);
const auto dest_storage = MergeTreeRelParser::getStorage(merge_tree_table, context);
StorageMetadataPtr metadata_snapshot = dest_storage->getInMemoryMetadataPtr();
Block header = metadata_snapshot->getSampleBlock();
ASTPtr none;
Chain chain;
auto sink = dest_storage->write(none, metadata_snapshot, context, false);
chain.addSink(sink);
chain.addSource(
std::make_shared<ApplySquashingTransform>(header, settings.min_insert_block_size_rows, settings.min_insert_block_size_bytes));
chain.addSource(
std::make_shared<PlanSquashingTransform>(header, settings.min_insert_block_size_rows, settings.min_insert_block_size_bytes));

std::unordered_map<String, String> partition_values;
if (!write_settings_.partition_dir.empty())
extractPartitionValues(write_settings_.partition_dir, partition_values);
return std::make_unique<SparkMergeTreeWriter>(
header, assert_cast<const SparkMergeTreeSink &>(*sink).sinkHelper(), QueryPipeline{std::move(chain)}, std::move(partition_values));
}

SparkMergeTreeWriter::SparkMergeTreeWriter(
const DB::Block & header_,
const SinkHelper & sink_helper_,
DB::QueryPipeline && pipeline_,
std::unordered_map<String, String> && partition_values_)
: header{header_}, sink_helper{sink_helper_}, pipeline{std::move(pipeline_)}, executor{pipeline}, partition_values{partition_values_}
{
}

void SparkMergeTreeWriter::write(const DB::Block & block)
{
auto new_block = removeColumnSuffix(block);
auto converter = ActionsDAG::makeConvertingActions(
new_block.getColumnsWithTypeAndName(), dataWrapper->header.getColumnsWithTypeAndName(), DB::ActionsDAG::MatchColumnsMode::Position);
new_block.getColumnsWithTypeAndName(), header.getColumnsWithTypeAndName(), DB::ActionsDAG::MatchColumnsMode::Position);
const ExpressionActions expression_actions{std::move(converter)};
expression_actions.execute(new_block);

if (chunkToPart(squashing->add({new_block.getColumns(), new_block.rows()})))
dataWrapper->checkAndMerge();
}

bool SparkMergeTreeWriter::chunkToPart(Chunk && plan_chunk)
{
if (Chunk result_chunk = DB::Squashing::squash(std::move(plan_chunk)))
{
auto result = squashing->getHeader().cloneWithColumns(result_chunk.detachColumns());
return blockToPart(result);
}
return false;
}

bool SparkMergeTreeWriter::blockToPart(Block & block)
{
auto blocks_with_partition = MergeTreeDataWriter::splitBlockIntoParts(std::move(block), 10, dataWrapper->metadata_snapshot, context);

if (blocks_with_partition.empty())
return false;

for (auto & item : blocks_with_partition)
{
size_t before_write_memory = 0;
if (auto * memory_tracker = CurrentThread::getMemoryTracker())
{
CurrentThread::flushUntrackedMemory();
before_write_memory = memory_tracker->get();
}
dataWrapper->writeTempPart(item, context, part_num);
part_num++;
/// Reset earlier to free memory
item.block.clear();
item.partition.clear();
}

return true;
executor.push(new_block);
}

void SparkMergeTreeWriter::finalize()
{
chunkToPart(squashing->flush());
dataWrapper->finish(context);
executor.finish();
}

std::vector<PartInfo> SparkMergeTreeWriter::getAllPartInfo() const
{
std::vector<PartInfo> res;
auto parts = dataWrapper->unsafeGet();
auto parts = sink_helper.unsafeGet();
res.reserve(parts.size());

for (const auto & part : parts)
Expand All @@ -125,7 +111,7 @@ std::vector<PartInfo> SparkMergeTreeWriter::getAllPartInfo() const
part->getBytesOnDisk(),
part->rows_count,
partition_values,
dataWrapper->write_settings.partition_settings.bucket_dir});
sink_helper.write_settings.partition_settings.bucket_dir});
}
return res;
}
Expand Down Expand Up @@ -161,4 +147,5 @@ String SparkMergeTreeWriter::partInfosToJson(const std::vector<PartInfo> & part_
writer.EndArray();
return result.GetString();
}

}
27 changes: 14 additions & 13 deletions cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
#pragma once

#include <Interpreters/Context.h>
#include <Interpreters/Squashing.h>
#include <Processors/Executors/PushingPipelineExecutor.h>
#include <Storages/MergeTree/IMergeTreeDataPart.h>
#include <Storages/MergeTree/MergeTreeDataWriter.h>
#include <Storages/MergeTree/MergeTreeTool.h>
#include <Storages/MergeTree/SparkMergeTreeSink.h>
#include <Storages/MergeTree/StorageMergeTreeFactory.h>
Expand Down Expand Up @@ -51,24 +50,26 @@ class SparkMergeTreeWriter
{
public:
static String partInfosToJson(const std::vector<PartInfo> & part_infos);
static std::unique_ptr<SparkMergeTreeWriter> create(
const MergeTreeTable & merge_tree_table,
const MergeTreePartitionWriteSettings & write_settings_,
const DB::ContextMutablePtr & context);

SparkMergeTreeWriter(
const MergeTreeTable & merge_tree_table, const GlutenMergeTreeWriteSettings & write_settings_, const DB::ContextPtr & context_);
const DB::Block & header_,
const SinkHelper & sink_helper_,
DB::QueryPipeline && pipeline_,
std::unordered_map<String, String> && partition_values_);

void write(const DB::Block & block);
void finalize();
std::vector<PartInfo> getAllPartInfo() const;

private:
bool chunkToPart(Chunk && plan_chunk);
bool blockToPart(Block & block);

SinkHelperPtr dataWrapper;
DB::ContextPtr context;
DB::Block header;
const SinkHelper & sink_helper;
DB::QueryPipeline pipeline;
DB::PushingPipelineExecutor executor;
std::unordered_map<String, String> partition_values;


std::unique_ptr<DB::Squashing> squashing;

int part_num = 1;
};
}
Loading

0 comments on commit 10c5b8e

Please sign in to comment.