Skip to content

Commit

Permalink
[GLUTEN-3817][CH] Optimize mergetree prewhere (#3818)
Browse files Browse the repository at this point in the history
* [GLUTEN-3817][CH] Optimize mergetree prewhere

* fix ci error
  • Loading branch information
loneylee authored Nov 24, 2023
1 parent 2a8cd88 commit e29a44e
Show file tree
Hide file tree
Showing 9 changed files with 510 additions and 4 deletions.
7 changes: 7 additions & 0 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,13 @@ void BackendInitializerUtil::initCompiledExpressionCache(DB::Context::Configurat
#endif
}

void BackendInitializerUtil::init_json(std::string * plan_json)
{
auto plan_ptr = std::make_unique<substrait::Plan>();
google::protobuf::util::JsonStringToMessage(plan_json->c_str(), plan_ptr.get());
return init(new String(plan_ptr->SerializeAsString()));
}

void BackendInitializerUtil::init(std::string * plan)
{
std::map<std::string, std::string> backend_conf_map = getBackendConfMap(plan);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class BackendInitializerUtil
/// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime
/// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver
static void init(std::string * plan);

static void init_json(std::string * plan_json);
static void updateConfig(DB::ContextMutablePtr, std::string *);


Expand Down
347 changes: 347 additions & 0 deletions cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


#include <google/protobuf/wrappers.pb.h>

#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Common/CHUtil.h>
#include <Common/MergeTreeTool.h>

#include "MergeTreeRelParser.h"


namespace DB
{
namespace ErrorCodes
{
extern const int NO_SUCH_DATA_PART;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_FUNCTION;
extern const int UNKNOWN_TYPE;

}
}

namespace local_engine
{
using namespace DB;

/// Find minimal position of any of the column in primary key.
static Int64 findMinPosition(const NameSet & condition_table_columns, const NameToIndexMap & primary_key_positions)
{
Int64 min_position = std::numeric_limits<Int64>::max() - 1;

for (const auto & column : condition_table_columns)
{
auto it = primary_key_positions.find(column);
if (it != primary_key_positions.end())
min_position = std::min(min_position, static_cast<Int64>(it->second));
}

return min_position;
}

DB::QueryPlanPtr
MergeTreeRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel_, std::list<const substrait::Rel *> & /*rel_stack_*/)
{
const auto & rel = rel_.read();
assert(rel.has_extension_table());
google::protobuf::StringValue table;
table.ParseFromString(rel.extension_table().detail().value());
auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value());
DB::Block header;
if (rel.has_base_schema() && rel.base_schema().names_size())
{
header = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
}
else
{
// For count(*) case, there will be an empty base_schema, so we try to read at least once column
auto all_parts_dir = MergeTreeUtil::getAllMergeTreeParts(std::filesystem::path("/") / merge_tree_table.relative_path);
if (all_parts_dir.empty())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Empty mergetree directory: {}", merge_tree_table.relative_path);
auto part_names_types_list = MergeTreeUtil::getSchemaFromMergeTreePart(all_parts_dir[0]);
NamesAndTypesList one_column_name_type;
one_column_name_type.push_back(part_names_types_list.front());
header = BlockUtil::buildHeader(one_column_name_type);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "Try to read ({}) instead of empty header", header.dumpNames());
}
auto names_and_types_list = header.getNamesAndTypesList();
auto storage_factory = StorageMergeTreeFactory::instance();
auto metadata = buildMetaData(names_and_types_list, context);
query_context.metadata = metadata;

auto storage = storage_factory.getStorage(
StorageID(merge_tree_table.database, merge_tree_table.table),
metadata->getColumns(),
[&]() -> CustomStorageMergeTreePtr
{
auto custom_storage_merge_tree = std::make_shared<CustomStorageMergeTree>(
StorageID(merge_tree_table.database, merge_tree_table.table),
merge_tree_table.relative_path,
*metadata,
false,
global_context,
"",
MergeTreeData::MergingParams(),
buildMergeTreeSettings());
custom_storage_merge_tree->loadDataParts(false, std::nullopt);
return custom_storage_merge_tree;
});

for (const auto & [name, sizes] : storage->getColumnSizes())
column_sizes[name] = sizes.data_compressed;

query_context.storage_snapshot = std::make_shared<StorageSnapshot>(*storage, metadata);
query_context.custom_storage_merge_tree = storage;
auto query_info = buildQueryInfo(names_and_types_list);

std::set<String> non_nullable_columns;
if (rel.has_filter())
{
NonNullableColumnsResolver non_nullable_columns_resolver(header, *getPlanParser(), rel.filter());
non_nullable_columns = non_nullable_columns_resolver.resolve();
query_info->prewhere_info = parsePreWhereInfo(rel.filter(), header);
}
auto data_parts = query_context.custom_storage_merge_tree->getAllDataPartsVector();
int min_block = merge_tree_table.min_block;
int max_block = merge_tree_table.max_block;
MergeTreeData::DataPartsVector selected_parts;
std::copy_if(
std::begin(data_parts),
std::end(data_parts),
std::inserter(selected_parts, std::begin(selected_parts)),
[min_block, max_block](MergeTreeData::DataPartPtr part)
{ return part->info.min_block >= min_block && part->info.max_block < max_block; });
if (selected_parts.empty())
throw Exception(ErrorCodes::NO_SUCH_DATA_PART, "part {} to {} not found.", min_block, max_block);
auto read_step = query_context.custom_storage_merge_tree->reader.readFromParts(
selected_parts,
/* alter_conversions = */ {},
names_and_types_list.getNames(),
query_context.storage_snapshot,
*query_info,
context,
context->getSettingsRef().max_block_size,
1);

steps.emplace_back(read_step.get());
query_plan->addStep(std::move(read_step));
if (!non_nullable_columns.empty())
{
auto input_header = query_plan->getCurrentDataStream().header;
std::erase_if(non_nullable_columns, [input_header](auto item) -> bool { return !input_header.has(item); });
auto * remove_null_step = getPlanParser()->addRemoveNullableStep(*query_plan, non_nullable_columns);
if (remove_null_step)
steps.emplace_back(remove_null_step);
}
return query_plan;
}

