diff --git a/cpp-ch/local-engine/Operator/ExpandStep.cpp b/cpp-ch/local-engine/Operator/ExpandStep.cpp index 9f56d9fd9460..8770c4c405cc 100644 --- a/cpp-ch/local-engine/Operator/ExpandStep.cpp +++ b/cpp-ch/local-engine/Operator/ExpandStep.cpp @@ -52,16 +52,18 @@ ExpandStep::ExpandStep(const DB::DataStream & input_stream_, const ExpandField & output_header = getOutputStream().header; } -DB::Block ExpandStep::buildOutputHeader(const DB::Block & , const ExpandField & project_set_exprs_) +DB::Block ExpandStep::buildOutputHeader(const DB::Block &, const ExpandField & project_set_exprs_) { DB::ColumnsWithTypeAndName cols; const auto & types = project_set_exprs_.getTypes(); const auto & names = project_set_exprs_.getNames(); + chassert(names.size() == types.size()); + for (size_t i = 0; i < project_set_exprs_.getExpandCols(); ++i) - cols.push_back(DB::ColumnWithTypeAndName(types[i], names[i])); + cols.emplace_back(DB::ColumnWithTypeAndName(types[i], names[i])); - return DB::Block(cols); + return DB::Block(std::move(cols)); } void ExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) diff --git a/cpp-ch/local-engine/Operator/ExpandTransform.cpp b/cpp-ch/local-engine/Operator/ExpandTransform.cpp index f5787163c5a1..5100ad070638 100644 --- a/cpp-ch/local-engine/Operator/ExpandTransform.cpp +++ b/cpp-ch/local-engine/Operator/ExpandTransform.cpp @@ -15,19 +15,20 @@ * limitations under the License. */ #include +#include #include #include #include #include #include +#include #include -#include "ExpandTransorm.h" - -#include #include #include +#include "ExpandTransorm.h" + namespace DB { namespace ErrorCodes @@ -93,53 +94,42 @@ void ExpandTransform::work() if (expand_expr_iterator >= project_set_exprs.getExpandRows()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "expand_expr_iterator >= project_set_exprs.getExpandRows()"); - const auto & original_cols = input_chunk.getColumns(); + const auto & input_header = getInputs().front().getHeader(); + const auto & input_columns = input_chunk.getColumns(); + const auto & types = project_set_exprs.getTypes(); + const auto & kinds = project_set_exprs.getKinds()[expand_expr_iterator]; + const auto & fields = project_set_exprs.getFields()[expand_expr_iterator]; size_t rows = input_chunk.getNumRows(); - DB::Columns cols; - for (size_t j = 0; j < project_set_exprs.getExpandCols(); ++j) + + DB::Columns columns(types.size()); + for (size_t col_i = 0; col_i < types.size(); ++col_i) { - const auto & type = project_set_exprs.getTypes()[j]; - const auto & kind = project_set_exprs.getKinds()[expand_expr_iterator][j]; - const auto & field = project_set_exprs.getFields()[expand_expr_iterator][j]; + const auto & type = types[col_i]; + const auto & kind = kinds[col_i]; + const auto & field = fields[col_i]; if (kind == EXPAND_FIELD_KIND_SELECTION) { - const auto & original_col = original_cols.at(field.get()); - if (type->isNullable() == original_col->isNullable()) - { - cols.push_back(original_col); - } - else if (type->isNullable() && !original_col->isNullable()) - { - auto null_map = DB::ColumnUInt8::create(rows, 0); - auto col = DB::ColumnNullable::create(original_col, std::move(null_map)); - cols.push_back(std::move(col)); - } - else - { - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Miss match nullable, column {} is nullable, but type {} is not nullable", - original_col->getName(), - type->getName()); - } + auto index = field.get(); + const auto & input_column = input_columns[index]; + + DB::ColumnWithTypeAndName input_arg; + input_arg.column = input_column; + input_arg.type = input_header.getByPosition(index).type; + /// input_column maybe non-Nullable + columns[col_i] = DB::castColumn(input_arg, type); } - else if (field.isNull()) + else if (kind == EXPAND_FIELD_KIND_LITERAL) { - // Add null column - auto null_map = DB::ColumnUInt8::create(rows, 1); - auto nested_type = DB::removeNullable(type); - auto col = DB::ColumnNullable::create(nested_type->createColumn()->cloneResized(rows), std::move(null_map)); - cols.push_back(std::move(col)); + /// Add const column with field value + auto column = type->createColumnConst(rows, field); + columns[col_i] = column; } else - { - // Add constant column: gid, gpos, etc. - auto col = type->createColumnConst(rows, field); - cols.push_back(col->convertToFullColumnIfConst()); - } + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknown ExpandFieldKind {}", magic_enum::enum_name(kind)); } - output_chunk = DB::Chunk(cols, rows); + + output_chunk = DB::Chunk(std::move(columns), rows); has_output = true; ++expand_expr_iterator; diff --git a/cpp-ch/local-engine/Operator/ExpandTransorm.h b/cpp-ch/local-engine/Operator/ExpandTransorm.h index 90bdf3dc13dc..f315ca5db35e 100644 --- a/cpp-ch/local-engine/Operator/ExpandTransorm.h +++ b/cpp-ch/local-engine/Operator/ExpandTransorm.h @@ -15,21 +15,21 @@ * limitations under the License. */ #pragma once -#include -#include + #include #include #include #include #include + namespace local_engine { // For handling substrait expand node. // The implementation in spark for groupingsets/rollup/cube is different from Clickhouse. -// We have to ways to support groupingsets/rollup/cube -// - rewrite the substrait plan in local engine and reuse the implementation of clickhouse. This +// We have two ways to support groupingsets/rollup/cube +// - Rewrite the substrait plan in local engine and reuse the implementation of clickhouse. This // may be more complex. -// - implement new transform to do the expandation. It's more simple, but may suffer some performance +// - Implement new transform to do the expandation. It's simpler, but may suffer some performance // issues. We try this first. class ExpandTransform : public DB::IProcessor {