Skip to content

Commit

Permalink
[CH] Refactor: don't using namespace DB in header (#8300)
Browse files Browse the repository at this point in the history
* Dont't include <Storages/MergeTree/SparkStorageMergeTree.h> in SerializedPlanParser

* try not using namespace DB in header

* fix build
  • Loading branch information
baibaichen authored Dec 23, 2024
1 parent d241d47 commit b0e9a04
Show file tree
Hide file tree
Showing 150 changed files with 1,012 additions and 1,053 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ extern const int BAD_ARGUMENTS;

namespace local_engine
{
using namespace DB;

struct AggregateFunctionGroupBloomFilterData
{
bool initted = false;
// small default value because BloomFilter has no default ctor
BloomFilter bloom_filter = BloomFilter(100, 2, 0);
DB::BloomFilter bloom_filter = DB::BloomFilter(100, 2, 0);
static const char * name() { return "groupBloomFilter"; }

void read(DB::ReadBuffer & in)
Expand All @@ -58,7 +57,7 @@ struct AggregateFunctionGroupBloomFilterData
}
else
{
bloom_filter = BloomFilter(BloomFilterParameters(filter_size, filter_hashes, seed));
bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(filter_size, filter_hashes, seed));
auto & v = bloom_filter.getFilter();
in.readStrict(reinterpret_cast<char *>(v.data()), v.size() * sizeof(v[0]));
initted = true;
Expand Down Expand Up @@ -89,12 +88,12 @@ struct AggregateFunctionGroupBloomFilterData
// For groupFunctionBloomFilter, we don't actually care about the final Int result(currently always return BF byte size).
// We just need its intermediate state, ,i.e. groupFunctionFilterState.
template <typename T, typename Data>
class AggregateFunctionGroupBloomFilter final : public IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>
class AggregateFunctionGroupBloomFilter final : public DB::IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>
{
public:
explicit AggregateFunctionGroupBloomFilter(
const DataTypes & argument_types_, const Array & parameters_, size_t filter_size_, size_t filter_hashes_, size_t seed_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>(argument_types_, parameters_, createResultType())
const DB::DataTypes & argument_types_, const DB::Array & parameters_, size_t filter_size_, size_t filter_hashes_, size_t seed_)
: DB::IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>(argument_types_, parameters_, createResultType())
, filter_size(filter_size_)
, filter_hashes(filter_hashes_)
, seed(seed_)
Expand All @@ -103,32 +102,32 @@ class AggregateFunctionGroupBloomFilter final : public IAggregateFunctionDataHel

String getName() const override { return Data::name(); }

static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
static DB::DataTypePtr createResultType() { return std::make_shared<DB::DataTypeNumber<T>>(); }

bool allocatesMemoryInArena() const override { return false; }

void checkFilterSize(size_t filter_size_) const
{
if (filter_size_ <= 0)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "filter_size being LTE 0 means bloom filter is not properly initialized");
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "filter_size being LTE 0 means bloom filter is not properly initialized");
}
}

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena *) const override
{
if unlikely (!this->data(place).initted)
{
checkFilterSize(filter_size);
this->data(place).bloom_filter = BloomFilter(BloomFilterParameters(filter_size, filter_hashes, seed));
this->data(place).bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(filter_size, filter_hashes, seed));
this->data(place).initted = true;
}

T x = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
T x = assert_cast<const DB::ColumnVector<T> &>(*columns[0]).getData()[row_num];
this->data(place).bloom_filter.add(reinterpret_cast<const char *>(&x), sizeof(T));
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena *) const override
{
// Skip un-initted values
if (!this->data(rhs).initted)
Expand All @@ -141,7 +140,7 @@ class AggregateFunctionGroupBloomFilter final : public IAggregateFunctionDataHel
{
// We use filter_other's size/hashes/seed to avoid passing these parameters around to construct AggregateFunctionGroupBloomFilter.
checkFilterSize(bloom_other.getSize());
this->data(place).bloom_filter = BloomFilter(BloomFilterParameters(bloom_other.getSize(), bloom_other.getHashes(), bloom_other.getSeed()));
this->data(place).bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(bloom_other.getSize(), bloom_other.getHashes(), bloom_other.getSeed()));
this->data(place).initted = true;
}
auto & filter_self = this->data(place).bloom_filter.getFilter();
Expand All @@ -154,19 +153,19 @@ class AggregateFunctionGroupBloomFilter final : public IAggregateFunctionDataHel
}
}