PrewhereInfoPtr MergeTreeRelParser::parsePreWhereInfo(const substrait::Expression & rel, Block & input)
{
std::string filter_name;
auto prewhere_info = std::make_shared<PrewhereInfo>();
prewhere_info->prewhere_actions = optimizePrewhereAction(rel, filter_name, input);
prewhere_info->prewhere_column_name = filter_name;
prewhere_info->need_filter = true;
prewhere_info->remove_prewhere_column = true;
prewhere_info->prewhere_actions->projectInput(false);
for (const auto & name : input.getNames())
prewhere_info->prewhere_actions->tryRestoreColumn(name);
return prewhere_info;
}

DB::ActionsDAGPtr MergeTreeRelParser::optimizePrewhereAction(const substrait::Expression & rel, std::string & filter_name, Block & block)
{
Conditions res;
std::set<Int64> pk_positions;
analyzeExpressions(res, rel, pk_positions, block);

Int64 min_valid_pk_pos = -1;
for (auto pk_pos : pk_positions)
{
if (pk_pos != min_valid_pk_pos + 1)
break;
min_valid_pk_pos = pk_pos;
}

// TODO need to test
for (auto & cond : res)
if (cond.min_position_in_primary_key > min_valid_pk_pos)
cond.min_position_in_primary_key = std::numeric_limits<Int64>::max() - 1;

// filter less size column first
res.sort();
auto filter_action = std::make_shared<ActionsDAG>(block.getNamesAndTypesList());

if (res.size() == 1)
{
parseToAction(filter_action, res.back().node, filter_name);
}
else
{
DB::ActionsDAG::NodeRawConstPtrs args;

for (Condition cond : res)
{
String ignore;
parseToAction(filter_action, cond.node, ignore);
args.emplace_back(&filter_action->getNodes().back());
}

auto function_builder = FunctionFactory::instance().get("and", context);
std::string args_name = join(args, ',');
filter_name = +"and(" + args_name + ")";
const auto * and_function = &filter_action->addFunction(function_builder, args, filter_name);
filter_action->addOrReplaceInOutputs(*and_function);
}

filter_action->removeUnusedActions(Names{filter_name}, false, true);
return filter_action;
}

