Skip to content

Commit

Permalink
support replicaterows
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Jul 2, 2024
1 parent eb1b913 commit 89371e4
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,20 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
spark.sql("drop table test")
}


test("intersect all") {
spark.sql("create table t1 (a int, b string) using parquet")
spark.sql("insert into t1 values (1, '1'),(2, '2'),(3, '3'),(4, '4'),(5, '5'),(6, '6')")
spark.sql("create table t2 (a int, b string) using parquet")
spark.sql("insert into t2 values (4, '4'),(5, '5'),(6, '6'),(7, '7'),(8, '8'),(9, '9')")
runQueryAndCompare(
"""
|SELECT a,b FROM t1 INTERSECT ALL SELECT a,b FROM t2
|""".stripMargin
)(df => checkFallbackOperators(df, 0))
spark.sql("drop table t1")
spark.sql("drop table t2")
}

}
92 changes: 92 additions & 0 deletions cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "ReplicateRowsStep.h"

#include <iostream>

#include <QueryPipeline/QueryPipelineBuilder.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}

namespace local_engine
{
static DB::ITransformingStep::Traits getTraits()
{
return DB::ITransformingStep::Traits
{
{
.preserves_number_of_streams = true,
.preserves_sorting = false,
},
{
.preserves_number_of_rows = false,
}
};
}

ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream)
: ITransformingStep(input_stream, transformHeader(input_stream.header), getTraits())
{
}

DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input)
{
DB::Block output;
for (int i = 1; i < input.columns(); i++)
{
output.insert(input.getByPosition(i));
}
return output;
}

void ReplicateRowsStep::transformPipeline(
DB::QueryPipelineBuilder & pipeline,
const DB::BuildQueryPipelineSettings & /*settings*/)
{
pipeline.addSimpleTransform(
[&](const DB::Block & header)
{
return std::make_shared<ReplicateRowsTransform>(header);
});
}

void ReplicateRowsStep::updateOutputStream()
{
output_stream = createOutputStream(input_streams.front(), transformHeader(input_streams.front().header), getDataStreamTraits());
}

ReplicateRowsTransform::ReplicateRowsTransform(const DB::Block & input_header_)
: ISimpleTransform(input_header_, ReplicateRowsStep::transformHeader(input_header_), true)
{
}

void ReplicateRowsTransform::transform(DB::Chunk & chunk)
{
auto replica_column = chunk.getColumns().front();
size_t total_rows = 0;
for (int i = 0; i < replica_column->size(); i++)
{
total_rows += replica_column->get64(i);
}

auto columns = chunk.detachColumns();
DB::MutableColumns mutable_columns;
for (int i = 1; i < columns.size(); i++)
{
mutable_columns.push_back(columns[i]->cloneEmpty());
mutable_columns.back()->reserve(total_rows);
DB::ColumnPtr src_col = columns[i];
DB::MutableColumnPtr & cur = mutable_columns.back();
for (int j = 0; j < replica_column->size(); j++)
{
cur->insertManyFrom(*src_col, j, replica_column->getUInt(j));
}
}

chunk.setColumns(std::move(mutable_columns), total_rows);
}
}
32 changes: 32 additions & 0 deletions cpp-ch/local-engine/Operator/ReplicateRowsStep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <Processors/ISimpleTransform.h>
#include <Processors/QueryPlan/ITransformingStep.h>

namespace local_engine
{

class ReplicateRowsStep : public DB::ITransformingStep
{
public:
ReplicateRowsStep(const DB::DataStream& input_stream);

static DB::Block transformHeader(const DB::Block& input);

String getName() const override { return "ReplicateRowsStep"; }
void transformPipeline(DB::QueryPipelineBuilder& pipeline,
const DB::BuildQueryPipelineSettings& settings) override;
private:
void updateOutputStream() override;
};

class ReplicateRowsTransform : public DB::ISimpleTransform
{
public:
ReplicateRowsTransform(const DB::Block& input_header_);

String getName() const override { return "ReplicateRowsTransform"; }
void transform(DB::Chunk&) override;

};
}
34 changes: 34 additions & 0 deletions cpp-ch/local-engine/Parser/ProjectRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Rewriter/ExpressionRewriter.h>
#include <Common/CHUtil.h>
#include <Operator/ReplicateRowsStep.h>

using namespace DB;

namespace local_engine
{
Expand Down Expand Up @@ -109,15 +112,46 @@ ProjectRelParser::SplittedActionsDAGs ProjectRelParser::splitActionsDAGInGenerat
return res;
}

