From 93b76aa11f69d57b4dcf17d86ac802403203588e Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Tue, 26 Sep 2023 09:56:16 +0800 Subject: [PATCH] convert agg type to fixed string --- .../Builder/SerializedPlanBuilder.cpp | 2 +- .../Shuffle/CachedShuffleWriter.cpp | 5 ++ .../local-engine/Shuffle/NativeSplitter.cpp | 2 + .../local-engine/Shuffle/ShuffleSplitter.cpp | 6 +- .../IO/AggregateSerializationUtils.cpp | 83 +++++++++++++++++++ .../Storages/IO/AggregateSerializationUtils.h | 30 +++++++ .../Storages/SourceFromJavaIter.cpp | 23 ++++- 7 files changed, 148 insertions(+), 3 deletions(-) create mode 100644 cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp create mode 100644 cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h diff --git a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp index 22499794a5001..1dd3cbfbfce3c 100644 --- a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp +++ b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp @@ -259,7 +259,7 @@ std::shared_ptr SerializedPlanBuilder::buildType(const DB::Data res->mutable_i32()->set_nullability(type_nullability); else if (which.isInt64()) res->mutable_i64()->set_nullability(type_nullability); - else if (which.isString() || which.isAggregateFunction()) + else if (which.isString() || which.isAggregateFunction() || which.isFixedString()) res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type else if (which.isFloat32()) res->mutable_fp32()->set_nullability(type_nullability); diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp index 4f3ae465078e4..1d5a13fcb166b 100644 --- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp @@ -17,6 +17,7 @@ #include "CachedShuffleWriter.h" #include #include +#include #include #include #include @@ -93,6 +94,10 @@ CachedShuffleWriter::CachedShuffleWriter(const String & short_name, SplitOptions void CachedShuffleWriter::split(DB::Block & block) { + Stopwatch split_time_watch; + split_time_watch.start(); + block = convertAggregateStateInBlock(block); + split_result.total_split_time += split_time_watch.elapsedNanoseconds(); initOutputIfNeeded(block); Stopwatch compute_pid_time_watch; diff --git a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp index 4b8b56f99c572..1ddd5311f90e5 100644 --- a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace local_engine { @@ -43,6 +44,7 @@ void NativeSplitter::split(DB::Block & block) { return; } + block = convertAggregateStateInBlock(block); if (!output_header.columns()) [[unlikely]] { if (output_columns_indicies.empty()) diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp index 2f7d144102845..957644a01ed01 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -40,6 +40,10 @@ void ShuffleSplitter::split(DB::Block & block) return; } computeAndCountPartitionId(block); + Stopwatch split_time_watch; + split_time_watch.start(); + block = convertAggregateStateInBlock(block); + split_result.total_split_time += split_time_watch.elapsedNanoseconds(); splitBlockByPartition(block); } SplitResult ShuffleSplitter::stop() diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp new file mode 100644 index 0000000000000..c00adc58ca6e0 --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AggregateSerializationUtils.h" +#include + +#include +#include +#include + + +using namespace DB; + +namespace local_engine +{ +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) +{ + if (!isAggregateFunction(col.type)) + { + return col; + } + const auto *aggregate_col = checkAndGetColumn(*col.column); + size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData(); + auto res_type = std::make_shared(state_size); + auto res_col = res_type->createColumn(); + res_col->reserve(aggregate_col->size()); + for (const auto & item : aggregate_col->getData()) + { + res_col->insertData(item, state_size); + } + return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); +} +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type) +{ + chassert(isAggregateFunction(type)); + auto res_col = type->createColumn(); + const auto * agg_type = checkAndGetDataType(type.get()); + ColumnAggregateFunction & real_column = typeid_cast(*res_col); + auto & arena = real_column.createOrGetArena(); + ColumnAggregateFunction::Container & vec = real_column.getData(); + + vec.reserve(col.column->size()); + auto agg_function = agg_type->getFunction(); + size_t size_of_state = agg_function->sizeOfData(); + size_t align_of_state = agg_function->alignOfData(); + + for (size_t i = 0; i < col.column->size(); ++i) + { + AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); + + agg_function->create(place); + + auto value = col.column->getDataAt(i); + memcpy(place, value.data, value.size); + + vec.push_back(place); + } + return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name); +} +DB::Block convertAggregateStateInBlock(DB::Block block) +{ + ColumnsWithTypeAndName columns; + for (const auto & item : block.getColumnsWithTypeAndName()) + { + columns.emplace_back(convertAggregateStateToFixedString(item)); + } + return columns; +} +} + diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h new file mode 100644 index 0000000000000..62d5127d8a90f --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace local_engine { + +DB::Block convertAggregateStateInBlock(DB::Block block); + +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col); + +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type); + +} + diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index 74d7d7054db70..37625503631fa 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -17,7 +17,7 @@ #include "SourceFromJavaIter.h" #include #include -#include +#include #include #include #include @@ -25,6 +25,8 @@ #include #include +using namespace DB; + namespace local_engine { jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; @@ -70,8 +72,27 @@ DB::Chunk SourceFromJavaIter::generate() result = BlockUtil::buildRowCountChunk(rows); } } + Columns converted_columns; + auto output_header = outputs.front().getHeader(); + for (size_t i = 0; i < result.getNumColumns(); ++i) + { + if (isAggregateFunction(output_header.getByPosition(i).type)) + { + auto col = data->getByPosition(i); + col.column = result.getColumns().at(i); + auto converted = convertFixedStringToAggregateState(col, output_header.getByPosition(i).type); + converted_columns.emplace_back(converted.column); + } + else + { + converted_columns.emplace_back(result.getColumns().at(i)); + } + } + result.setColumns(converted_columns, result.getNumRows()); } CLEAN_JNIENV + + return result; } SourceFromJavaIter::~SourceFromJavaIter()