Skip to content

Commit

Permalink
refactor for rel parsers
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 11, 2024
1 parent 5f65501 commit 1072a7e
Show file tree
Hide file tree
Showing 32 changed files with 345 additions and 230 deletions.
1 change: 1 addition & 0 deletions cpp-ch/local-engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_subdirectory(proto)
add_headers_and_sources(builder Builder)
add_headers_and_sources(join Join)
add_headers_and_sources(parser Parser)
add_headers_and_sources(parser Parser/RelParsers)
add_headers_and_sources(rewriter Rewriter)
add_headers_and_sources(storages Storages)
add_headers_and_sources(storages Storages/Output)
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#include <Functions/registerFunctions.h>
#include <IO/SharedThreadPools.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/SubstraitParserUtils.h>
#include <Planner/PlannerActionsVisitor.h>
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <Compression/CompressedReadBuffer.h>
#include <Interpreters/TableJoin.h>
#include <Join/StorageJoinFromReadBuffer.h>
#include <Parser/JoinRelParser.h>
#include <Parser/RelParsers/JoinRelParser.h>
#include <Parser/TypeParser.h>
#include <QueryPipeline/ProfileInfo.h>
#include <Shuffle/ShuffleReader.h>
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeTuple.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/TypeParser.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once
#include <Parser/AggregateFunctionParser.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

Expand All @@ -30,7 +30,7 @@ class AggregateRelParser : public RelParser
~AggregateRelParser() override = default;
DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) override;
const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.aggregate().input(); }
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override { return &rel.aggregate().input(); }

private:
struct AggregateInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "CrossRelParser.h"
#include <optional>

#include <Interpreters/CollectJoinOnKeysVisitor.h>
#include <Interpreters/GraceHashJoin.h>
Expand Down Expand Up @@ -73,7 +74,7 @@ CrossRelParser::parse(DB::QueryPlanPtr /*query_plan*/, const substrait::Rel & /*
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call parse().");
}

const substrait::Rel & CrossRelParser::getSingleInput(const substrait::Rel & /*rel*/)
std::optional<const substrait::Rel *> CrossRelParser::getSingleInput(const substrait::Rel & /*rel*/)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput().");
}
Expand Down Expand Up @@ -194,7 +195,8 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB:
else
{
JoinPtr hash_join = std::make_shared<HashJoin>(table_join, right->getCurrentDataStream().header.cloneEmpty());
QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false);
QueryPlanStepPtr join_step
= std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false);
join_step->setStepDescription("CROSS_JOIN");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#pragma once

#include <memory>
#include <Parser/RelParser.h>
#include <optional>
#include <Parser/RelParsers/RelParser.h>
#include <substrait/algebra.pb.h>

namespace DB
Expand All @@ -42,7 +43,7 @@ class CrossRelParser : public RelParser

DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack) override;

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override;
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override;

private:
std::unordered_map<std::string, std::string> & function_mapping;
Expand All @@ -55,7 +56,11 @@ class CrossRelParser : public RelParser
void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right);
void addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join);
bool applyJoinFilter(
DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition);
DB::TableJoin & table_join,
const substrait::CrossRel & join_rel,
DB::QueryPlan & left,
DB::QueryPlan & right,
bool allow_mixed_condition);
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Operator/ExpandStep.h>
#include <Parser/ExpandField.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Common/logger_useful.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
* limitations under the License.
*/
#pragma once
#include <Parser/RelParser.h>
#include <optional>
#include <Parser/RelParsers/RelParser.h>


