diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala index 26e9972812214..4295c7072fd83 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala @@ -211,4 +211,20 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(query_sql, true, { _ => }) spark.sql("drop table test") } + + + test("intersect all") { + spark.sql("create table t1 (a int, b string) using parquet") + spark.sql("insert into t1 values (1, '1'),(2, '2'),(3, '3'),(4, '4'),(5, '5'),(6, '6')") + spark.sql("create table t2 (a int, b string) using parquet") + spark.sql("insert into t2 values (4, '4'),(5, '5'),(6, '6'),(7, '7'),(8, '8'),(9, '9')") + runQueryAndCompare( + """ + |SELECT a,b FROM t1 INTERSECT ALL SELECT a,b FROM t2 + |""".stripMargin + )(df => checkFallbackOperators(df, 0)) + spark.sql("drop table t1") + spark.sql("drop table t2") + } + } diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp new file mode 100644 index 0000000000000..d0c587e3a0d9b --- /dev/null +++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp @@ -0,0 +1,92 @@ +#include "ReplicateRowsStep.h" + +#include + +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int LOGICAL_ERROR; +} +} + +namespace local_engine +{ +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits + { + { + .preserves_number_of_streams = true, + .preserves_sorting = false, + }, + { + .preserves_number_of_rows = false, + } + }; +} + +ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) + : ITransformingStep(input_stream, transformHeader(input_stream.header), getTraits()) +{ +} + +DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) +{ + DB::Block output; + for (int i = 1; i < input.columns(); i++) + { + output.insert(input.getByPosition(i)); + } + return output; +} + +void ReplicateRowsStep::transformPipeline( + DB::QueryPipelineBuilder & pipeline, + const DB::BuildQueryPipelineSettings & /*settings*/) +{ + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared(header); + }); +} + +void ReplicateRowsStep::updateOutputStream() +{ + output_stream = createOutputStream(input_streams.front(), transformHeader(input_streams.front().header), getDataStreamTraits()); +} + +ReplicateRowsTransform::ReplicateRowsTransform(const DB::Block & input_header_) + : ISimpleTransform(input_header_, ReplicateRowsStep::transformHeader(input_header_), true) +{ +} + +void ReplicateRowsTransform::transform(DB::Chunk & chunk) +{ + auto replica_column = chunk.getColumns().front(); + size_t total_rows = 0; + for (int i = 0; i < replica_column->size(); i++) + { + total_rows += replica_column->get64(i); + } + + auto columns = chunk.detachColumns(); + DB::MutableColumns mutable_columns; + for (int i = 1; i < columns.size(); i++) + { + mutable_columns.push_back(columns[i]->cloneEmpty()); + mutable_columns.back()->reserve(total_rows); + DB::ColumnPtr src_col = columns[i]; + DB::MutableColumnPtr & cur = mutable_columns.back(); + for (int j = 0; j < replica_column->size(); j++) + { + cur->insertManyFrom(*src_col, j, replica_column->getUInt(j)); + } + } + + chunk.setColumns(std::move(mutable_columns), total_rows); +} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.h b/cpp-ch/local-engine/Operator/ReplicateRowsStep.h new file mode 100644 index 0000000000000..c49c68f578652 --- /dev/null +++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace local_engine +{ + +class ReplicateRowsStep : public DB::ITransformingStep +{ +public: + ReplicateRowsStep(const DB::DataStream& input_stream); + + static DB::Block transformHeader(const DB::Block& input); + + String getName() const override { return "ReplicateRowsStep"; } + void transformPipeline(DB::QueryPipelineBuilder& pipeline, + const DB::BuildQueryPipelineSettings& settings) override; +private: + void updateOutputStream() override; +}; + +class ReplicateRowsTransform : public DB::ISimpleTransform +{ +public: + ReplicateRowsTransform(const DB::Block& input_header_); + + String getName() const override { return "ReplicateRowsTransform"; } + void transform(DB::Chunk&) override; + +}; +} diff --git a/cpp-ch/local-engine/Parser/ProjectRelParser.cpp b/cpp-ch/local-engine/Parser/ProjectRelParser.cpp index eb190101f170c..2f75ac396dfe3 100644 --- a/cpp-ch/local-engine/Parser/ProjectRelParser.cpp +++ b/cpp-ch/local-engine/Parser/ProjectRelParser.cpp @@ -21,6 +21,9 @@ #include #include #include +#include + +using namespace DB; namespace local_engine { @@ -109,15 +112,46 @@ ProjectRelParser::SplittedActionsDAGs ProjectRelParser::splitActionsDAGInGenerat return res; } +bool ProjectRelParser::isReplicateRows(substrait::GenerateRel rel) +{ + return plan_parser->isFunction(rel.generator().scalar_function(), "replicaterows"); +} + +DB::QueryPlanPtr ProjectRelParser::parseReplicateRows(DB::QueryPlanPtr query_plan, substrait::GenerateRel generate_rel) +{ + std::vector expressions; + for (int i = 0; i < generate_rel.generator().scalar_function().arguments_size(); ++i) + { + expressions.emplace_back(generate_rel.generator().scalar_function().arguments(i).value()); + } + auto header = query_plan->getCurrentDataStream().header; + auto actions_dag = expressionsToActionsDAG(expressions, header); + auto before_replicate_rows = std::make_unique(query_plan->getCurrentDataStream(), actions_dag); + before_replicate_rows->setStepDescription("Before ReplicateRows"); + steps.emplace_back(before_replicate_rows.get()); + query_plan->addStep(std::move(before_replicate_rows)); + + auto replicate_rows_step = std::make_unique(query_plan->getCurrentDataStream()); + replicate_rows_step->setStepDescription("ReplicateRows"); + steps.emplace_back(replicate_rows_step.get()); + query_plan->addStep(std::move(replicate_rows_step)); + return query_plan; +} + DB::QueryPlanPtr ProjectRelParser::parseGenerate(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & /*rel_stack_*/) { const auto & generate_rel = rel.generate(); + if (isReplicateRows(generate_rel)) + { + return parseReplicateRows(std::move(query_plan), generate_rel); + } std::vector expressions; for (int i = 0; i < generate_rel.child_output_size(); ++i) { expressions.emplace_back(generate_rel.child_output(i)); } + expressions.emplace_back(generate_rel.generator()); auto header = query_plan->getCurrentDataStream().header; auto actions_dag = expressionsToActionsDAG(expressions, header); diff --git a/cpp-ch/local-engine/Parser/ProjectRelParser.h b/cpp-ch/local-engine/Parser/ProjectRelParser.h index ae56939144758..48a16d774d887 100644 --- a/cpp-ch/local-engine/Parser/ProjectRelParser.h +++ b/cpp-ch/local-engine/Parser/ProjectRelParser.h @@ -19,7 +19,6 @@ #include #include #include -#include namespace local_engine { @@ -50,6 +49,9 @@ class ProjectRelParser : public RelParser /// Split actions_dag of generate rel into 3 parts: before array join + during array join + after array join static SplittedActionsDAGs splitActionsDAGInGenerate(ActionsDAGPtr actions_dag); + bool isReplicateRows(substrait::GenerateRel rel); + + DB::QueryPlanPtr parseReplicateRows(QueryPlanPtr query_plan, substrait::GenerateRel generate_rel); const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { diff --git a/cpp-ch/local-engine/Parser/RelParser.h b/cpp-ch/local-engine/Parser/RelParser.h index 6ca8af5359551..0228c2867a269 100644 --- a/cpp-ch/local-engine/Parser/RelParser.h +++ b/cpp-ch/local-engine/Parser/RelParser.h @@ -85,7 +85,6 @@ class RelParser static std::map parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension &advanced_extension); static std::string getStringConfig(const std::map & configs, const std::string & key, const std::string & default_value = ""); -private: SerializedPlanParser * plan_parser; }; diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 77819fd73e75d..581160a02dfb1 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1153,6 +1153,12 @@ std::pair SerializedPlanParser::convertStructFieldType(const #undef UINT_CONVERT } +bool SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel, String function_name) +{ + auto func_signature = function_mapping[std::to_string(rel.function_reference())]; + return func_signature.starts_with(function_name + ":"); +} + ActionsDAGPtr SerializedPlanParser::parseFunction( const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index ad2b0d50ec6a8..f92453c1dfdec 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -255,6 +255,7 @@ class SerializedPlanParser friend class NonNullableColumnsResolver; friend class JoinRelParser; friend class MergeTreeRelParser; + friend class ProjectRelParser; std::unique_ptr createExecutor(DB::QueryPlanPtr query_plan); @@ -389,6 +390,8 @@ class SerializedPlanParser const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names); static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); + bool isFunction(substrait::Expression_ScalarFunction rel, String function_name); + int name_no = 0; std::unordered_map function_mapping; std::vector input_iters; diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 806ec844de601..e7e9c7ffe9004 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -282,6 +282,7 @@ object ExpressionMappings { Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID), Sig[SparkPartitionID](SPARK_PARTITION_ID), Sig[WidthBucket](WIDTH_BUCKET), + Sig[ReplicateRows](REPLICATE_ROWS), // Decimal Sig[UnscaledValue](UNSCALED_VALUE), // Generator function diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 7060e297ea10e..278f119226457 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -314,6 +314,7 @@ object ExpressionNames { final val SPARK_PARTITION_ID = "spark_partition_id" final val MONOTONICALLY_INCREASING_ID = "monotonically_increasing_id" final val WIDTH_BUCKET = "width_bucket" + final val REPLICATE_ROWS = "replicaterows" // Directly use child expression transformer final val KNOWN_NULLABLE = "known_nullable"