From f305598439659d3749ca18003edceb7c1c587254 Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Tue, 26 Sep 2023 09:56:16 +0800 Subject: [PATCH 01/14] convert agg type to fixed string --- .../Builder/SerializedPlanBuilder.cpp | 2 +- .../Shuffle/CachedShuffleWriter.cpp | 5 ++ .../local-engine/Shuffle/NativeSplitter.cpp | 2 + .../local-engine/Shuffle/ShuffleSplitter.cpp | 6 +- .../IO/AggregateSerializationUtils.cpp | 83 +++++++++++++++++++ .../Storages/IO/AggregateSerializationUtils.h | 30 +++++++ .../Storages/SourceFromJavaIter.cpp | 23 ++++- 7 files changed, 148 insertions(+), 3 deletions(-) create mode 100644 cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp create mode 100644 cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h diff --git a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp index 22499794a500..1dd3cbfbfce3 100644 --- a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp +++ b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp @@ -259,7 +259,7 @@ std::shared_ptr SerializedPlanBuilder::buildType(const DB::Data res->mutable_i32()->set_nullability(type_nullability); else if (which.isInt64()) res->mutable_i64()->set_nullability(type_nullability); - else if (which.isString() || which.isAggregateFunction()) + else if (which.isString() || which.isAggregateFunction() || which.isFixedString()) res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type else if (which.isFloat32()) res->mutable_fp32()->set_nullability(type_nullability); diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp index 676231e048b9..2abcc376df01 100644 --- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp @@ -17,6 +17,7 @@ #include "CachedShuffleWriter.h" #include #include +#include #include #include #include @@ -93,6 +94,10 @@ CachedShuffleWriter::CachedShuffleWriter(const String & short_name, SplitOptions void CachedShuffleWriter::split(DB::Block & block) { + Stopwatch split_time_watch; + split_time_watch.start(); + block = convertAggregateStateInBlock(block); + split_result.total_split_time += split_time_watch.elapsedNanoseconds(); initOutputIfNeeded(block); Stopwatch compute_pid_time_watch; diff --git a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp index eec4e05fffb9..b46ab426f3df 100644 --- a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace local_engine { @@ -43,6 +44,7 @@ void NativeSplitter::split(DB::Block & block) { return; } + block = convertAggregateStateInBlock(block); if (!output_header.columns()) [[unlikely]] { if (output_columns_indicies.empty()) diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp index 266c343a59e7..281fb1821868 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -40,6 +40,10 @@ void ShuffleSplitter::split(DB::Block & block) return; } computeAndCountPartitionId(block); + Stopwatch split_time_watch; + split_time_watch.start(); + block = convertAggregateStateInBlock(block); + split_result.total_split_time += split_time_watch.elapsedNanoseconds(); splitBlockByPartition(block); } SplitResult ShuffleSplitter::stop() diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp new file mode 100644 index 000000000000..c00adc58ca6e --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AggregateSerializationUtils.h" +#include + +#include +#include +#include + + +using namespace DB; + +namespace local_engine +{ +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) +{ + if (!isAggregateFunction(col.type)) + { + return col; + } + const auto *aggregate_col = checkAndGetColumn(*col.column); + size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData(); + auto res_type = std::make_shared(state_size); + auto res_col = res_type->createColumn(); + res_col->reserve(aggregate_col->size()); + for (const auto & item : aggregate_col->getData()) + { + res_col->insertData(item, state_size); + } + return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); +} +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type) +{ + chassert(isAggregateFunction(type)); + auto res_col = type->createColumn(); + const auto * agg_type = checkAndGetDataType(type.get()); + ColumnAggregateFunction & real_column = typeid_cast(*res_col); + auto & arena = real_column.createOrGetArena(); + ColumnAggregateFunction::Container & vec = real_column.getData(); + + vec.reserve(col.column->size()); + auto agg_function = agg_type->getFunction(); + size_t size_of_state = agg_function->sizeOfData(); + size_t align_of_state = agg_function->alignOfData(); + + for (size_t i = 0; i < col.column->size(); ++i) + { + AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); + + agg_function->create(place); + + auto value = col.column->getDataAt(i); + memcpy(place, value.data, value.size); + + vec.push_back(place); + } + return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name); +} +DB::Block convertAggregateStateInBlock(DB::Block block) +{ + ColumnsWithTypeAndName columns; + for (const auto & item : block.getColumnsWithTypeAndName()) + { + columns.emplace_back(convertAggregateStateToFixedString(item)); + } + return columns; +} +} + diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h new file mode 100644 index 000000000000..62d5127d8a90 --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace local_engine { + +DB::Block convertAggregateStateInBlock(DB::Block block); + +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col); + +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type); + +} + diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index 74d7d7054db7..37625503631f 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -17,7 +17,7 @@ #include "SourceFromJavaIter.h" #include #include -#include +#include #include #include #include @@ -25,6 +25,8 @@ #include #include +using namespace DB; + namespace local_engine { jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; @@ -70,8 +72,27 @@ DB::Chunk SourceFromJavaIter::generate() result = BlockUtil::buildRowCountChunk(rows); } } + Columns converted_columns; + auto output_header = outputs.front().getHeader(); + for (size_t i = 0; i < result.getNumColumns(); ++i) + { + if (isAggregateFunction(output_header.getByPosition(i).type)) + { + auto col = data->getByPosition(i); + col.column = result.getColumns().at(i); + auto converted = convertFixedStringToAggregateState(col, output_header.getByPosition(i).type); + converted_columns.emplace_back(converted.column); + } + else + { + converted_columns.emplace_back(result.getColumns().at(i)); + } + } + result.setColumns(converted_columns, result.getNumRows()); } CLEAN_JNIENV + + return result; } SourceFromJavaIter::~SourceFromJavaIter() From d16d5597819327f4e702a3bbe433622e728ae40a Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Tue, 26 Sep 2023 16:49:00 +0800 Subject: [PATCH 02/14] optimize shuffle read --- .../Shuffle/CachedShuffleWriter.cpp | 3 +- .../local-engine/Shuffle/PartitionWriter.cpp | 7 +- cpp-ch/local-engine/Shuffle/ShuffleReader.cpp | 4 +- cpp-ch/local-engine/Shuffle/ShuffleReader.h | 3 +- .../local-engine/Storages/IO/NativeReader.cpp | 154 ++++++++++++++++++ .../local-engine/Storages/IO/NativeReader.h | 47 ++++++ .../local-engine/Storages/IO/NativeWriter.cpp | 84 ++++++++++ .../local-engine/Storages/IO/NativeWriter.h | 50 ++++++ .../Storages/SourceFromJavaIter.cpp | 23 +-- 9 files changed, 345 insertions(+), 30 deletions(-) create mode 100644 cpp-ch/local-engine/Storages/IO/NativeReader.cpp create mode 100644 cpp-ch/local-engine/Storages/IO/NativeReader.h create mode 100644 cpp-ch/local-engine/Storages/IO/NativeWriter.cpp create mode 100644 cpp-ch/local-engine/Storages/IO/NativeWriter.h diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp index 2abcc376df01..16d56e9bb8e8 100644 --- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp @@ -94,11 +94,11 @@ CachedShuffleWriter::CachedShuffleWriter(const String & short_name, SplitOptions void CachedShuffleWriter::split(DB::Block & block) { + initOutputIfNeeded(block); Stopwatch split_time_watch; split_time_watch.start(); block = convertAggregateStateInBlock(block); split_result.total_split_time += split_time_watch.elapsedNanoseconds(); - initOutputIfNeeded(block); Stopwatch compute_pid_time_watch; compute_pid_time_watch.start(); @@ -110,7 +110,6 @@ void CachedShuffleWriter::split(DB::Block & block) { out_block.insert(block.getByPosition(output_columns_indicies[col])); } - partition_writer->write(partition_info, out_block); if (options.spill_threshold > 0 && partition_writer->totalCacheSize() > options.spill_threshold) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index 7a6bcb78ccf1..3623e493bf48 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -29,6 +29,7 @@ #include #include #include +#include using namespace DB; @@ -73,7 +74,7 @@ void LocalPartitionWriter::evictPartitions(bool for_memory_spill) WriteBufferFromFile output(file, shuffle_writer->options.io_buffer_size); auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {}); CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, 0, shuffle_writer->output_header); + NativeWriter writer(compressed_output, shuffle_writer->output_header); SpillInfo info; info.spilled_file = file; size_t partition_id = 0; @@ -122,7 +123,7 @@ std::vector LocalPartitionWriter::mergeSpills(WriteBuffer& data_file) { auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {}); CompressedWriteBuffer compressed_output(data_file, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, 0, shuffle_writer->output_header); + NativeWriter writer(compressed_output, shuffle_writer->output_header); std::vector partition_length; partition_length.resize(shuffle_writer->options.partition_nums, 0); @@ -229,7 +230,7 @@ void CelebornPartitionWriter::evictPartitions(bool for_memory_spill) WriteBufferFromOwnString output; auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {}); CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, 0, shuffle_writer->output_header); + NativeWriter writer(compressed_output, shuffle_writer->output_header); size_t raw_size = partition.spill(writer); compressed_output.sync(); Stopwatch push_time_watch; diff --git a/cpp-ch/local-engine/Shuffle/ShuffleReader.cpp b/cpp-ch/local-engine/Shuffle/ShuffleReader.cpp index 32c9e4cecf08..50165ca6661a 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleReader.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleReader.cpp @@ -37,11 +37,11 @@ local_engine::ShuffleReader::ShuffleReader(std::unique_ptr in_, bool { compressed_in = std::make_unique(*in); configureCompressedReadBuffer(static_cast(*compressed_in)); - input_stream = std::make_unique(*compressed_in, 0); + input_stream = std::make_unique(*compressed_in); } else { - input_stream = std::make_unique(*in, 0); + input_stream = std::make_unique(*in); } } Block * local_engine::ShuffleReader::read() diff --git a/cpp-ch/local-engine/Shuffle/ShuffleReader.h b/cpp-ch/local-engine/Shuffle/ShuffleReader.h index fccc2b0e5755..082e75a26ca6 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleReader.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleReader.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace DB { @@ -42,7 +43,7 @@ class ShuffleReader : BlockIterator private: std::unique_ptr in; std::unique_ptr compressed_in; - std::unique_ptr input_stream; + std::unique_ptr input_stream; DB::Block header; }; diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp new file mode 100644 index 000000000000..6080e8486b5a --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "NativeReader.h" + +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int INCORRECT_INDEX; + extern const int LOGICAL_ERROR; + extern const int CANNOT_READ_ALL_DATA; + extern const int INCORRECT_DATA; + extern const int TOO_LARGE_ARRAY_SIZE; +} +} + +using namespace DB; + +namespace local_engine +{ +void NativeReader::readData(const ISerialization & serialization, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint) +{ + ISerialization::DeserializeBinaryBulkSettings settings; + settings.getter = [&](ISerialization::SubstreamPath) -> ReadBuffer * { return &istr; }; + settings.avg_value_size_hint = avg_value_size_hint; + settings.position_independent_encoding = false; + settings.native_format = true; + + ISerialization::DeserializeBinaryBulkStatePtr state; + + serialization.deserializeBinaryBulkStatePrefix(settings, state); + serialization.deserializeBinaryBulkWithMultipleStreams(column, rows, settings, state, nullptr); + + if (column->size() != rows) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, + "Cannot read all data in NativeReader. Rows read: {}. Rows expected: {}", column->size(), rows); +} + +void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows) +{ + ColumnAggregateFunction & real_column = typeid_cast(*column->assumeMutable()); + auto & arena = real_column.createOrGetArena(); + ColumnAggregateFunction::Container & vec = real_column.getData(); + + vec.reserve(rows); + auto agg_function = data_type.getFunction(); + size_t size_of_state = agg_function->sizeOfData(); + size_t align_of_state = agg_function->alignOfData(); + + for (size_t i = 0; i < rows; ++i) + { + AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); + + agg_function->create(place); + + auto n = istr.read(place, size_of_state); + chassert(n == size_of_state); + vec.push_back(place); + } +} + + +Block NativeReader::getHeader() const +{ + return header; +} + +Block NativeReader::read() +{ + Block res; + + const DataTypeFactory & data_type_factory = DataTypeFactory::instance(); + + if (istr.eof()) + { + return res; + } + + /// Dimensions + size_t columns = 0; + size_t rows = 0; + + readVarUInt(columns, istr); + readVarUInt(rows, istr); + + if (columns > 1'000'000uz) + throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many columns in Native format: {}", columns); + if (rows > 1'000'000'000'000uz) + throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many rows in Native format: {}", rows); + + if (columns == 0 && !header && rows != 0) + throw Exception(ErrorCodes::INCORRECT_DATA, "Zero columns but {} rows in Native format.", rows); + + for (size_t i = 0; i < columns; ++i) + { + ColumnWithTypeAndName column; + + column.name = "col_" + std::to_string(i); + + /// Type + String type_name; + readBinary(type_name, istr); + column.type = data_type_factory.get(type_name); + bool is_agg_state_type = isAggregateFunction(column.type); + SerializationPtr serialization = column.type->getDefaultSerialization(); + + /// Data + ColumnPtr read_column = column.type->createColumn(*serialization); + + double avg_value_size_hint = avg_value_size_hints.empty() ? 0 : avg_value_size_hints[i]; + if (rows) /// If no rows, nothing to read. + { + if (is_agg_state_type) + { + const DataTypeAggregateFunction * agg_type = checkAndGetDataType(column.type.get()); + readAggData(*agg_type, read_column, istr, rows); + } + else + { + readData(*serialization, read_column, istr, rows, avg_value_size_hint); + } + } + column.column = std::move(read_column); + + res.insert(std::move(column)); + } + + if (res.rows() != rows) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Row count mismatch after deserialization, got: {}, expected: {}", res.rows(), rows); + + return res; +} + +} diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.h b/cpp-ch/local-engine/Storages/IO/NativeReader.h new file mode 100644 index 000000000000..3cd9375eed22 --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace local_engine +{ + +class NativeReader +{ +public: + NativeReader(DB::ReadBuffer & istr_) : istr(istr_) {} + + static void readData(const DB::ISerialization & serialization, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows, double avg_value_size_hint); + static void readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows); + + DB::Block getHeader() const; + + DB::Block read(); + +private: + DB::ReadBuffer & istr; + DB::Block header; + + DB::PODArray avg_value_size_hints; + + void updateAvgValueSizeHints(const DB::Block & block); +}; + +} diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp new file mode 100644 index 000000000000..2cbdad0ae36b --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "NativeWriter.h" +#include +#include +#include +#include + +using namespace DB; + +namespace local_engine +{ +void NativeWriter::flush() +{ + ostr.next(); +} + +static void writeData(const ISerialization & serialization, const ColumnPtr & column, WriteBuffer & ostr, UInt64 offset, UInt64 limit) +{ + /** If there are columns-constants - then we materialize them. + * (Since the data type does not know how to serialize / deserialize constants.) + */ + ColumnPtr full_column = column->convertToFullColumnIfConst(); + + ISerialization::SerializeBinaryBulkSettings settings; + settings.getter = [&ostr](ISerialization::SubstreamPath) -> WriteBuffer * { return &ostr; }; + settings.position_independent_encoding = false; + settings.low_cardinality_max_dictionary_size = 0; + + ISerialization::SerializeBinaryBulkStatePtr state; + serialization.serializeBinaryBulkStatePrefix(*full_column, settings, state); + serialization.serializeBinaryBulkWithMultipleStreams(*full_column, offset, limit, settings, state); + serialization.serializeBinaryBulkStateSuffix(settings, state); +} + +size_t NativeWriter::write(const DB::Block & block) +{ + size_t written_before = ostr.count(); + + block.checkNumberOfRows(); + + /// Dimensions + size_t columns = block.columns(); + size_t rows = block.rows(); + + writeVarUInt(columns, ostr); + writeVarUInt(rows, ostr); + + for (size_t i = 0; i < columns; ++i) + { + auto column = block.safeGetByPosition(i); + /// agg state will convert to fixedString, need write actual agg state type + auto original_type = header.safeGetByPosition(i).type; + /// Type + String type_name = original_type->getName(); + + writeStringBinary(type_name, ostr); + + SerializationPtr serialization = column.type->getDefaultSerialization(); + column.column = recursiveRemoveSparse(column.column); + /// Data + if (rows) /// Zero items of data is always represented as zero number of bytes. + writeData(*serialization, column.column, ostr, 0, 0); + } + + size_t written_after = ostr.count(); + size_t written_size = written_after - written_before; + return written_size; +} +} diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.h b/cpp-ch/local-engine/Storages/IO/NativeWriter.h new file mode 100644 index 000000000000..6815d89d25dd --- /dev/null +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace DB +{ +class WriteBuffer; +class CompressedWriteBuffer; +} + +namespace local_engine +{ + +class NativeWriter +{ +public: + NativeWriter( + DB::WriteBuffer & ostr_, const DB::Block & header_): ostr(ostr_), header(header_) + {} + + DB::Block getHeader() const { return header; } + + /// Returns the number of bytes written. + size_t write(const DB::Block & block); + void flush(); + + +private: + DB::WriteBuffer & ostr; + DB::Block header; +}; +} diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index 37625503631f..74d7d7054db7 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -17,7 +17,7 @@ #include "SourceFromJavaIter.h" #include #include -#include +#include #include #include #include @@ -25,8 +25,6 @@ #include #include -using namespace DB; - namespace local_engine { jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; @@ -72,27 +70,8 @@ DB::Chunk SourceFromJavaIter::generate() result = BlockUtil::buildRowCountChunk(rows); } } - Columns converted_columns; - auto output_header = outputs.front().getHeader(); - for (size_t i = 0; i < result.getNumColumns(); ++i) - { - if (isAggregateFunction(output_header.getByPosition(i).type)) - { - auto col = data->getByPosition(i); - col.column = result.getColumns().at(i); - auto converted = convertFixedStringToAggregateState(col, output_header.getByPosition(i).type); - converted_columns.emplace_back(converted.column); - } - else - { - converted_columns.emplace_back(result.getColumns().at(i)); - } - } - result.setColumns(converted_columns, result.getNumRows()); } CLEAN_JNIENV - - return result; } SourceFromJavaIter::~SourceFromJavaIter() From 95d7e5b81bae8275a8c09661220dd1b0fb302e3a Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Wed, 27 Sep 2023 20:17:49 +0800 Subject: [PATCH 03/14] fix bug --- .../Join/StorageJoinFromReadBuffer.cpp | 4 ++-- .../local-engine/Shuffle/NativeSplitter.cpp | 1 - .../local-engine/Shuffle/ShuffleSplitter.cpp | 19 +++++++++++++------ cpp-ch/local-engine/Shuffle/ShuffleSplitter.h | 5 +++-- cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp | 4 ++-- cpp-ch/local-engine/Shuffle/ShuffleWriter.h | 4 ++-- .../local-engine/Storages/IO/NativeReader.cpp | 12 ++++++++++-- .../local-engine/Storages/IO/NativeWriter.cpp | 13 +++++++++++-- .../local-engine/Storages/IO/NativeWriter.h | 2 +- 9 files changed, 44 insertions(+), 20 deletions(-) diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp index 3ef616b6ce0f..83c37c7ad752 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp @@ -16,7 +16,7 @@ */ #include "StorageJoinFromReadBuffer.h" -#include +#include #include #include #include @@ -42,7 +42,7 @@ using namespace DB; void restore(DB::ReadBuffer & in, IJoin & join, const Block & sample_block) { - NativeReader block_stream(in, 0); + local_engine::NativeReader block_stream(in); ProfileInfo info; { diff --git a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp index b46ab426f3df..d05f6633ca8a 100644 --- a/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/NativeSplitter.cpp @@ -44,7 +44,6 @@ void NativeSplitter::split(DB::Block & block) { return; } - block = convertAggregateStateInBlock(block); if (!output_header.columns()) [[unlikely]] { if (output_columns_indicies.empty()) diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp index 281fb1821868..e45b9d32218e 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp @@ -39,6 +39,7 @@ void ShuffleSplitter::split(DB::Block & block) { return; } + initOutputIfNeeded(block); computeAndCountPartitionId(block); Stopwatch split_time_watch; split_time_watch.start(); @@ -74,12 +75,12 @@ SplitResult ShuffleSplitter::stop() stopped = true; return split_result; } -void ShuffleSplitter::splitBlockByPartition(DB::Block & block) + +void ShuffleSplitter::initOutputIfNeeded(Block & block) { - Stopwatch split_time_watch; - split_time_watch.start(); - if (!output_header.columns()) [[unlikely]] + if (output_header.columns() == 0) [[unlikely]] { + output_header = block.cloneEmpty(); if (output_columns_indicies.empty()) { output_header = block.cloneEmpty(); @@ -90,7 +91,7 @@ void ShuffleSplitter::splitBlockByPartition(DB::Block & block) } else { - DB::ColumnsWithTypeAndName cols; + ColumnsWithTypeAndName cols; for (const auto & index : output_columns_indicies) { cols.push_back(block.getByPosition(index)); @@ -98,6 +99,12 @@ void ShuffleSplitter::splitBlockByPartition(DB::Block & block) output_header = DB::Block(cols); } } +} + +void ShuffleSplitter::splitBlockByPartition(DB::Block & block) +{ + Stopwatch split_time_watch; + split_time_watch.start(); DB::Block out_block; for (size_t col = 0; col < output_header.columns(); ++col) { @@ -152,7 +159,7 @@ void ShuffleSplitter::spillPartition(size_t partition_id) { partition_write_buffers[partition_id] = getPartitionWriteBuffer(partition_id); partition_outputs[partition_id] - = std::make_unique(*partition_write_buffers[partition_id], 0, partition_buffer[partition_id].getHeader()); + = std::make_unique(*partition_write_buffers[partition_id], output_header); } DB::Block result = partition_buffer[partition_id].releaseColumns(); if (result.rows() > 0) diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h index e9a59c339cae..aad53508b81b 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.h @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include #include @@ -101,6 +101,7 @@ class ShuffleSplitter : public ShuffleWriterBase private: void init(); + void initOutputIfNeeded(DB::Block & block); void splitBlockByPartition(DB::Block & block); void spillPartition(size_t partition_id); std::string getPartitionTempFile(size_t partition_id); @@ -111,7 +112,7 @@ class ShuffleSplitter : public ShuffleWriterBase bool stopped = false; PartitionInfo partition_info; std::vector partition_buffer; - std::vector> partition_outputs; + std::vector> partition_outputs; std::vector> partition_write_buffers; std::vector> partition_cached_write_buffers; std::vector compressed_buffers; diff --git a/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp index 8fdbac37fd60..dddf0b895fdf 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp @@ -41,11 +41,11 @@ void ShuffleWriter::write(const Block & block) { if (compression_enable) { - native_writer = std::make_unique(*compressed_out, 0, block.cloneEmpty()); + native_writer = std::make_unique(*compressed_out, block.cloneEmpty()); } else { - native_writer = std::make_unique(*write_buffer, 0, block.cloneEmpty()); + native_writer = std::make_unique(*write_buffer, block.cloneEmpty()); } } if (block.rows() > 0) diff --git a/cpp-ch/local-engine/Shuffle/ShuffleWriter.h b/cpp-ch/local-engine/Shuffle/ShuffleWriter.h index 459bf4e93ad7..98f67d1ccadb 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleWriter.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleWriter.h @@ -16,7 +16,7 @@ */ #pragma once #include -#include +#include namespace local_engine { @@ -32,7 +32,7 @@ class ShuffleWriter private: std::unique_ptr compressed_out; std::unique_ptr write_buffer; - std::unique_ptr native_writer; + std::unique_ptr native_writer; bool compression_enable; }; } diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index 6080e8486b5a..f0c47b9b1f5e 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace DB { @@ -120,7 +121,14 @@ Block NativeReader::read() /// Type String type_name; readBinary(type_name, istr); - column.type = data_type_factory.get(type_name); + bool agg_opt_column = false; + String real_type_name = type_name; + if (type_name.ends_with(NativeWriter::AGG_STATE_SUFFIX)) + { + agg_opt_column = true; + real_type_name = type_name.substr(0, type_name.length() - NativeWriter::AGG_STATE_SUFFIX.length()); + } + column.type = data_type_factory.get(real_type_name); bool is_agg_state_type = isAggregateFunction(column.type); SerializationPtr serialization = column.type->getDefaultSerialization(); @@ -130,7 +138,7 @@ Block NativeReader::read() double avg_value_size_hint = avg_value_size_hints.empty() ? 0 : avg_value_size_hints[i]; if (rows) /// If no rows, nothing to read. { - if (is_agg_state_type) + if (is_agg_state_type && agg_opt_column) { const DataTypeAggregateFunction * agg_type = checkAndGetDataType(column.type.get()); readAggData(*agg_type, read_column, istr, rows); diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index 2cbdad0ae36b..1c24f4c1a5e2 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -24,6 +24,8 @@ using namespace DB; namespace local_engine { + +const String NativeWriter::AGG_STATE_SUFFIX= "#optagg"; void NativeWriter::flush() { ostr.next(); @@ -67,8 +69,15 @@ size_t NativeWriter::write(const DB::Block & block) auto original_type = header.safeGetByPosition(i).type; /// Type String type_name = original_type->getName(); - - writeStringBinary(type_name, ostr); + if (isAggregateFunction(original_type) + && header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType()) + { + writeStringBinary(type_name + AGG_STATE_SUFFIX, ostr); + } + else + { + writeStringBinary(type_name, ostr); + } SerializationPtr serialization = column.type->getDefaultSerialization(); column.column = recursiveRemoveSparse(column.column); diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.h b/cpp-ch/local-engine/Storages/IO/NativeWriter.h index 6815d89d25dd..a958f4484dce 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.h +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.h @@ -32,12 +32,12 @@ namespace local_engine class NativeWriter { public: + static const String AGG_STATE_SUFFIX; NativeWriter( DB::WriteBuffer & ostr_, const DB::Block & header_): ostr(ostr_), header(header_) {} DB::Block getHeader() const { return header; } - /// Returns the number of bytes written. size_t write(const DB::Block & block); void flush(); From 71e4a0b38a183eef250e710ceaf00d8d05bd698a Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Thu, 28 Sep 2023 16:13:53 +0800 Subject: [PATCH 04/14] add white list for agg opt --- ...seTPCHColumnarShuffleParquetAQESuite.scala | 39 +++++++++++++++++++ .../Builder/SerializedPlanBuilder.cpp | 2 +- .../IO/AggregateSerializationUtils.cpp | 19 ++++++++- 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala index 38192a32b465..237bac4ad1f4 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala @@ -290,4 +290,43 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite } } } + + test("collect_set") { + val sql = + """ + |select a, b from ( + |select n_regionkey as a, collect_set(if(n_regionkey=0, n_name, null)) as set from nation group by n_regionkey) + |lateral view explode(set) as b + |order by a, b + |""".stripMargin + runQueryAndCompare(sql)(checkOperatorMatch[CHHashAggregateExecTransformer]) + } + + test("test 'aggregate function collect_list'") { + val df = runQueryAndCompare( + "select l_orderkey,from_unixtime(l_orderkey, 'yyyy-MM-dd HH:mm:ss') " + + "from lineitem order by l_orderkey desc limit 10" + )(checkOperatorMatch[ProjectExecTransformer]) + checkLengthAndPlan(df, 10) + } + + test("test max string") { + withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "true")) { + val sql = + """ + |SELECT + | l_returnflag, + | l_linestatus, + | max(l_comment) + |FROM + | lineitem + |WHERE + | l_shipdate <= date'1998-09-02' - interval 1 day + |GROUP BY + | l_returnflag, + | l_linestatus + |""".stripMargin + runQueryAndCompare(sql, noFallBack = false) { df => } + } + } } diff --git a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp index 1dd3cbfbfce3..92e5c564110d 100644 --- a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp +++ b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp @@ -259,7 +259,7 @@ std::shared_ptr SerializedPlanBuilder::buildType(const DB::Data res->mutable_i32()->set_nullability(type_nullability); else if (which.isInt64()) res->mutable_i64()->set_nullability(type_nullability); - else if (which.isString() || which.isAggregateFunction() || which.isFixedString()) + else if (which.isStringOrFixedString() || which.isAggregateFunction()) res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type else if (which.isFloat32()) res->mutable_fp32()->set_nullability(type_nullability); diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp index c00adc58ca6e..d53d565c42e7 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -26,6 +27,14 @@ using namespace DB; namespace local_engine { + +bool isFixedSizeStateAggregateFunction(const String& name) +{ + // TODO max(String) should exclude, but fallback now + static const std::set function_set = {"min", "max", "sum", "count", "avg"}; + return function_set.contains(name); +} + DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) { if (!isAggregateFunction(col.type)) @@ -33,13 +42,19 @@ DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeA return col; } const auto *aggregate_col = checkAndGetColumn(*col.column); + // only support known fixed size aggregate function + if (!isFixedSizeStateAggregateFunction(aggregate_col->getAggregateFunction()->getName())) + { + return col; + } size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData(); auto res_type = std::make_shared(state_size); auto res_col = res_type->createColumn(); - res_col->reserve(aggregate_col->size()); + PaddedPODArray & column_chars_t = assert_cast(*res_col).getChars(); + column_chars_t.reserve(aggregate_col->size() * state_size); for (const auto & item : aggregate_col->getData()) { - res_col->insertData(item, state_size); + column_chars_t.insert_assume_reserved(item, item + state_size); } return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); } From ffc034a9075bbf4eb3a2bc568c6f379d492f1e3a Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Thu, 28 Sep 2023 16:59:33 +0800 Subject: [PATCH 05/14] fix style --- .../GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala index 237bac4ad1f4..0814c3c8c7d7 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHColumnarShuffleParquetAQESuite.scala @@ -295,7 +295,8 @@ class GlutenClickHouseTPCHColumnarShuffleParquetAQESuite val sql = """ |select a, b from ( - |select n_regionkey as a, collect_set(if(n_regionkey=0, n_name, null)) as set from nation group by n_regionkey) + |select n_regionkey as a, collect_set(if(n_regionkey=0, n_name, null)) + | as set from nation group by n_regionkey) |lateral view explode(set) as b |order by a, b |""".stripMargin From c285d6970429d6d589e5c8cd978d101ec48119fe Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Tue, 31 Oct 2023 09:56:50 +0800 Subject: [PATCH 06/14] optimize shuffle more --- .../local-engine/Shuffle/PartitionWriter.cpp | 2 +- cpp-ch/local-engine/Shuffle/PartitionWriter.h | 2 +- .../IO/AggregateSerializationUtils.cpp | 28 ++++++++++++++++++- .../Storages/IO/AggregateSerializationUtils.h | 2 ++ .../local-engine/Storages/IO/NativeReader.cpp | 7 +++-- .../local-engine/Storages/IO/NativeWriter.cpp | 20 +++++++++++-- 6 files changed, 52 insertions(+), 9 deletions(-) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index 3623e493bf48..20f47d869ec2 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -303,7 +303,7 @@ void Partition::clear() blocks.clear(); } -size_t Partition::spill(DB::NativeWriter & writer) +size_t Partition::spill(NativeWriter & writer) { std::unique_lock lock(mtx, std::try_to_lock); if (lock.owns_lock()) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.h b/cpp-ch/local-engine/Shuffle/PartitionWriter.h index a0d83c194b05..0a457e39415e 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.h +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.h @@ -48,7 +48,7 @@ class Partition void addBlock(DB::Block & block); bool empty() const; void clear(); - size_t spill(DB::NativeWriter & writer); + size_t spill(NativeWriter & writer); private: std::vector blocks; diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp index d53d565c42e7..49b8f274ad2a 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -19,8 +19,10 @@ #include #include +#include #include #include +#include using namespace DB; @@ -58,6 +60,30 @@ DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeA } return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); } + +DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndName col) +{ + if (!isAggregateFunction(col.type)) + { + return col; + } + const auto *aggregate_col = checkAndGetColumn(*col.column); + auto res_type = std::make_shared(); + auto res_col = res_type->createColumn(); + PaddedPODArray & column_chars = assert_cast(*res_col).getChars(); + column_chars.reserve(aggregate_col->size() * 60); + IColumn::Offsets & column_offsets = assert_cast(*res_col).getOffsets(); + auto value_writer = WriteBufferFromVector>(column_chars); + column_offsets.reserve(aggregate_col->size()); + for (const auto & item : aggregate_col->getData()) + { + aggregate_col->getAggregateFunction()->serialize(item, value_writer); + writeChar('\0', value_writer); + column_offsets.emplace_back(value_writer.count()); + } + return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); +} + DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type) { chassert(isAggregateFunction(type)); @@ -90,7 +116,7 @@ DB::Block convertAggregateStateInBlock(DB::Block block) ColumnsWithTypeAndName columns; for (const auto & item : block.getColumnsWithTypeAndName()) { - columns.emplace_back(convertAggregateStateToFixedString(item)); + columns.emplace_back(convertAggregateStateToString(item)); } return columns; } diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h index 62d5127d8a90..fb9bd4ae5e87 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h @@ -24,6 +24,8 @@ DB::Block convertAggregateStateInBlock(DB::Block block); DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col); +DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndName col); + DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type); } diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index f0c47b9b1f5e..cf6d11f1695f 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -73,9 +73,10 @@ void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); agg_function->create(place); - - auto n = istr.read(place, size_of_state); - chassert(n == size_of_state); +// UInt64 size; +// readVarUInt(size, istr); + agg_function->deserialize(place, istr); + istr.ignore(); vec.push_back(place); } } diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index 1c24f4c1a5e2..fd7d11fd54b6 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -19,6 +19,8 @@ #include #include #include +#include + using namespace DB; @@ -69,8 +71,9 @@ size_t NativeWriter::write(const DB::Block & block) auto original_type = header.safeGetByPosition(i).type; /// Type String type_name = original_type->getName(); - if (isAggregateFunction(original_type) - && header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType()) + bool is_agg_opt = isAggregateFunction(original_type) + && header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType(); + if (is_agg_opt) { writeStringBinary(type_name + AGG_STATE_SUFFIX, ostr); } @@ -83,7 +86,18 @@ size_t NativeWriter::write(const DB::Block & block) column.column = recursiveRemoveSparse(column.column); /// Data if (rows) /// Zero items of data is always represented as zero number of bytes. - writeData(*serialization, column.column, ostr, 0, 0); + { + if (is_agg_opt) + { + const auto * str_col = static_cast(column.column.get()); + const PaddedPODArray & column_chars = str_col->getChars(); + ostr.write(column_chars.raw_data(), str_col->getOffsets().back()); + } + else + { + writeData(*serialization, column.column, ostr, 0, 0); + } + } } size_t written_after = ostr.count(); From 739377c8fe3bc27b25d6bd5f95756402ce921e6a Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Tue, 31 Oct 2023 15:23:55 +0800 Subject: [PATCH 07/14] combine two optimization --- .../IO/AggregateSerializationUtils.cpp | 26 ++++++++++++++++--- .../Storages/IO/AggregateSerializationUtils.h | 3 +++ .../local-engine/Storages/IO/NativeReader.cpp | 15 ++++++++--- .../local-engine/Storages/IO/NativeWriter.cpp | 6 ++++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp index 49b8f274ad2a..5a5712cc1bb9 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -32,11 +32,20 @@ namespace local_engine bool isFixedSizeStateAggregateFunction(const String& name) { - // TODO max(String) should exclude, but fallback now static const std::set function_set = {"min", "max", "sum", "count", "avg"}; return function_set.contains(name); } +bool isFixedSizeArguments(DataTypes data_types) +{ + return data_types.front()->isValueRepresentedByNumber(); +} + +bool isFixedSizeAggregateFunction(DB::AggregateFunctionPtr function) +{ + return isFixedSizeStateAggregateFunction(function->getName()) && isFixedSizeArguments(function->getArgumentTypes()); +} + DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) { if (!isAggregateFunction(col.type)) @@ -45,7 +54,7 @@ DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeA } const auto *aggregate_col = checkAndGetColumn(*col.column); // only support known fixed size aggregate function - if (!isFixedSizeStateAggregateFunction(aggregate_col->getAggregateFunction()->getName())) + if (!isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction())) { return col; } @@ -116,7 +125,18 @@ DB::Block convertAggregateStateInBlock(DB::Block block) ColumnsWithTypeAndName columns; for (const auto & item : block.getColumnsWithTypeAndName()) { - columns.emplace_back(convertAggregateStateToString(item)); + if (isAggregateFunction(item.type)) + { + const auto *aggregate_col = checkAndGetColumn(*item.column); + if (isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction())) + columns.emplace_back(convertAggregateStateToFixedString(item)); + else + columns.emplace_back(convertAggregateStateToString(item)); + } + else + { + columns.emplace_back(item); + } } return columns; } diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h index fb9bd4ae5e87..6df1ad2821d2 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h @@ -15,11 +15,14 @@ * limitations under the License. */ #pragma once +#include #include #include namespace local_engine { +bool isFixedSizeAggregateFunction(DB::AggregateFunctionPtr function); + DB::Block convertAggregateStateInBlock(DB::Block block); DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col); diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index cf6d11f1695f..98fa4b7419b5 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace DB { @@ -73,10 +74,16 @@ void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); agg_function->create(place); -// UInt64 size; -// readVarUInt(size, istr); - agg_function->deserialize(place, istr); - istr.ignore(); + if (isFixedSizeAggregateFunction(agg_function)) + { + auto n = istr.read(place, size_of_state); + chassert(n == size_of_state); + } + else + { + agg_function->deserialize(place, istr, std::nullopt, &arena); + istr.ignore(); + } vec.push_back(place); } } diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index fd7d11fd54b6..472a2fec77a5 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -20,6 +20,9 @@ #include #include #include +#include +#include +#include using namespace DB; @@ -87,7 +90,8 @@ size_t NativeWriter::write(const DB::Block & block) /// Data if (rows) /// Zero items of data is always represented as zero number of bytes. { - if (is_agg_opt) + const auto * agg_type = checkAndGetDataType(original_type.get()); + if (is_agg_opt && agg_type && !isFixedSizeAggregateFunction(agg_type->getFunction())) { const auto * str_col = static_cast(column.column.get()); const PaddedPODArray & column_chars = str_col->getChars(); From f59a09f04cf3d1f945f8c82831fe95078aa34e61 Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Thu, 2 Nov 2023 15:40:39 +0800 Subject: [PATCH 08/14] fix performance --- cpp-ch/local-engine/Storages/IO/NativeReader.cpp | 15 ++++++++++++--- cpp-ch/local-engine/Storages/IO/NativeReader.h | 1 + 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index 98fa4b7419b5..d22e72092916 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -58,6 +58,7 @@ void NativeReader::readData(const ISerialization & serialization, ColumnPtr & co "Cannot read all data in NativeReader. Rows read: {}. Rows expected: {}", column->size(), rows); } +template void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows) { ColumnAggregateFunction & real_column = typeid_cast(*column->assumeMutable()); @@ -72,9 +73,8 @@ void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, for (size_t i = 0; i < rows; ++i) { AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); - agg_function->create(place); - if (isFixedSizeAggregateFunction(agg_function)) + if constexpr (FIXED) { auto n = istr.read(place, size_of_state); chassert(n == size_of_state); @@ -84,6 +84,7 @@ void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, agg_function->deserialize(place, istr, std::nullopt, &arena); istr.ignore(); } + vec.push_back(place); } } @@ -149,7 +150,15 @@ Block NativeReader::read() if (is_agg_state_type && agg_opt_column) { const DataTypeAggregateFunction * agg_type = checkAndGetDataType(column.type.get()); - readAggData(*agg_type, read_column, istr, rows); + bool fixed = isFixedSizeAggregateFunction(agg_type->getFunction()); + if (fixed) + { + readAggData(*agg_type, read_column, istr, rows); + } + else + { + readAggData(*agg_type, read_column, istr, rows); + } } else { diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.h b/cpp-ch/local-engine/Storages/IO/NativeReader.h index 3cd9375eed22..d065fce347d4 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.h +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.h @@ -29,6 +29,7 @@ class NativeReader NativeReader(DB::ReadBuffer & istr_) : istr(istr_) {} static void readData(const DB::ISerialization & serialization, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows, double avg_value_size_hint); + template static void readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows); DB::Block getHeader() const; From 0b4ef6391866fad1b8accdf4b5c809a1f768d9a4 Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Fri, 3 Nov 2023 16:35:10 +0800 Subject: [PATCH 09/14] fix ut failed --- .../execution/adaptive/GlutenAdaptiveQueryExecSuite.scala | 6 ++++-- .../execution/adaptive/GlutenAdaptiveQueryExecSuite.scala | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala index 0e85bf546e7c..3ac53799f3f9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala @@ -437,8 +437,10 @@ class GlutenAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQL test("gluten Exchange reuse") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100", - SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + // magic threshold, ch backend has two bhj when threshold is 100 + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "90", + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT value FROM testData join testData2 ON key = a " + "join (SELECT value v from testData join testData3 ON key = a) on value = v") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala index 6bbeb801825d..5b5ee83be49b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/GlutenAdaptiveQueryExecSuite.scala @@ -437,8 +437,10 @@ class GlutenAdaptiveQueryExecSuite extends AdaptiveQueryExecSuite with GlutenSQL test("gluten Exchange reuse") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100", - SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + // magic threshold, ch backend has two bhj when threshold is 100 + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "90", + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT value FROM testData join testData2 ON key = a " + "join (SELECT value v from testData join testData3 ON key = a) on value = v") From 431272d83e8af2f40281cfcc83f1df83127c5197 Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Tue, 14 Nov 2023 17:01:10 +0800 Subject: [PATCH 10/14] Update AggregateSerializationUtils.cpp --- .../Storages/IO/AggregateSerializationUtils.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp index 5a5712cc1bb9..6ad5bcefc625 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -48,7 +48,7 @@ bool isFixedSizeAggregateFunction(DB::AggregateFunctionPtr function) DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) { - if (!isAggregateFunction(col.type)) + if (!WhichDataType(col.type).isAggregateFunction()) { return col; } @@ -72,7 +72,7 @@ DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeA DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndName col) { - if (!isAggregateFunction(col.type)) + if (!WhichDataType(col.type).isAggregateFunction()) { return col; } @@ -95,7 +95,7 @@ DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndNam DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type) { - chassert(isAggregateFunction(type)); + chassert(WhichDataType(type).isAggregateFunction()); auto res_col = type->createColumn(); const auto * agg_type = checkAndGetDataType(type.get()); ColumnAggregateFunction & real_column = typeid_cast(*res_col); @@ -125,7 +125,7 @@ DB::Block convertAggregateStateInBlock(DB::Block block) ColumnsWithTypeAndName columns; for (const auto & item : block.getColumnsWithTypeAndName()) { - if (isAggregateFunction(item.type)) + if (WhichDataType(item.type).isAggregateFunction()) { const auto *aggregate_col = checkAndGetColumn(*item.column); if (isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction())) From d329ff898efa8bdad0c1063dcb5df6391c805ada Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Tue, 14 Nov 2023 09:13:54 +0000 Subject: [PATCH 11/14] fix --- cpp-ch/local-engine/Storages/IO/NativeReader.cpp | 2 +- cpp-ch/local-engine/Storages/IO/NativeWriter.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index d22e72092916..12955579c624 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -138,7 +138,7 @@ Block NativeReader::read() real_type_name = type_name.substr(0, type_name.length() - NativeWriter::AGG_STATE_SUFFIX.length()); } column.type = data_type_factory.get(real_type_name); - bool is_agg_state_type = isAggregateFunction(column.type); + bool is_agg_state_type = WhichDataType(column.type).isAggregateFunction(); SerializationPtr serialization = column.type->getDefaultSerialization(); /// Data diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index 472a2fec77a5..2afc3e9d66ba 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -74,7 +74,7 @@ size_t NativeWriter::write(const DB::Block & block) auto original_type = header.safeGetByPosition(i).type; /// Type String type_name = original_type->getName(); - bool is_agg_opt = isAggregateFunction(original_type) + bool is_agg_opt = WhichDataType(original_type).isAggregateFunction(); && header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType(); if (is_agg_opt) { From af0e256bf9bd81145d51f1a7aa9b4e192d6f7e9f Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Tue, 14 Nov 2023 10:36:08 +0000 Subject: [PATCH 12/14] fix --- cpp-ch/local-engine/Storages/IO/NativeWriter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index 2afc3e9d66ba..39a0cb7b579b 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -74,7 +74,7 @@ size_t NativeWriter::write(const DB::Block & block) auto original_type = header.safeGetByPosition(i).type; /// Type String type_name = original_type->getName(); - bool is_agg_opt = WhichDataType(original_type).isAggregateFunction(); + bool is_agg_opt = WhichDataType(original_type).isAggregateFunction() && header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType(); if (is_agg_opt) { From db955e4a2e3075f5edda5bcf85ada117c8d48a36 Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Tue, 14 Nov 2023 09:42:43 +0000 Subject: [PATCH 13/14] update --- cpp-ch/local-engine/Shuffle/PartitionWriter.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index 20f47d869ec2..c933c4596641 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -57,7 +57,7 @@ void local_engine::PartitionWriter::write(const PartitionInfo& partition_info, D if (buffer.size() >= shuffle_writer->options.split_size) { Block block = buffer.releaseColumns(); - auto bytes = block.bytes(); + auto bytes = block.allocatedBytes(); total_partition_buffer_size += bytes; shuffle_writer->split_result.raw_partition_length[i] += bytes; partition_buffer[i].addBlock(block); @@ -131,7 +131,8 @@ std::vector LocalPartitionWriter::mergeSpills(WriteBuffer& data_file) spill_inputs.reserve(spill_infos.size()); for (const auto & spill : spill_infos) { - spill_inputs.emplace_back(std::make_shared(spill.spilled_file, shuffle_writer->options.io_buffer_size)); + // only use readBig + spill_inputs.emplace_back(std::make_shared(spill.spilled_file, 0)); } Stopwatch write_time_watch; From 80ec518f92ebffbf69cde0f7c1ad67ef06494abe Mon Sep 17 00:00:00 2001 From: liuneng <1398775315@qq.com> Date: Wed, 15 Nov 2023 12:02:16 +0800 Subject: [PATCH 14/14] some fix --- .../local-engine/Shuffle/PartitionWriter.cpp | 4 ++- .../IO/AggregateSerializationUtils.cpp | 25 +++++++++---------- .../Storages/IO/AggregateSerializationUtils.h | 10 ++++---- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index c933c4596641..932917362b20 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -310,9 +310,11 @@ size_t Partition::spill(NativeWriter & writer) if (lock.owns_lock()) { size_t raw_size = 0; - for (const auto & block : blocks) + while (!blocks.empty()) { + auto & block = blocks.back(); raw_size += writer.write(block); + blocks.pop_back(); } blocks.clear(); return raw_size; diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp index 6ad5bcefc625..84c32f4565f7 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp @@ -36,23 +36,23 @@ bool isFixedSizeStateAggregateFunction(const String& name) return function_set.contains(name); } -bool isFixedSizeArguments(DataTypes data_types) +bool isFixedSizeArguments(const DataTypes& data_types) { - return data_types.front()->isValueRepresentedByNumber(); + return removeNullable(data_types.front())->isValueRepresentedByNumber(); } -bool isFixedSizeAggregateFunction(DB::AggregateFunctionPtr function) +bool isFixedSizeAggregateFunction(const DB::AggregateFunctionPtr& function) { return isFixedSizeStateAggregateFunction(function->getName()) && isFixedSizeArguments(function->getArgumentTypes()); } -DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col) +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(const DB::ColumnWithTypeAndName& col) { - if (!WhichDataType(col.type).isAggregateFunction()) + const auto *aggregate_col = checkAndGetColumn(*col.column); + if (!aggregate_col) { return col; } - const auto *aggregate_col = checkAndGetColumn(*col.column); // only support known fixed size aggregate function if (!isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction())) { @@ -70,17 +70,16 @@ DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeA return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); } -DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndName col) +DB::ColumnWithTypeAndName convertAggregateStateToString(const DB::ColumnWithTypeAndName& col) { - if (!WhichDataType(col.type).isAggregateFunction()) + const auto *aggregate_col = checkAndGetColumn(*col.column); + if (!aggregate_col) { return col; } - const auto *aggregate_col = checkAndGetColumn(*col.column); auto res_type = std::make_shared(); auto res_col = res_type->createColumn(); PaddedPODArray & column_chars = assert_cast(*res_col).getChars(); - column_chars.reserve(aggregate_col->size() * 60); IColumn::Offsets & column_offsets = assert_cast(*res_col).getOffsets(); auto value_writer = WriteBufferFromVector>(column_chars); column_offsets.reserve(aggregate_col->size()); @@ -93,7 +92,7 @@ DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndNam return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name); } -DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type) +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(const DB::ColumnWithTypeAndName & col, const DB::DataTypePtr & type) { chassert(WhichDataType(type).isAggregateFunction()); auto res_col = type->createColumn(); @@ -101,7 +100,6 @@ DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeA ColumnAggregateFunction & real_column = typeid_cast(*res_col); auto & arena = real_column.createOrGetArena(); ColumnAggregateFunction::Container & vec = real_column.getData(); - vec.reserve(col.column->size()); auto agg_function = agg_type->getFunction(); size_t size_of_state = agg_function->sizeOfData(); @@ -120,9 +118,10 @@ DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeA } return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name); } -DB::Block convertAggregateStateInBlock(DB::Block block) +DB::Block convertAggregateStateInBlock(DB::Block& block) { ColumnsWithTypeAndName columns; + columns.reserve(block.columns()); for (const auto & item : block.getColumnsWithTypeAndName()) { if (WhichDataType(item.type).isAggregateFunction()) diff --git a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h index 6df1ad2821d2..6536982ef572 100644 --- a/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h +++ b/cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h @@ -21,15 +21,15 @@ namespace local_engine { -bool isFixedSizeAggregateFunction(DB::AggregateFunctionPtr function); +bool isFixedSizeAggregateFunction(const DB::AggregateFunctionPtr & function); -DB::Block convertAggregateStateInBlock(DB::Block block); +DB::Block convertAggregateStateInBlock(DB::Block& block); -DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col); +DB::ColumnWithTypeAndName convertAggregateStateToFixedString(const DB::ColumnWithTypeAndName & col); -DB::ColumnWithTypeAndName convertAggregateStateToString(DB::ColumnWithTypeAndName col); +DB::ColumnWithTypeAndName convertAggregateStateToString(const DB::ColumnWithTypeAndName & col); -DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type); +DB::ColumnWithTypeAndName convertFixedStringToAggregateState(const DB::ColumnWithTypeAndName & col, const DB::DataTypePtr & type); }