Skip to content

Commit

Permalink
[CH] Support replicaterows #6308
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
Support expression ReplicateRows in spark

How was this patch tested?
unit tests
  • Loading branch information
liuneng1994 authored Jul 3, 2024
1 parent d589aa3 commit 47fa44f
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,19 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
spark.sql("drop table test")
}

test("intersect all") {
spark.sql("create table t1 (a int, b string) using parquet")
spark.sql("insert into t1 values (1, '1'),(2, '2'),(3, '3'),(4, '4'),(5, '5'),(6, '6')")
spark.sql("create table t2 (a int, b string) using parquet")
spark.sql("insert into t2 values (4, '4'),(5, '5'),(6, '6'),(7, '7'),(8, '8'),(9, '9')")
runQueryAndCompare(
"""
|SELECT a,b FROM t1 INTERSECT ALL SELECT a,b FROM t2
|""".stripMargin
)(df => checkFallbackOperators(df, 0))
spark.sql("drop table t1")
spark.sql("drop table t2")
}

}
108 changes: 108 additions & 0 deletions cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ReplicateRowsStep.h"

#include <iostream>

#include <QueryPipeline/QueryPipelineBuilder.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}

namespace local_engine
{
static DB::ITransformingStep::Traits getTraits()
{
return DB::ITransformingStep::Traits
{
{
.preserves_number_of_streams = true,
.preserves_sorting = false,
},
{
.preserves_number_of_rows = false,
}
};
}

ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream)
: ITransformingStep(input_stream, transformHeader(input_stream.header), getTraits())
{
}

DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input)
{
DB::Block output;
for (int i = 1; i < input.columns(); i++)
{
output.insert(input.getByPosition(i));
}
return output;
}

void ReplicateRowsStep::transformPipeline(
DB::QueryPipelineBuilder & pipeline,
const DB::BuildQueryPipelineSettings & /*settings*/)
{
pipeline.addSimpleTransform(
[&](const DB::Block & header)
{
return std::make_shared<ReplicateRowsTransform>(header);
});
}

void ReplicateRowsStep::updateOutputStream()
{
output_stream = createOutputStream(input_streams.front(), transformHeader(input_streams.front().header), getDataStreamTraits());
}

ReplicateRowsTransform::ReplicateRowsTransform(const DB::Block & input_header_)
: ISimpleTransform(input_header_, ReplicateRowsStep::transformHeader(input_header_), true)
{
}

void ReplicateRowsTransform::transform(DB::Chunk & chunk)
{
auto replica_column = chunk.getColumns().front();
size_t total_rows = 0;
for (int i = 0; i < replica_column->size(); i++)
{
total_rows += replica_column->get64(i);
}

auto columns = chunk.detachColumns();
DB::MutableColumns mutable_columns;
for (int i = 1; i < columns.size(); i++)
{
mutable_columns.push_back(columns[i]->cloneEmpty());
mutable_columns.back()->reserve(total_rows);
DB::ColumnPtr src_col = columns[i];
DB::MutableColumnPtr & cur = mutable_columns.back();
for (int j = 0; j < replica_column->size(); j++)
{
cur->insertManyFrom(*src_col, j, replica_column->getUInt(j));
}
}

chunk.setColumns(std::move(mutable_columns), total_rows);
}
}
48 changes: 48 additions & 0 deletions cpp-ch/local-engine/Operator/ReplicateRowsStep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <Processors/ISimpleTransform.h>
#include <Processors/QueryPlan/ITransformingStep.h>

namespace local_engine
{

class ReplicateRowsStep : public DB::ITransformingStep
{
public:
ReplicateRowsStep(const DB::DataStream& input_stream);

static DB::Block transformHeader(const DB::Block& input);

String getName() const override { return "ReplicateRowsStep"; }
void transformPipeline(DB::QueryPipelineBuilder& pipeline,
const DB::BuildQueryPipelineSettings& settings) override;
private:
void updateOutputStream() override;
};

class ReplicateRowsTransform : public DB::ISimpleTransform
{
public:
ReplicateRowsTransform(const DB::Block& input_header_);

String getName() const override { return "ReplicateRowsTransform"; }
void transform(DB::Chunk&) override;

};
}
34 changes: 34 additions & 0 deletions cpp-ch/local-engine/Parser/ProjectRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Rewriter/ExpressionRewriter.h>
#include <Common/CHUtil.h>
#include <Operator/ReplicateRowsStep.h>

using namespace DB;

namespace local_engine
{
Expand Down Expand Up @@ -109,15 +112,46 @@ ProjectRelParser::SplittedActionsDAGs ProjectRelParser::splitActionsDAGInGenerat
return res;
}