void MergeTreeRelParser::parseToAction(ActionsDAGPtr & filter_action, const substrait::Expression & rel, std::string & filter_name)
{
if (rel.has_scalar_function())
getPlanParser()->parseFunctionWithDAG(rel, filter_name, filter_action, true);
else
{
const auto * in_node = parseExpression(filter_action, rel);
filter_action->addOrReplaceInOutputs(*in_node);
filter_name = in_node->result_name;
}
}

void MergeTreeRelParser::analyzeExpressions(
Conditions & res, const substrait::Expression & rel, std::set<Int64> & pk_positions, Block & block)
{
if (rel.has_scalar_function() && getCHFunctionName(rel.scalar_function()) == "and")
{
int arguments_size = rel.scalar_function().arguments_size();

for (int i = 0; i < arguments_size; ++i)
{
auto argument = rel.scalar_function().arguments(i);
analyzeExpressions(res, argument.value(), pk_positions, block);
}
}
else
{
Condition cond(rel);
collectColumns(rel, cond.table_columns, block);
cond.columns_size = getColumnsSize(cond.table_columns);

// TODO: get primary_key_names
const NameToIndexMap primary_key_names_positions;
cond.min_position_in_primary_key = findMinPosition(cond.table_columns, primary_key_names_positions);
pk_positions.emplace(cond.min_position_in_primary_key);

res.emplace_back(std::move(cond));
}
}


UInt64 MergeTreeRelParser::getColumnsSize(const NameSet & columns)
{
UInt64 size = 0;
for (const auto & column : columns)
if (column_sizes.contains(column))
size += column_sizes[column];

return size;
}

void MergeTreeRelParser::collectColumns(const substrait::Expression & rel, NameSet & columns, Block & block)
{
switch (rel.rex_type_case())
{
case substrait::Expression::RexTypeCase::kLiteral: {
return;
}

case substrait::Expression::RexTypeCase::kSelection: {
const size_t idx = rel.selection().direct_reference().struct_field().field();
if (const Names names = block.getNames(); names.size() > idx)
columns.insert(names[idx]);

return;
}

case substrait::Expression::RexTypeCase::kCast: {
const auto & input = rel.cast().input();
collectColumns(input, columns, block);
return;
}

case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = rel.if_then();

auto condition_nums = if_then.ifs_size();
for (int i = 0; i < condition_nums; ++i)
{
const auto & ifs = if_then.ifs(i);
collectColumns(ifs.if_(), columns, block);
collectColumns(ifs.then(), columns, block);
}

return;
}

case substrait::Expression::RexTypeCase::kScalarFunction: {
for (const auto & arg : rel.scalar_function().arguments())
collectColumns(arg.value(), columns, block);

return;
}

case substrait::Expression::RexTypeCase::kSingularOrList: {
const auto & options = rel.singular_or_list().options();
/// options is empty always return false
if (options.empty())
return;

collectColumns(rel.singular_or_list().value(), columns, block);
return;
}

default:
throw Exception(
ErrorCodes::UNKNOWN_TYPE,
"Unsupported spark expression type {} : {}",
magic_enum::enum_name(rel.rex_type_case()),
rel.DebugString());
}
}


String MergeTreeRelParser::getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func)
{
auto func_signature = getPlanParser()->function_mapping.at(std::to_string(substrait_func.function_reference()));
auto pos = func_signature.find(':');
auto func_name = func_signature.substr(0, pos);

auto it = SCALAR_FUNCTIONS.find(func_name);
if (it == SCALAR_FUNCTIONS.end())
throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported substrait function on mergetree prewhere parser: {}", func_name);
return it->second;
}

}
Loading

0 comments on commit e29a44e

Please sign in to comment.