diff --git a/cpp-ch/local-engine/Parser/FilterRelParser.cpp b/cpp-ch/local-engine/Parser/FilterRelParser.cpp new file mode 100644 index 000000000000..19facf3bff96 --- /dev/null +++ b/cpp-ch/local-engine/Parser/FilterRelParser.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FilterRelParser.h" +#include +#include + +namespace local_engine +{ +FilterRelParser::FilterRelParser(SerializedPlanParser * plan_paser_) + : RelParser(plan_paser_) +{ +} +DB::QueryPlanPtr FilterRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & /*rel_stack_*/) +{ + ExpressionsRewriter rewriter(getPlanParser()); + substrait::Rel final_rel = rel; + rewriter.rewrite(final_rel); + + const auto & filter_rel = rel.filter(); + std::string filter_name; + + auto input_header = query_plan->getCurrentDataStream().header; + DB::ActionsDAGPtr actions_dag = std::make_shared(input_header.getColumnsWithTypeAndName()); + const auto condition_node = parseExpression(actions_dag, filter_rel.condition()); + if (filter_rel.condition().has_scalar_function()) + { + actions_dag->addOrReplaceInOutputs(*condition_node); + } + filter_name = condition_node->result_name; + + bool remove_filter_column = true; + auto input_names = query_plan->getCurrentDataStream().header.getNames(); + DB::NameSet input_with_condition(input_names.begin(), input_names.end()); + if (input_with_condition.contains(condition_node->result_name)) + remove_filter_column = false; + else + input_with_condition.insert(condition_node->result_name); + + actions_dag->removeUnusedActions(input_with_condition); + NonNullableColumnsResolver non_nullable_columns_resolver(input_header, *getPlanParser(), filter_rel.condition()); + auto non_nullable_columns = non_nullable_columns_resolver.resolve(); + + auto filter_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag, filter_name, remove_filter_column); + filter_step->setStepDescription("WHERE"); + steps.emplace_back(filter_step.get()); + query_plan->addStep(std::move(filter_step)); + + // remove nullable + auto * remove_null_step = getPlanParser()->addRemoveNullableStep(*query_plan, non_nullable_columns); + if (remove_null_step) + { + steps.emplace_back(remove_null_step); + } + + return query_plan; +} + +void registerFilterRelParser(RelParserFactory & factory) +{ + auto builder + = [](SerializedPlanParser * plan_parser) -> std::unique_ptr { return std::make_unique(plan_parser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kFilter, builder); +} +} diff --git a/cpp-ch/local-engine/Parser/FilterRelParser.h b/cpp-ch/local-engine/Parser/FilterRelParser.h new file mode 100644 index 000000000000..a7151595243f --- /dev/null +++ b/cpp-ch/local-engine/Parser/FilterRelParser.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace local_engine +{ +class FilterRelParser : public RelParser +{ +public: + explicit FilterRelParser(SerializedPlanParser * plan_paser_); + ~FilterRelParser() override = default; + + const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.filter().input(); } + + DB::QueryPlanPtr + parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & rel_stack_) override; +private: + // Poco::Logger * logger = &Poco::Logger::get("ProjectRelParser"); +}; +} diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index fb761a0ec1de..9c5ea815980f 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -101,6 +101,7 @@ void registerExpandRelParser(RelParserFactory & factory); void registerAggregateParser(RelParserFactory & factory); void registerProjectRelParser(RelParserFactory & factory); void registerJoinRelParser(RelParserFactory & factory); +void registerFilterRelParser(RelParserFactory & factory); void registerRelParsers() { @@ -111,5 +112,6 @@ void registerRelParsers() registerAggregateParser(factory); registerProjectRelParser(factory); registerJoinRelParser(factory); + registerFilterRelParser(factory); } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8dfcd0133cda..596442526628 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -536,49 +536,6 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list query_plan->addStep(std::move(limit_step)); break; } - case substrait::Rel::RelTypeCase::kFilter: { - rel_stack.push_back(&rel); - const auto & filter = rel.filter(); - query_plan = parseOp(filter.input(), rel_stack); - rel_stack.pop_back(); - std::string filter_name; - - ActionsDAGPtr actions_dag = nullptr; - if (filter.condition().has_scalar_function()) - { - actions_dag = parseFunction(query_plan->getCurrentDataStream().header, filter.condition(), filter_name, nullptr, true); - } - else - { - actions_dag = std::make_shared(blockToNameAndTypeList(query_plan->getCurrentDataStream().header)); - const auto * node = parseExpression(actions_dag, filter.condition()); - filter_name = node->result_name; - } - - bool remove_filter_column = true; - auto input = query_plan->getCurrentDataStream().header.getNames(); - NameSet input_with_condition(input.begin(), input.end()); - if (input_with_condition.contains(filter_name)) - remove_filter_column = false; - else - input_with_condition.emplace(filter_name); - - actions_dag->removeUnusedActions(input_with_condition); - NonNullableColumnsResolver non_nullable_columns_resolver(query_plan->getCurrentDataStream().header, *this, filter.condition()); - auto non_nullable_columns = non_nullable_columns_resolver.resolve(); - auto filter_step - = std::make_unique(query_plan->getCurrentDataStream(), actions_dag, filter_name, remove_filter_column); - filter_step->setStepDescription("WHERE"); - steps.emplace_back(filter_step.get()); - query_plan->addStep(std::move(filter_step)); - // remove nullable - auto * remove_null_step = addRemoveNullableStep(*query_plan, non_nullable_columns); - if (remove_null_step) - { - steps.emplace_back(remove_null_step); - } - break; - } case substrait::Rel::RelTypeCase::kRead: { const auto & read = rel.read(); assert(read.has_local_files() || read.has_extension_table() && "Only support local parquet files or merge tree read rel"); @@ -609,6 +566,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list } break; } + case substrait::Rel::RelTypeCase::kFilter: case substrait::Rel::RelTypeCase::kGenerate: case substrait::Rel::RelTypeCase::kProject: case substrait::Rel::RelTypeCase::kAggregate: diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 2d3980ad9aac..faa9f86e924f 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -296,6 +296,8 @@ class SerializedPlanParser static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); + static ContextMutablePtr global_context; static Context::ConfigurationPtr config; static SharedContextHolder shared_context; @@ -359,7 +361,6 @@ class SerializedPlanParser static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names); - IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); int name_no = 0;