Skip to content

Commit

Permalink
improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinyhZou committed Dec 22, 2023
1 parent 2372f68 commit fa11505
Showing 1 changed file with 96 additions and 45 deletions.
141 changes: 96 additions & 45 deletions cpp-ch/local-engine/Functions/SparkFunctionFloor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,80 @@
#include <Functions/FunctionFactory.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypesNumber.h>
#include <bit>

using namespace DB;

namespace local_engine
{

template <typename T, ScaleMode scale_mode>
struct SparkFloatFloorImpl
{
private:
static_assert(!is_decimal<T>);
using Op = FloatRoundingComputation<T, RoundingMode::Floor, scale_mode>;
using Data = std::array<T, Op::data_count>;
public:
static NO_INLINE void apply(const PaddedPODArray<T> & in, size_t scale, PaddedPODArray<T> & out, PaddedPODArray<UInt8> & null_map)
{
auto mm_scale = Op::prepare(scale);
const size_t data_count = std::tuple_size<Data>();
const T* end_in = in.data() + in.size();
const T* limit = in.data() + in.size() / data_count * data_count;
const T* __restrict p_in = in.data();
T* __restrict p_out = out.data();
size_t i = 0;
while (p_in < limit)
{
Op::compute(p_in, mm_scale, p_out);
checkAndSetNullable(*p_out, null_map[i]);
p_in += data_count;
p_out += data_count;
++i;
}

if (p_in < end_in)
{
Data tmp_src{{}};
Data tmp_dst;
size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in);
memcpy(&tmp_src, p_in, tail_size_bytes);
Op::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst));
memcpy(p_out, &tmp_dst, tail_size_bytes);
checkAndSetNullable(*p_out, null_map[i]);
}
}

static void checkAndSetNullable(T& t, UInt8& null_flag)
{
if (t != t) // means the element is nan
{
t = 0;
null_flag = 1;
}
else if constexpr (std::is_same<T, float>::value) // means the float type element is inf
{
if ((std::bit_cast<uint32_t>(t) & 0b01111111111111111111111111111111) == 0b01111111100000000000000000000000)
{
t = 0;
null_flag = 1;
}
}
else if constexpr (std::is_same<T, double>::value) // means the double type element is inf
{
if ((std::bit_cast<uint64_t>(t) & 0b0111111111111111111111111111111111111111111111111111111111111111)
== 0b0111111111110000000000000000000000000000000000000000000000000000)
{
t = 0;
null_flag = 1;
}
}
}
};

class SparkFunctionFloor : public DB::FunctionFloor
{
public:
Expand All @@ -40,66 +109,48 @@ class SparkFunctionFloor : public DB::FunctionFloor

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows) const override
{
DB::ColumnPtr res = DB::FunctionFloor::executeImpl(arguments, result_type, input_rows);
if (res->isNullable())
{
const DB::ColumnNullable * nullable_col = assert_cast<const DB::ColumnNullable *>(res.get());
res = nullable_col->getNestedColumnPtr();
}
DB::MutableColumnPtr null_map_col = DB::ColumnUInt8::create(res->size(), 0);
DB::TypeIndex res_type_index = res->getDataType();
switch (res_type_index)
const ColumnWithTypeAndName & first_arg = arguments[0];
Scale scale_arg = getScaleArg(arguments);
switch(first_arg.type->getTypeId())
{
case DB::TypeIndex::Float32:
{
DB::MutableColumnPtr res_mutable = DB::IColumn::mutate(res);
checkAndSetNullable<DB::Float32>(res_mutable, null_map_col);
return DB::ColumnNullable::create(std::move(res_mutable), std::move(null_map_col));
}
case DB::TypeIndex::Float64:
{
DB::MutableColumnPtr res_mutable = DB::IColumn::mutate(res);
checkAndSetNullable<DB::Float64>(res_mutable, null_map_col);
return DB::ColumnNullable::create(std::move(res_mutable), std::move(null_map_col));
}
case TypeIndex::Float32:
return executeInternal<Float32>(first_arg.column, scale_arg);
case TypeIndex::Float64:
return executeInternal<Float64>(first_arg.column, scale_arg);
default:
return DB::ColumnNullable::create(std::move(res), std::move(null_map_col));
DB::ColumnPtr res = DB::FunctionFloor::executeImpl(arguments, result_type, input_rows);
DB::MutableColumnPtr null_map_col = DB::ColumnUInt8::create(first_arg.column->size(), 0);
return DB::ColumnNullable::create(std::move(res), std::move(null_map_col));
}

}

template<typename T>
static void checkAndSetNullable(DB::MutableColumnPtr & data_col_ptr, DB::MutableColumnPtr & null_map_ptr)
static ColumnPtr executeInternal(const ColumnPtr & col_arg, const Scale & scale_arg)
{
DB::PaddedPODArray<UInt8> & null_map = assert_cast<DB::ColumnUInt8 *>(null_map_ptr.get())->getData();
DB::PaddedPODArray<T> & data = assert_cast<DB::ColumnVector<T> *>(data_col_ptr.get())->getData();
for (size_t i = 0; i < data.size(); ++i)
const auto * col = checkAndGetColumn<ColumnVector<T>>(col_arg.get());
auto col_res = ColumnVector<T>::create(col->size());
MutableColumnPtr null_map_col = DB::ColumnUInt8::create(col->size(), 0);
PaddedPODArray<T> & vec_res = col_res->getData();
PaddedPODArray<UInt8> & null_map_data = assert_cast<ColumnVector<UInt8> *>(null_map_col.get())->getData();
if (!vec_res.empty())
{
const T t = data[i];
if (t != t) // means the element is nan
if (scale_arg == 0)
{
data[i] = 0;
null_map[i] = 1;
size_t scale = 1;
SparkFloatFloorImpl<T, ScaleMode::Zero>::apply(col->getData(), scale, vec_res, null_map_data);
}
else if constexpr (std::is_same<T, float>::value) // means the float type element is inf
else if (scale_arg > 0)
{
if ((std::bit_cast<uint32_t>(t) & 0b01111111111111111111111111111111) == 0b01111111100000000000000000000000)
{
data[i] = 0;
null_map[i] = 1;
}
size_t scale = intExp10(scale_arg);
SparkFloatFloorImpl<T, ScaleMode::Positive>::apply(col->getData(), scale, vec_res, null_map_data);
}
else if constexpr (std::is_same<T, double>::value) // means the double type element is inf
else
{
if ((std::bit_cast<uint64_t>(t) & 0b0111111111111111111111111111111111111111111111111111111111111111)
== 0b0111111111110000000000000000000000000000000000000000000000000000)
{
data[i] = 0;
null_map[i] = 1;
}
size_t scale = intExp10(-scale_arg);
SparkFloatFloorImpl<T, ScaleMode::Negative>::apply(col->getData(), scale, vec_res, null_map_data);
}
}
return DB::ColumnNullable::create(std::move(col_res), std::move(null_map_col));
}
};

}

0 comments on commit fa11505

Please sign in to comment.