void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
void deserialize(DB::AggregateDataPtr __restrict place, DB::ReadBuffer & buf, std::optional<size_t> /* version */, DB::Arena *) const override
{
this->data(place).read(buf);
}

void insertResultInto(AggregateDataPtr __restrict /*place*/, IColumn & to, Arena *) const override
void insertResultInto(DB::AggregateDataPtr __restrict /*place*/, DB::IColumn & to, DB::Arena *) const override
{
assert_cast<ColumnVector<T> &>(to).getData().push_back(static_cast<T>(filter_size));
assert_cast<DB::ColumnVector<T> &>(to).getData().push_back(static_cast<T>(filter_size));
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,24 @@ extern const int ILLEGAL_TYPE_OF_ARGUMENT;

namespace local_engine
{
using namespace DB;

struct Settings;

/**
* this class is copied from AggregateFunctionMerge with little enhancement.
* we use this PartialMerge for both spark PartialMerge and Final
*/
class AggregateFunctionPartialMerge final : public IAggregateFunctionHelper<AggregateFunctionPartialMerge>
class AggregateFunctionPartialMerge final : public DB::IAggregateFunctionHelper<AggregateFunctionPartialMerge>
{
private:
AggregateFunctionPtr nested_func;
DB::AggregateFunctionPtr nested_func;

public:
AggregateFunctionPartialMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, params_, createResultType(nested_)), nested_func(nested_)
AggregateFunctionPartialMerge(const DB::AggregateFunctionPtr & nested_, const DB::DataTypePtr & argument, const DB::Array & params_)
: DB::IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, params_, createResultType(nested_)), nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
const DB::DataTypeAggregateFunction * data_type = typeid_cast<const DB::DataTypeAggregateFunction *>(argument.get());

if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction()))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
throw DB::Exception(
DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function {}, "
"expected {} or equivalent type",
argument->getName(),
Expand All @@ -64,54 +60,54 @@ class AggregateFunctionPartialMerge final : public IAggregateFunctionHelper<Aggr

String getName() const override { return nested_func->getName() + "PartialMerge"; }

static DataTypePtr createResultType(const AggregateFunctionPtr & nested_) { return nested_->getResultType(); }
static DB::DataTypePtr createResultType(const DB::AggregateFunctionPtr & nested_) { return nested_->getResultType(); }

const DataTypePtr & getResultType() const override { return nested_func->getResultType(); }
const DB::DataTypePtr & getResultType() const override { return nested_func->getResultType(); }

bool isVersioned() const override { return nested_func->isVersioned(); }

size_t getDefaultVersion() const override { return nested_func->getDefaultVersion(); }

DataTypePtr getStateType() const override { return nested_func->getStateType(); }
DB::DataTypePtr getStateType() const override { return nested_func->getStateType(); }

void create(AggregateDataPtr __restrict place) const override { nested_func->create(place); }
void create(DB::AggregateDataPtr __restrict place) const override { nested_func->create(place); }

void destroy(AggregateDataPtr __restrict place) const noexcept override { nested_func->destroy(place); }
void destroy(DB::AggregateDataPtr __restrict place) const noexcept override { nested_func->destroy(place); }

bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); }

size_t sizeOfData() const override { return nested_func->sizeOfData(); }

size_t alignOfData() const override { return nested_func->alignOfData(); }

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena * arena) const override
{
nested_func->merge(place, assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num], arena);
nested_func->merge(place, assert_cast<const DB::ColumnAggregateFunction &>(*columns[0]).getData()[row_num], arena);
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * arena) const override
{
nested_func->merge(place, rhs, arena);
}

void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_func->serialize(place, buf, version);
}

void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
void deserialize(DB::AggregateDataPtr __restrict place, DB::ReadBuffer & buf, std::optional<size_t> version, DB::Arena * arena) const override
{
nested_func->deserialize(place, buf, version, arena);
}

void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
void insertResultInto(DB::AggregateDataPtr __restrict place, DB::IColumn & to, DB::Arena * arena) const override
{
nested_func->insertResultInto(place, to, arena);
}

bool allocatesMemoryInArena() const override { return nested_func->allocatesMemoryInArena(); }

AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
DB::AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
/// If the aggregate phase is `INTEMEDIATE_TO_INTERMEDIATE`, partial merge combinator is applied. In this case, the actual result column's
/// representation is `xxxPartialMerge`. It will make block structure check fail somewhere, since the expected column's represiontation is
/// `xxx` without partial merge. The represiontaions of `xxxPartialMerge` and `xxx` are the same actually.
Expand Down
25 changes: 12 additions & 13 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
for (const auto & block : blocks)
num_rows += block.rows();

Block out = blocks[0].cloneEmpty();
MutableColumns columns = out.mutateColumns();
DB::Block out = blocks[0].cloneEmpty();
DB::MutableColumns columns = out.mutateColumns();

for (size_t i = 0; i < columns.size(); ++i)
{
Expand All @@ -338,7 +338,7 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)
{
/// According to definition of DEFUALT_BLOCK_SIZE
size_t padding_n = 2 * PADDING_FOR_SIMD - 1;
size_t padding_n = 2 * DB::PADDING_FOR_SIMD - 1;
size_t rounded_n = roundUpToPowerOfTwoOrZero(n);
size_t padded_n = n;
if (rounded_n > n + n / 2)
Expand Down Expand Up @@ -381,19 +381,19 @@ void PlanUtil::checkOuputType(const DB::QueryPlan & plan)
const DB::WhichDataType which(ch_type_without_nullable);
if (which.isDateTime64())
{
const auto * ch_type_datetime64 = checkAndGetDataType<DataTypeDateTime64>(ch_type_without_nullable.get());
const auto * ch_type_datetime64 = checkAndGetDataType<DB::DataTypeDateTime64>(ch_type_without_nullable.get());
if (ch_type_datetime64->getScale() != 6)
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
}
else if (which.isDecimal())
{
if (which.isDecimal256())
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());

const auto scale = getDecimalScale(*ch_type_without_nullable);
const auto precision = getDecimalPrecision(*ch_type_without_nullable);
if (scale == 0 && precision == 0)
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName());
}
}
}
Expand All @@ -402,8 +402,7 @@ DB::IQueryPlanStep * PlanUtil::adjustQueryPlanHeader(DB::QueryPlan & plan, const
{
auto convert_actions_dag = DB::ActionsDAG::makeConvertingActions(
plan.getCurrentHeader().getColumnsWithTypeAndName(),
to_header.getColumnsWithTypeAndName(),
ActionsDAG::MatchColumnsMode::Name);
to_header.getColumnsWithTypeAndName(), DB::ActionsDAG::MatchColumnsMode::Name);
auto expression_step = std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(convert_actions_dag));
expression_step->setStepDescription(step_desc);
auto * step_ptr = expression_step.get();
Expand Down Expand Up @@ -504,9 +503,9 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB
const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
const DataTypePtr & cast_to_type,
const DB::DataTypePtr & cast_to_type,
const std::string & result_name,
CastType cast_type)
DB::CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
type_name_col.name = cast_to_type->getName();
Expand All @@ -515,7 +514,7 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
const auto * right_arg = &actions_dag.addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
ColumnWithTypeAndName left_column{nullptr, node->result_type, {}};
DB::ColumnWithTypeAndName left_column{nullptr, node->result_type, {}};
DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
auto func_base_cast = createInternalCast(std::move(left_column), cast_to_type, cast_type, diagnostic);

Expand All @@ -527,7 +526,7 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
const DB::ActionsDAG::Node * node,
const DB::DataTypePtr & dst_type,
const std::string & result_name,
CastType cast_type)
DB::CastType cast_type)
{
if (node->result_type->equals(*dst_type))
return node;
Expand Down
13 changes: 5 additions & 8 deletions cpp-ch/local-engine/Functions/FunctionGetDateData.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Common/LocalDate.h>
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
#pragma once
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadHelpers.h>
#include <IO/parseDateTimeBestEffort.h>

using namespace DB;
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>

namespace DB
{
Expand Down Expand Up @@ -55,7 +52,7 @@ class FunctionGetDateData : public DB::IFunction
const auto * src_col = checkAndGetColumn<DB::ColumnString>(arg1.column.get());
size_t size = src_col->size();

using ColVecTo = ColumnVector<T>;
using ColVecTo = DB::ColumnVector<T>;
typename ColVecTo::MutablePtr result_column = ColVecTo::create(size, 0);
typename ColVecTo::Container & result_container = result_column->getData();
DB::ColumnUInt8::MutablePtr null_map = DB::ColumnUInt8::create(size, 0);
Expand Down
Loading

0 comments on commit b0e9a04

Please sign in to comment.