bool ProjectRelParser::isReplicateRows(substrait::GenerateRel rel)
{
return plan_parser->isFunction(rel.generator().scalar_function(), "replicaterows");
}

DB::QueryPlanPtr ProjectRelParser::parseReplicateRows(DB::QueryPlanPtr query_plan, substrait::GenerateRel generate_rel)
{
std::vector<substrait::Expression> expressions;
for (int i = 0; i < generate_rel.generator().scalar_function().arguments_size(); ++i)
{
expressions.emplace_back(generate_rel.generator().scalar_function().arguments(i).value());
}
auto header = query_plan->getCurrentDataStream().header;
auto actions_dag = expressionsToActionsDAG(expressions, header);
auto before_replicate_rows = std::make_unique<DB::ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
before_replicate_rows->setStepDescription("Before ReplicateRows");
steps.emplace_back(before_replicate_rows.get());
query_plan->addStep(std::move(before_replicate_rows));

auto replicate_rows_step = std::make_unique<ReplicateRowsStep>(query_plan->getCurrentDataStream());
replicate_rows_step->setStepDescription("ReplicateRows");
steps.emplace_back(replicate_rows_step.get());
query_plan->addStep(std::move(replicate_rows_step));
return query_plan;
}

DB::QueryPlanPtr
ProjectRelParser::parseGenerate(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> & /*rel_stack_*/)
{
const auto & generate_rel = rel.generate();
if (isReplicateRows(generate_rel))
{
return parseReplicateRows(std::move(query_plan), generate_rel);
}
std::vector<substrait::Expression> expressions;
for (int i = 0; i < generate_rel.child_output_size(); ++i)
{
expressions.emplace_back(generate_rel.child_output(i));
}

expressions.emplace_back(generate_rel.generator());
auto header = query_plan->getCurrentDataStream().header;
auto actions_dag = expressionsToActionsDAG(expressions, header);
Expand Down
4 changes: 3 additions & 1 deletion cpp-ch/local-engine/Parser/ProjectRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <Core/SortDescription.h>
#include <Parser/RelParser.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

namespace local_engine
{
Expand Down Expand Up @@ -50,6 +49,9 @@ class ProjectRelParser : public RelParser
/// Split actions_dag of generate rel into 3 parts: before array join + during array join + after array join
static SplittedActionsDAGs splitActionsDAGInGenerate(ActionsDAGPtr actions_dag);

bool isReplicateRows(substrait::GenerateRel rel);

DB::QueryPlanPtr parseReplicateRows(QueryPlanPtr query_plan, substrait::GenerateRel generate_rel);

const substrait::Rel & getSingleInput(const substrait::Rel & rel) override
{
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Parser/RelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class RelParser
static std::map<std::string, std::string> parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension &advanced_extension);
static std::string getStringConfig(const std::map<std::string, std::string> & configs, const std::string & key, const std::string & default_value = "");

private:
SerializedPlanParser * plan_parser;
};

Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,12 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::convertStructFieldType(const
#undef UINT_CONVERT
}

bool SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel, String function_name)
{
auto func_signature = function_mapping[std::to_string(rel.function_reference())];
return func_signature.starts_with(function_name + ":");
}

ActionsDAGPtr SerializedPlanParser::parseFunction(
const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
Expand Down
3 changes: 3 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class SerializedPlanParser
friend class NonNullableColumnsResolver;
friend class JoinRelParser;
friend class MergeTreeRelParser;
friend class ProjectRelParser;

std::unique_ptr<LocalExecutor> createExecutor(DB::QueryPlanPtr query_plan);

Expand Down Expand Up @@ -391,6 +392,8 @@ class SerializedPlanParser
const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names);
static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field);

bool isFunction(substrait::Expression_ScalarFunction rel, String function_name);

int name_no = 0;
std::unordered_map<std::string, std::string> function_mapping;
std::vector<jobject> input_iters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ object ExpressionMappings {
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID),
Sig[SparkPartitionID](SPARK_PARTITION_ID),
Sig[WidthBucket](WIDTH_BUCKET),
Sig[ReplicateRows](REPLICATE_ROWS),
// Decimal
Sig[UnscaledValue](UNSCALED_VALUE),
// Generator function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ object ExpressionNames {
final val SPARK_PARTITION_ID = "spark_partition_id"
final val MONOTONICALLY_INCREASING_ID = "monotonically_increasing_id"
final val WIDTH_BUCKET = "width_bucket"
final val REPLICATE_ROWS = "replicaterows"

// Directly use child expression transformer
final val KNOWN_NULLABLE = "known_nullable"
Expand Down

0 comments on commit 47fa44f

Please sign in to comment.