bool ProjectRelParser::isReplicateRows(substrait::GenerateRel rel)
{
return plan_parser->isFunction(rel.generator().scalar_function(), "replicaterows");
}

DB::QueryPlanPtr ProjectRelParser::parseReplicateRows(DB::QueryPlanPtr query_plan, substrait::GenerateRel generate_rel)
{
std::vector<substrait::Expression> expressions;
for (int i = 0; i < generate_rel.generator().scalar_function().arguments_size(); ++i)
{
expressions.emplace_back(generate_rel.generator().scalar_function().arguments(i).value());
}
auto header = query_plan->getCurrentDataStream().header;
auto actions_dag = expressionsToActionsDAG(expressions, header);
auto before_replicate_rows = std::make_unique<DB::ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
before_replicate_rows->setStepDescription("Before ReplicateRows");
steps.emplace_back(before_replicate_rows.get());
query_plan->addStep(std::move(before_replicate_rows));

auto replicate_rows_step = std::make_unique<ReplicateRowsStep>(query_plan->getCurrentDataStream());
replicate_rows_step->setStepDescription("ReplicateRows");
steps.emplace_back(replicate_rows_step.get());
query_plan->addStep(std::move(replicate_rows_step));
return query_plan;
}

DB::QueryPlanPtr
ProjectRelParser::parseGenerate(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & /*rel_stack_*/)
{
const auto & generate_rel = rel.generate();
if (isReplicateRows(generate_rel))
{
return parseReplicateRows(std::move(query_plan), generate_rel);
}
std::vector<substrait::Expression> expressions;
for (int i = 0; i < generate_rel.child_output_size(); ++i)
{
expressions.emplace_back(generate_rel.child_output(i));
}

expressions.emplace_back(generate_rel.generator());
auto header = query_plan->getCurrentDataStream().header;
auto actions_dag = expressionsToActionsDAG(expressions, header);
Expand Down
4 changes: 3 additions & 1 deletion cpp-ch/local-engine/Parser/ProjectRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <Core/SortDescription.h>
#include <Parser/RelParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

namespace local_engine
{
Expand Down Expand Up @@ -50,6 +49,9 @@ class ProjectRelParser : public RelParser
/// Split actions_dag of generate rel into 3 parts: before array join + during array join + after array join
static SplittedActionsDAGs splitActionsDAGInGenerate(ActionsDAGPtr actions_dag);

bool isReplicateRows(substrait::GenerateRel rel);

DB::QueryPlanPtr parseReplicateRows(QueryPlanPtr query_plan, substrait::GenerateRel generate_rel);

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override
{
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Parser/RelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class RelParser
static std::map<std::string, std::string> parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension &advanced_extension);
static std::string getStringConfig(const std::map<std::string, std::string> & configs, const std::string & key, const std::string & default_value = "");

private:
SerializedPlanParser * plan_parser;
};

Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,12 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::convertStructFieldType(const
#undef UINT_CONVERT
}

bool SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel, String function_name)
{
auto func_signature = function_mapping[std::to_string(rel.function_reference())];
return func_signature.starts_with(function_name + ":");
}

ActionsDAGPtr SerializedPlanParser::parseFunction(
const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
Expand Down
3 changes: 3 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class SerializedPlanParser
friend class NonNullableColumnsResolver;
friend class JoinRelParser;
friend class MergeTreeRelParser;
friend class ProjectRelParser;

std::unique_ptr<LocalExecutor> createExecutor(DB::QueryPlanPtr query_plan);

Expand Down Expand Up @@ -389,6 +390,8 @@ class SerializedPlanParser
const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names);
static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field);

bool isFunction(substrait::Expression_ScalarFunction rel, String function_name);

int name_no = 0;
std::unordered_map<std::string, std::string> function_mapping;
std::vector<jobject> input_iters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ object ExpressionMappings {
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID),
Sig[SparkPartitionID](SPARK_PARTITION_ID),
Sig[WidthBucket](WIDTH_BUCKET),
Sig[ReplicateRows](REPLICATE_ROWS),
// Decimal
Sig[UnscaledValue](UNSCALED_VALUE),
// Generator function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ object ExpressionNames {
final val SPARK_PARTITION_ID = "spark_partition_id"
final val MONOTONICALLY_INCREASING_ID = "monotonically_increasing_id"
final val WIDTH_BUCKET = "width_bucket"
final val REPLICATE_ROWS = "replicaterows"

// Directly use child expression transformer
final val KNOWN_NULLABLE = "known_nullable"
Expand Down

0 comments on commit 89371e4

Please sign in to comment.