Skip to content

Commit

Permalink
0903
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 11, 2024
1 parent fd56ea4 commit 1950ac4
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <exception>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/WindowGroupLimitFunctions.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <Processors/Transforms/WindowTransform.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

namespace DB::ErrorCodes
{
extern const int BAD_ARGUMENTS;
}

namespace local_engine
{
WindowFunctionTopRowNumber::WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types, const DB::Array & parameters_)
: DB::WindowFunction(name, arg_types, parameters_, std::make_shared<DB::DataTypeUInt64>())
{
if (parameters.size() != 1)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs a limit parameter", name);
limit = parameters[0].safeGet<UInt64>();
LOG_ERROR(getLogger("WindowFunctionTopRowNumber"), "xxx {} limit: {}", name, limit);
}

void WindowFunctionTopRowNumber::windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const
{
LOG_ERROR(
getLogger("WindowFunctionTopRowNumber"),
"xxx current row number: {}, current_row: {}@{}, partition_ended: {}",
transform->current_row_number,
transform->current_row.block,
transform->current_row.row,
transform->partition_ended);
/// If the rank value is larger then limit, and current block only contains rows which are all belong to one partition.
/// We cant drop this block directly.
if (!transform->partition_ended && !transform->current_row.row && transform->current_row_number > limit)
{
/// It's safe to make it mutable here. but it's still too dangerous, it may be changed in the future and make it unsafe.
auto * mutable_transform = const_cast<DB::WindowTransform *>(transform);
DB::WindowTransformBlock & current_block = mutable_transform->blockAt(mutable_transform->current_row);
current_block.rows = 0;
auto clear_columns = [](DB::Columns & cols)
{
DB::Columns new_cols;
for (const auto & col : cols)
{
new_cols.push_back(std::move(col->cloneEmpty()));
}
cols = new_cols;
};
clear_columns(current_block.original_input_columns);
clear_columns(current_block.input_columns);
clear_columns(current_block.casted_columns);
mutable_transform->current_row.block += 1;
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "{} is not implemented", name);
}
else
{
auto & to_col = *transform->blockAt(transform->current_row).output_columns[function_index];
assert_cast<DB::ColumnUInt64 &>(to_col).getData().push_back(transform->current_row_number);
}
}

void registerWindowGroupLimitFunctions(DB::AggregateFunctionFactory & factory)
{
const DB::AggregateFunctionProperties properties
= {.returns_default_when_only_null = true, .is_order_dependent = true, .is_window_function = true};
factory.registerFunction(
"top_row_number",
{[](const String & name, const DB::DataTypes & args_type, const DB::Array & parameters, const DB::Settings *)
{ return std::make_shared<WindowFunctionTopRowNumber>(name, args_type, parameters); },
properties},
DB::AggregateFunctionFactory::Case::Insensitive);
}
}
33 changes: 33 additions & 0 deletions cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <AggregateFunctions/WindowFunction.h>

namespace local_engine
{
class WindowFunctionTopRowNumber : public DB::WindowFunction
{
public:
explicit WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types_, const DB::Array & parameters_);
~WindowFunctionTopRowNumber() override = default;

void windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const override;
bool allocatesMemoryInArena() const override { return false; }

private:
size_t limit = 0;
};
}
41 changes: 23 additions & 18 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,7 @@ std::map<std::string, std::string> BackendInitializerUtil::getBackendConfMap(std
}