namespace local_engine
Expand All @@ -29,6 +30,6 @@ class ExpandRelParser : public RelParser
DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) override;

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.expand().input(); }
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override { return &rel.expand().input(); }
};
}
49 changes: 49 additions & 0 deletions cpp-ch/local-engine/Parser/RelParsers/FetchRelParser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include <optional>
#include <Parser/SerializedPlanParser.h>
#include <Processors/QueryPlan/LimitStep.h>
#include "RelParser.h"
namespace local_engine
{
class FetchRelParser : public RelParser
{
public:
explicit FetchRelParser(SerializedPlanParser * plan_parser_) : RelParser(plan_parser_) { }
~FetchRelParser() override = default;

DB::QueryPlanPtr parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> &)
{
const auto & limit = rel.fetch();
auto limit_step = std::make_unique<DB::LimitStep>(query_plan->getCurrentDataStream(), limit.count(), limit.offset());
limit_step->setStepDescription("LIMIT");
steps.push_back(limit_step.get());
query_plan->addStep(std::move(limit_step));
return query_plan;
}
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override { return &rel.fetch().input(); }
};

void registerFetchRelParser(RelParserFactory & factory)
{
auto builder = [](SerializedPlanParser * plan_parser_) { return std::make_unique<FetchRelParser>(plan_parser_); };
factory.registerBuilder(substrait::Rel::RelTypeCase::kFetch, builder);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

#pragma once

#include <Parser/RelParser.h>
#include <optional>
#include <Parser/RelParsers/RelParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

Expand All @@ -29,11 +30,12 @@ class FilterRelParser : public RelParser
explicit FilterRelParser(SerializedPlanParser * plan_paser_);
~FilterRelParser() override = default;

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.filter().input(); }
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override { return &rel.filter().input(); }

DB::QueryPlanPtr
parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_) override;

private:
// Poco::Logger * logger = &Poco::Logger::get("ProjectRelParser");
};
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "JoinRelParser.h"
#include <optional>

#include <Core/Block.h>
#include <Functions/FunctionFactory.h>
Expand All @@ -27,10 +28,10 @@
#include <Interpreters/TableJoin.h>
#include <Join/BroadCastJoinBuilder.h>
#include <Join/StorageJoinFromReadBuffer.h>
#include <Operator/EarlyStopStep.h>
#include <Parser/AdvancedParametersParseUtil.h>
#include <Parser/SerializedPlanParser.h>
#include <Parsers/ASTIdentifier.h>
#include <Operator/EarlyStopStep.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/JoinStep.h>
Expand All @@ -56,8 +57,8 @@ namespace local_engine
{
std::shared_ptr<DB::TableJoin> createDefaultTableJoin(substrait::JoinRel_JoinType join_type, bool is_existence_join, ContextPtr & context)
{
auto table_join = std::make_shared<TableJoin>(
context->getSettingsRef(), context->getGlobalTemporaryVolume(), context->getTempDataOnDisk());
auto table_join
= std::make_shared<TableJoin>(context->getSettingsRef(), context->getGlobalTemporaryVolume(), context->getTempDataOnDisk());

std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, is_existence_join);
table_join->setKind(kind_and_strictness.first);
Expand All @@ -79,7 +80,7 @@ JoinRelParser::parse(DB::QueryPlanPtr /*query_plan*/, const substrait::Rel & /*r
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call parse().");
}

const substrait::Rel & JoinRelParser::getSingleInput(const substrait::Rel & /*rel*/)
std::optional<const substrait::Rel *> JoinRelParser::getSingleInput(const substrait::Rel & /*rel*/)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput().");
}
Expand Down Expand Up @@ -282,13 +283,22 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
auto input_header = left->getCurrentDataStream().header;
DB::ActionsDAG filter_is_not_null_dag{input_header.getColumnsWithTypeAndName()};
// when is_null_aware_anti_join is true, there is only one join key
const auto * key_field = filter_is_not_null_dag.getInputs()[join.expression().scalar_function().arguments().at(0).value().selection().direct_reference().struct_field().field()];
const auto * key_field = filter_is_not_null_dag.getInputs()[join.expression()
.scalar_function()
.arguments()
.at(0)
.value()
.selection()
.direct_reference()
.struct_field()
.field()];

