diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h index 0bd28b116d9a..432595e09140 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h @@ -27,11 +27,15 @@ namespace local_engine { using namespace DB; -template +template class BaseFloatRoundingHalfUpComputation; +#ifdef __SSE4_1__ + +/// vectorized implementation for x86 + template <> -class BaseFloatRoundingHalfUpComputation +class BaseFloatRoundingHalfUpComputation { public: using ScalarType = Float32; @@ -59,7 +63,7 @@ class BaseFloatRoundingHalfUpComputation }; template <> -class BaseFloatRoundingHalfUpComputation +class BaseFloatRoundingHalfUpComputation { public: using ScalarType = Float64; @@ -86,13 +90,43 @@ class BaseFloatRoundingHalfUpComputation static VectorType prepare(size_t scale) { return load1(scale); } }; +/// end __SSE4_1__ +#endif + +/// Sequential implementation for ARM. Also used for scalar arguments + +template +class BaseFloatRoundingHalfUpComputation +{ +public: + using ScalarType = T; + using VectorType = T; + static const size_t data_count = 1; + + static VectorType load(const ScalarType * in) { return *in; } + static VectorType load1(const ScalarType in) { return in; } + static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} + static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } + static VectorType divide(VectorType val, VectorType scale) { return val / scale; } + template + static VectorType apply(VectorType val) + { + return roundWithMode(val, mode); + } + + static VectorType prepare(size_t scale) + { + return load1(scale); + } +}; + /** Implementation of low-level round-off functions for floating-point values. */ -template -class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation +template +class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation { - using Base = BaseFloatRoundingHalfUpComputation; + using Base = BaseFloatRoundingHalfUpComputation; public: static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out) @@ -124,15 +158,22 @@ struct FloatRoundingHalfUpImpl private: static_assert(!is_decimal); - using Op = FloatRoundingHalfUpComputation; - using Data = std::array; + template + using Op = FloatRoundingHalfUpComputation; + using Data = std::array::data_count>; using ColumnType = ColumnVector; using Container = typename ColumnType::Container; public: static NO_INLINE void apply(const Container & in, size_t scale, Container & out) { - auto mm_scale = Op::prepare(scale); + auto mm_scale = Op<>::prepare(scale); const size_t data_count = std::tuple_size(); @@ -144,7 +185,7 @@ struct FloatRoundingHalfUpImpl while (p_in < limit) { - Op::compute(p_in, mm_scale, p_out); + Op<>::compute(p_in, mm_scale, p_out); p_in += data_count; p_out += data_count; } @@ -157,7 +198,7 @@ struct FloatRoundingHalfUpImpl size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in); memcpy(&tmp_src, p_in, tail_size_bytes); - Op::compute(reinterpret_cast(&tmp_src), mm_scale, reinterpret_cast(&tmp_dst)); + Op<>::compute(reinterpret_cast(&tmp_src), mm_scale, reinterpret_cast(&tmp_dst)); memcpy(p_out, &tmp_dst, tail_size_bytes); } }