std::vector<String> BackendInitializerUtil::wrapDiskPathConfig(
const String & path_prefix,
const String & path_suffix,
Poco::Util::AbstractConfiguration & config)
const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config)
{
std::vector<String> changed_paths;
if (path_prefix.empty() && path_suffix.empty())
Expand Down Expand Up @@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::map<std::s
auto path_need_clean = wrapDiskPathConfig("", "/" + pid, *config);
std::lock_guard lock(BackendFinalizerUtil::paths_mutex);
BackendFinalizerUtil::paths_need_to_clean.insert(
BackendFinalizerUtil::paths_need_to_clean.end(),
path_need_clean.begin(),
path_need_clean.end());
BackendFinalizerUtil::paths_need_to_clean.end(), path_need_clean.begin(), path_need_clean.end());
}
return config;
}
Expand All @@ -683,7 +679,9 @@ void BackendInitializerUtil::initEnvs(DB::Context::ConfigurationPtr config)
{
const std::string config_timezone = config->getString("timezone");
const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone);
if (0 != setenv("TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv
if (0
!= setenv(
"TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv
throw Poco::Exception("Cannot setenv TZ variable");

tzset();
Expand Down Expand Up @@ -807,8 +805,7 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
{
auto mem_gb = task_memory / static_cast<double>(1_GiB);
// 2.8x+5, Heuristics calculate the block size of external sort, [8,16]
settings.prefer_external_sort_block_bytes = std::max(std::min(
static_cast<size_t>(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
settings.prefer_external_sort_block_bytes = std::max(std::min(static_cast<size_t>(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
}
}
}
Expand Down Expand Up @@ -848,10 +845,14 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)

global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio);

String index_uncompressed_cache_policy = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
String index_uncompressed_cache_policy
= config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
size_t index_uncompressed_cache_size
= config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio
= config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(
index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);

String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
Expand Down Expand Up @@ -890,6 +891,7 @@ extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerFunctions(FunctionFactory &);
extern void registerWindowGroupLimitFunctions(AggregateFunctionFactory &);

void registerAllFunctions()
{
Expand All @@ -899,6 +901,7 @@ void registerAllFunctions()
auto & agg_factory = AggregateFunctionFactory::instance();
registerAggregateFunctionsBloomFilter(agg_factory);
registerAggregateFunctionSparkAvg(agg_factory);
registerWindowGroupLimitFunctions(agg_factory);
{
/// register aggregate function combinators from local_engine
auto & factory = AggregateFunctionCombinatorFactory::instance();
Expand Down Expand Up @@ -1023,11 +1026,13 @@ void BackendFinalizerUtil::finalizeGlobally()
StorageMergeTreeFactory::clear();
QueryContext::resetGlobal();
std::lock_guard lock(paths_mutex);
std::ranges::for_each(paths_need_to_clean, [](const auto & path)
{
if (fs::exists(path))
fs::remove_all(path);
});
std::ranges::for_each(
paths_need_to_clean,
[](const auto & path)
{
if (fs::exists(path))
fs::remove_all(path);
});
paths_need_to_clean.clear();
}

Expand Down
17 changes: 7 additions & 10 deletions cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace DB::ErrorCodes
extern const int BAD_ARGUMENTS;
}

const static String FUNCTION_ROW_NUM = "row_number";
const static String FUNCTION_ROW_NUM = "top_row_number";
const static String FUNCTION_RANK = "top_rank";
const static String FUNCTION_DENSE_RANK = "top_dense_rank";

Expand Down Expand Up @@ -68,22 +68,18 @@ WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait
current_plan->addStep(std::move(post_project_step));

LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx output header: {}", current_plan->getCurrentDataStream().header.dumpStructure());
bool x = true;
if (x)
{
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalide rel");
}
return std::move(current_plan);
}

DB::WindowFrame WindowGroupLimitRelParser::buildWindowFrame(const String & function_name)
{
// We only need first rows, so let the begin type is unbounded is OK
DB::WindowFrame frame;
if (function_name == FUNCTION_ROW_NUM)
{
frame.type = DB::WindowFrame::FrameType::ROWS;
frame.begin_type = DB::WindowFrame::BoundaryType::Offset;
frame.begin_offset = 1;
frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded;
frame.begin_offset = 0;
frame.begin_preceding = true;
frame.end_type = DB::WindowFrame::BoundaryType::Current;
frame.end_offset = 0;
Expand Down Expand Up @@ -123,7 +119,7 @@ DB::WindowDescription WindowGroupLimitRelParser::buildWindowDescription(const su
ss << win_desc.frame.toString();
win_desc.window_name = ss.str();

win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name));
win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name, static_cast<size_t>(win_rel_def.limit())));

return win_desc;
}
Expand Down Expand Up @@ -153,7 +149,7 @@ WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrF
return sort_desc;
}

DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name)
DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name, size_t limit)
{
DB::WindowFunctionDescription desc;
desc.column_name = function_name;
Expand All @@ -162,6 +158,7 @@ DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDesc
DB::Names func_args;
DB::DataTypes func_args_types;
DB::Array func_params;
func_params.push_back(limit);
auto func_ptr = RelParser::getAggregateFunction(function_name, func_args_types, func_properties, func_params);
desc.argument_names = func_args;
desc.argument_types = func_args_types;
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ class WindowGroupLimitRelParser : public RelParser

DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);

static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name);
static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name, size_t limit);
};
}

0 comments on commit 1950ac4

Please sign in to comment.