auto result_node = filter_is_not_null_dag.tryFindInOutputs(key_field->result_name);
// add a function isNotNull to filter the null key on the left side
const auto * cond_node = plan_parser->toFunctionNode(filter_is_not_null_dag, "isNotNull", {result_node});
filter_is_not_null_dag.addOrReplaceInOutputs(*cond_node);
auto filter_step = std::make_unique<FilterStep>(left->getCurrentDataStream(), std::move(filter_is_not_null_dag), cond_node->result_name, true);
auto filter_step = std::make_unique<FilterStep>(
left->getCurrentDataStream(), std::move(filter_is_not_null_dag), cond_node->result_name, true);
left->addStep(std::move(filter_step));
}
// other case: is_empty_hash_table, don't need to handle
Expand Down Expand Up @@ -342,8 +352,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
= couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header);
if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0
&& join_opt_info.partitions_num > 0
&& join_opt_info.right_table_rows / join_opt_info.partitions_num
< join_config.multi_join_on_clauses_build_side_rows_limit)
&& join_opt_info.right_table_rows / join_opt_info.partitions_num < join_config.multi_join_on_clauses_build_side_rows_limit)
{
query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <unordered_set>
#include <Core/Joins.h>
#include <Interpreters/TableJoin.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <substrait/algebra.pb.h>

namespace DB
Expand All @@ -44,7 +44,7 @@ class JoinRelParser : public RelParser

DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack) override;

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override;
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override;

private:
std::unordered_map<std::string, std::string> & function_mapping;
Expand All @@ -69,8 +69,8 @@ class JoinRelParser : public RelParser

void existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols);

static std::unordered_set<DB::JoinTableSide> extractTableSidesFromExpression(
const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header);
static std::unordered_set<DB::JoinTableSide>
extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header);

bool couldRewriteToMultiJoinOnClauses(
const DB::TableJoin::JoinOnClause & prefix_clause,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
#pragma once

#include <memory>
#include <optional>
#include <substrait/algebra.pb.h>

#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>

namespace DB
{
Expand Down Expand Up @@ -50,7 +51,7 @@ class MergeTreeRelParser : public RelParser
DB::QueryPlanPtr parseReadRel(
DB::QueryPlanPtr query_plan, const substrait::ReadRel & read_rel, const substrait::ReadRel::ExtensionTable & extension_table);

const substrait::Rel & getSingleInput(const substrait::Rel &) override
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel &) override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "MergeTreeRelParser can't call getSingleInput().");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
* limitations under the License.
*/
#pragma once
#include <optional>
#include <Core/Block.h>
#include <Core/SortDescription.h>
#include <Parser/RelParser.h>
#include <Parser/RelParsers/RelParser.h>
#include <Poco/Logger.h>

namespace local_engine
Expand All @@ -29,7 +30,7 @@ class ProjectRelParser : public RelParser
{
ActionsDAG before_array_join; /// Optional
ActionsDAG array_join;
ActionsDAG after_array_join; /// Optional
ActionsDAG after_array_join; /// Optional
};

explicit ProjectRelParser(SerializedPlanParser * plan_paser_);
Expand All @@ -44,21 +45,21 @@ class ProjectRelParser : public RelParser
DB::QueryPlanPtr parseProject(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_);
DB::QueryPlanPtr parseGenerate(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_);

static const DB::ActionsDAG::Node * findArrayJoinNode(const ActionsDAG& actions_dag);
static const DB::ActionsDAG::Node * findArrayJoinNode(const ActionsDAG & actions_dag);

/// Split actions_dag of generate rel into 3 parts: before array join + during array join + after array join
static SplittedActionsDAGs splitActionsDAGInGenerate(const ActionsDAG& actions_dag);
static SplittedActionsDAGs splitActionsDAGInGenerate(const ActionsDAG & actions_dag);

bool isReplicateRows(substrait::GenerateRel rel);

DB::QueryPlanPtr parseReplicateRows(QueryPlanPtr query_plan, substrait::GenerateRel generate_rel);

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override
std::optional<const substrait::Rel *> getSingleInput(const substrait::Rel & rel) override
{
if (rel.has_generate())
return rel.generate().input();
return &rel.generate().input();

return rel.project().input();
return &rel.project().input();
}
};
}
Loading

0 comments on commit 1072a7e

Please sign in to comment.