Skip to content

Commit

Permalink
[GLUTEN-7717][CH] [ARM]fix compile issue for SparkFunctionRoundHalfUp (
Browse files Browse the repository at this point in the history
…#7718)

* [GLUTEN-7717][CH] [ARM]fix compile issue for SparkFunctionRoundHalfUp

* fix hand type issue
  • Loading branch information
loudongfeng authored Oct 30, 2024
1 parent 045e33e commit 24feeed
Showing 1 changed file with 52 additions and 11 deletions.
63 changes: 52 additions & 11 deletions cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ namespace local_engine
{
using namespace DB;

template <typename T>
template <typename T, Vectorize vectorize>
class BaseFloatRoundingHalfUpComputation;

#ifdef __SSE4_1__

/// vectorized implementation for x86

template <>
class BaseFloatRoundingHalfUpComputation<Float32>
class BaseFloatRoundingHalfUpComputation<Float32, Vectorize::Yes>
{
public:
using ScalarType = Float32;
Expand Down Expand Up @@ -59,7 +63,7 @@ class BaseFloatRoundingHalfUpComputation<Float32>
};

template <>
class BaseFloatRoundingHalfUpComputation<Float64>
class BaseFloatRoundingHalfUpComputation<Float64, Vectorize::Yes>
{
public:
using ScalarType = Float64;
Expand All @@ -86,13 +90,43 @@ class BaseFloatRoundingHalfUpComputation<Float64>
static VectorType prepare(size_t scale) { return load1(scale); }
};

/// end __SSE4_1__
#endif

/// Sequential implementation for ARM. Also used for scalar arguments

template <typename T>
class BaseFloatRoundingHalfUpComputation<T, Vectorize::No>
{
public:
using ScalarType = T;
using VectorType = T;
static const size_t data_count = 1;

static VectorType load(const ScalarType * in) { return *in; }
static VectorType load1(const ScalarType in) { return in; }
static VectorType store(ScalarType * out, ScalarType val) { return *out = val;}
static VectorType multiply(VectorType val, VectorType scale) { return val * scale; }
static VectorType divide(VectorType val, VectorType scale) { return val / scale; }
template <RoundingMode mode>
static VectorType apply(VectorType val)
{
return roundWithMode(val, mode);
}

static VectorType prepare(size_t scale)
{
return load1(scale);
}
};


/** Implementation of low-level round-off functions for floating-point values.
*/
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, Vectorize vectorize>
class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T, vectorize>
{
using Base = BaseFloatRoundingHalfUpComputation<T>;
using Base = BaseFloatRoundingHalfUpComputation<T, vectorize>;

public:
static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out)
Expand Down Expand Up @@ -124,15 +158,22 @@ struct FloatRoundingHalfUpImpl
private:
static_assert(!is_decimal<T>);

using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode>;
using Data = std::array<T, Op::data_count>;
template <Vectorize vectorize =
#ifdef __SSE4_1__
Vectorize::Yes
#else
Vectorize::No
#endif
>
using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode, vectorize>;
using Data = std::array<T, Op<>::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;

public:
static NO_INLINE void apply(const Container & in, size_t scale, Container & out)
{
auto mm_scale = Op::prepare(scale);
auto mm_scale = Op<>::prepare(scale);

const size_t data_count = std::tuple_size<Data>();

Expand All @@ -144,7 +185,7 @@ struct FloatRoundingHalfUpImpl

while (p_in < limit)
{
Op::compute(p_in, mm_scale, p_out);
Op<>::compute(p_in, mm_scale, p_out);
p_in += data_count;
p_out += data_count;
}
Expand All @@ -157,7 +198,7 @@ struct FloatRoundingHalfUpImpl
size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in);

memcpy(&tmp_src, p_in, tail_size_bytes);
Op::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst));
Op<>::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst));
memcpy(p_out, &tmp_dst, tail_size_bytes);
}
}
Expand Down

0 comments on commit 24feeed

Please sign in to comment.