Skip to content

Commit

Permalink
convert agg type to fixed string
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Sep 27, 2023
1 parent 049928f commit 93b76aa
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ std::shared_ptr<substrait::Type> SerializedPlanBuilder::buildType(const DB::Data
res->mutable_i32()->set_nullability(type_nullability);
else if (which.isInt64())
res->mutable_i64()->set_nullability(type_nullability);
else if (which.isString() || which.isAggregateFunction())
else if (which.isString() || which.isAggregateFunction() || which.isFixedString())
res->mutable_binary()->set_nullability(type_nullability); /// Spark Binary type is more similiar to CH String type
else if (which.isFloat32())
res->mutable_fp32()->set_nullability(type_nullability);
Expand Down
5 changes: 5 additions & 0 deletions cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CachedShuffleWriter.h"
#include <Poco/StringTokenizer.h>
#include <Common/Stopwatch.h>
#include <Storages/IO/AggregateSerializationUtils.h>
#include <Shuffle/PartitionWriter.h>
#include <jni/CelebornClient.h>
#include <jni/jni_common.h>
Expand Down Expand Up @@ -93,6 +94,10 @@ CachedShuffleWriter::CachedShuffleWriter(const String & short_name, SplitOptions

void CachedShuffleWriter::split(DB::Block & block)
{
Stopwatch split_time_watch;
split_time_watch.start();
block = convertAggregateStateInBlock(block);
split_result.total_split_time += split_time_watch.elapsedNanoseconds();
initOutputIfNeeded(block);

Stopwatch compute_pid_time_watch;
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Shuffle/NativeSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <Common/Exception.h>
#include <Common/JNIUtils.h>
#include <Common/logger_useful.h>
#include <Storages/IO/AggregateSerializationUtils.h>

namespace local_engine
{
Expand All @@ -43,6 +44,7 @@ void NativeSplitter::split(DB::Block & block)
{
return;
}
block = convertAggregateStateInBlock(block);
if (!output_header.columns()) [[unlikely]]
{
if (output_columns_indicies.empty())
Expand Down
6 changes: 5 additions & 1 deletion cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <string>
#include <fcntl.h>
#include <Compression/CompressionFactory.h>
#include <Functions/FunctionFactory.h>
#include <Storages/IO/AggregateSerializationUtils.h>
#include <IO/BrotliWriteBuffer.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteHelpers.h>
Expand All @@ -40,6 +40,10 @@ void ShuffleSplitter::split(DB::Block & block)
return;
}
computeAndCountPartitionId(block);
Stopwatch split_time_watch;
split_time_watch.start();
block = convertAggregateStateInBlock(block);
split_result.total_split_time += split_time_watch.elapsedNanoseconds();
splitBlockByPartition(block);
}
SplitResult ShuffleSplitter::stop()
Expand Down
83 changes: 83 additions & 0 deletions cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "AggregateSerializationUtils.h"
#include <Common/Arena.h>

#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeFixedString.h>


using namespace DB;

namespace local_engine
{
DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col)
{
if (!isAggregateFunction(col.type))
{
return col;
}
const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(*col.column);
size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData();
auto res_type = std::make_shared<DataTypeFixedString>(state_size);
auto res_col = res_type->createColumn();
res_col->reserve(aggregate_col->size());
for (const auto & item : aggregate_col->getData())
{
res_col->insertData(item, state_size);
}
return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}
DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type)
{
chassert(isAggregateFunction(type));
auto res_col = type->createColumn();
const auto * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(type.get());
ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*res_col);
auto & arena = real_column.createOrGetArena();
ColumnAggregateFunction::Container & vec = real_column.getData();

vec.reserve(col.column->size());
auto agg_function = agg_type->getFunction();
size_t size_of_state = agg_function->sizeOfData();
size_t align_of_state = agg_function->alignOfData();

for (size_t i = 0; i < col.column->size(); ++i)
{
AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state);

agg_function->create(place);

auto value = col.column->getDataAt(i);
memcpy(place, value.data, value.size);

vec.push_back(place);
}
return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name);
}
DB::Block convertAggregateStateInBlock(DB::Block block)
{
ColumnsWithTypeAndName columns;
for (const auto & item : block.getColumnsWithTypeAndName())
{
columns.emplace_back(convertAggregateStateToFixedString(item));
}
return columns;
}
}

30 changes: 30 additions & 0 deletions cpp-ch/local-engine/Storages/IO/AggregateSerializationUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <Core/Block.h>
#include <DataTypes/IDataType.h>

namespace local_engine {

DB::Block convertAggregateStateInBlock(DB::Block block);

DB::ColumnWithTypeAndName convertAggregateStateToFixedString(DB::ColumnWithTypeAndName col);

DB::ColumnWithTypeAndName convertFixedStringToAggregateState(DB::ColumnWithTypeAndName col, DB::DataTypePtr type);

}

23 changes: 22 additions & 1 deletion cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
#include "SourceFromJavaIter.h"
#include <Columns/ColumnNullable.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <DataTypes/DataTypesNumber.h>
#include <Storages/IO/AggregateSerializationUtils.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <jni/jni_common.h>
#include <Common/CHUtil.h>
#include <Common/DebugUtils.h>
#include <Common/Exception.h>
#include <Common/JNIUtils.h>

using namespace DB;

namespace local_engine
{
jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr;
Expand Down Expand Up @@ -70,8 +72,27 @@ DB::Chunk SourceFromJavaIter::generate()
result = BlockUtil::buildRowCountChunk(rows);
}
}
Columns converted_columns;
auto output_header = outputs.front().getHeader();
for (size_t i = 0; i < result.getNumColumns(); ++i)
{
if (isAggregateFunction(output_header.getByPosition(i).type))
{
auto col = data->getByPosition(i);
col.column = result.getColumns().at(i);
auto converted = convertFixedStringToAggregateState(col, output_header.getByPosition(i).type);
converted_columns.emplace_back(converted.column);
}
else
{
converted_columns.emplace_back(result.getColumns().at(i));
}
}
result.setColumns(converted_columns, result.getNumRows());
}
CLEAN_JNIENV


return result;
}
SourceFromJavaIter::~SourceFromJavaIter()
Expand Down

0 comments on commit 93b76aa

Please sign in to comment.