From 06512e98c8e550ee34fd21a686f1f564fcfdf81e Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Sat, 12 Oct 2024 17:54:53 +0800 Subject: [PATCH 01/10] Support hyper log log plus plus(HLL++) Signed-off-by: Chong Gao --- cpp/CMakeLists.txt | 1 + cpp/include/cudf/aggregation.hpp | 82 ++- .../cudf/detail/aggregation/aggregation.hpp | 70 ++ cpp/include/cudf/hashing/detail/xxhash_64.cuh | 294 ++++++++ cpp/src/aggregation/aggregation.cpp | 48 ++ cpp/src/groupby/sort/aggregate.cpp | 34 + .../sort/group_hyper_log_log_plus_plus.cu | 663 ++++++++++++++++++ cpp/src/groupby/sort/group_reductions.hpp | 15 + cpp/src/hash/xxhash_64.cu | 268 +------ cpp/tests/CMakeLists.txt | 1 + cpp/tests/groupby/hllpp_tests.cpp | 73 ++ .../main/java/ai/rapids/cudf/Aggregation.java | 74 +- .../ai/rapids/cudf/GroupByAggregation.java | 8 + .../ai/rapids/cudf/ReductionAggregation.java | 8 + java/src/main/native/src/AggregationJni.cpp | 24 +- .../test/java/ai/rapids/cudf/TableTest.java | 15 +- 16 files changed, 1372 insertions(+), 306 deletions(-) create mode 100644 cpp/include/cudf/hashing/detail/xxhash_64.cuh create mode 100644 cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu create mode 100644 cpp/tests/groupby/hllpp_tests.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 506f6c185f5..27269bbb200 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -422,6 +422,7 @@ add_library( src/groupby/sort/group_correlation.cu src/groupby/sort/group_count.cu src/groupby/sort/group_histogram.cu + src/groupby/sort/group_hyper_log_log_plus_plus.cu src/groupby/sort/group_m2.cu src/groupby/sort/group_max.cu src/groupby/sort/group_min.cu diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index f5f514d26d9..355e4f59f60 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -84,43 +84,45 @@ class aggregation { * @brief Possible aggregation operations */ enum Kind { - SUM, ///< sum reduction - PRODUCT, ///< product reduction - MIN, ///< min reduction - MAX, ///< max reduction - COUNT_VALID, ///< count number of valid elements - COUNT_ALL, ///< count number of elements - ANY, ///< any reduction - ALL, ///< all reduction - SUM_OF_SQUARES, ///< sum of squares reduction - MEAN, ///< arithmetic mean reduction - M2, ///< sum of squares of differences from the mean - VARIANCE, ///< variance - STD, ///< standard deviation - MEDIAN, ///< median reduction - QUANTILE, ///< compute specified quantile(s) - ARGMAX, ///< Index of max element - ARGMIN, ///< Index of min element - NUNIQUE, ///< count number of unique elements - NTH_ELEMENT, ///< get the nth element - ROW_NUMBER, ///< get row-number of current index (relative to rolling window) - EWMA, ///< get exponential weighted moving average at current index - RANK, ///< get rank of current index - COLLECT_LIST, ///< collect values into a list - COLLECT_SET, ///< collect values into a list without duplicate entries - LEAD, ///< window function, accesses row at specified offset following current row - LAG, ///< window function, accesses row at specified offset preceding current row - PTX, ///< PTX UDF based reduction - CUDA, ///< CUDA UDF based reduction - MERGE_LISTS, ///< merge multiple lists values into one list - MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries - MERGE_M2, ///< merge partial values of M2 aggregation, - COVARIANCE, ///< covariance between two sets of elements - CORRELATION, ///< correlation between two sets of elements - TDIGEST, ///< create a tdigest from a set of input values - MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together - HISTOGRAM, ///< compute frequency of each element - MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation, + SUM, ///< sum reduction + PRODUCT, ///< product reduction + MIN, ///< min reduction + MAX, ///< max reduction + COUNT_VALID, ///< count number of valid elements + COUNT_ALL, ///< count number of elements + ANY, ///< any reduction + ALL, ///< all reduction + SUM_OF_SQUARES, ///< sum of squares reduction + MEAN, ///< arithmetic mean reduction + M2, ///< sum of squares of differences from the mean + VARIANCE, ///< variance + STD, ///< standard deviation + MEDIAN, ///< median reduction + QUANTILE, ///< compute specified quantile(s) + ARGMAX, ///< Index of max element + ARGMIN, ///< Index of min element + NUNIQUE, ///< count number of unique elements + NTH_ELEMENT, ///< get the nth element + ROW_NUMBER, ///< get row-number of current index (relative to rolling window) + EWMA, ///< get exponential weighted moving average at current index + RANK, ///< get rank of current index + COLLECT_LIST, ///< collect values into a list + COLLECT_SET, ///< collect values into a list without duplicate entries + LEAD, ///< window function, accesses row at specified offset following current row + LAG, ///< window function, accesses row at specified offset preceding current row + PTX, ///< PTX UDF based reduction + CUDA, ///< CUDA UDF based reduction + MERGE_LISTS, ///< merge multiple lists values into one list + MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries + MERGE_M2, ///< merge partial values of M2 aggregation, + COVARIANCE, ///< covariance between two sets of elements + CORRELATION, ///< correlation between two sets of elements + TDIGEST, ///< create a tdigest from a set of input values + MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together + HISTOGRAM, ///< compute frequency of each element + MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation + HLLPP, ///< approximating the number of distinct items by using hyper log log plus plus (HLLPP) + MERGE_HLLPP ///< merge partial values of HLLPP aggregation }; aggregation() = delete; @@ -770,5 +772,11 @@ std::unique_ptr make_tdigest_aggregation(int max_centroids = 1000); template std::unique_ptr make_merge_tdigest_aggregation(int max_centroids = 1000); +template +std::unique_ptr make_hyper_log_log_aggregation(int num_registers_per_sketch); + +template +std::unique_ptr make_merge_hyper_log_log_aggregation(int const num_registers_per_sketch); + /** @} */ // end of group } // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 6661a461b8b..e9a37cf5217 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -104,6 +104,10 @@ class simple_aggregations_collector { // Declares the interface for the simple class tdigest_aggregation const& agg); virtual std::vector> visit( data_type col_type, class merge_tdigest_aggregation const& agg); + virtual std::vector> visit( + data_type col_type, class hyper_log_log_aggregation const& agg); + virtual std::vector> visit( + data_type col_type, class merge_hyper_log_log_aggregation const& agg); }; class aggregation_finalizer { // Declares the interface for the finalizer @@ -144,6 +148,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer virtual void visit(class tdigest_aggregation const& agg); virtual void visit(class merge_tdigest_aggregation const& agg); virtual void visit(class ewma_aggregation const& agg); + virtual void visit(class hyper_log_log_aggregation const& agg); + virtual void visit(class merge_hyper_log_log_aggregation const& agg); }; /** @@ -1186,6 +1192,54 @@ class merge_tdigest_aggregation final : public groupby_aggregation, public reduc void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; +/** + * @brief Derived aggregation class for specifying TDIGEST aggregation + */ +class hyper_log_log_aggregation final : public groupby_aggregation, public reduce_aggregation { + public: + explicit hyper_log_log_aggregation(int const precision_) + : aggregation{HLLPP}, precision(precision_) + { + } + + int const precision; + + [[nodiscard]] std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived aggregation class for specifying MERGE_TDIGEST aggregation + */ +class merge_hyper_log_log_aggregation final : public groupby_aggregation, + public reduce_aggregation { + public: + explicit merge_hyper_log_log_aggregation(int const precision_) + : aggregation{MERGE_HLLPP}, precision(precision_) + { + } + int const precision; + + [[nodiscard]] std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + /** * @brief Sentinel value used for `ARGMAX` aggregation. * @@ -1319,6 +1373,12 @@ struct target_type_impl { using type = double; }; +// Always use list for HLLPP +template +struct target_type_impl { + using type = list_view; +}; + // Always use `double` for VARIANCE template struct target_type_impl { @@ -1426,6 +1486,12 @@ struct target_type_impl { using type = struct_view; }; +// Always use list for MERGE_HLLPP +template +struct target_type_impl { + using type = list_view; +}; + // Use list for MERGE_HISTOGRAM template struct target_type_impl { @@ -1579,6 +1645,10 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind return f.template operator()(std::forward(args)...); case aggregation::EWMA: return f.template operator()(std::forward(args)...); + case aggregation::HLLPP: + return f.template operator()(std::forward(args)...); + case aggregation::MERGE_HLLPP: + return f.template operator()(std::forward(args)...); default: { #ifndef __CUDA_ARCH__ CUDF_FAIL("Unsupported aggregation."); diff --git a/cpp/include/cudf/hashing/detail/xxhash_64.cuh b/cpp/include/cudf/hashing/detail/xxhash_64.cuh new file mode 100644 index 00000000000..eaf85dae5e9 --- /dev/null +++ b/cpp/include/cudf/hashing/detail/xxhash_64.cuh @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace cudf::hashing::detail { +using hash_value_type = uint64_t; + +template +struct XXHash_64 { + using result_type = hash_value_type; + + constexpr XXHash_64() = default; + constexpr XXHash_64(hash_value_type seed) : m_seed(seed) {} + + __device__ inline uint32_t getblock32(std::byte const* data, std::size_t offset) const + { + // Read a 4-byte value from the data pointer as individual bytes for safe + // unaligned access (very likely for string types). + auto block = reinterpret_cast(data + offset); + return block[0] | (block[1] << 8) | (block[2] << 16) | (block[3] << 24); + } + + __device__ inline uint64_t getblock64(std::byte const* data, std::size_t offset) const + { + uint64_t result = getblock32(data, offset + 4); + result = result << 32; + return result | getblock32(data, offset); + } + + result_type __device__ inline operator()(Key const& key) const { return compute(key); } + + template + result_type __device__ inline compute(T const& key) const + { + auto data = device_span(reinterpret_cast(&key), sizeof(T)); + return compute_bytes(data); + } + + result_type __device__ inline compute_remaining_bytes(device_span& in, + std::size_t offset, + result_type h64) const + { + // remaining data can be processed in 8-byte chunks + if ((in.size() % 32) >= 8) { + for (; offset <= in.size() - 8; offset += 8) { + uint64_t k1 = getblock64(in.data(), offset) * prime2; + + k1 = rotate_bits_left(k1, 31) * prime1; + h64 ^= k1; + h64 = rotate_bits_left(h64, 27) * prime1 + prime4; + } + } + + // remaining data can be processed in 4-byte chunks + if ((in.size() % 8) >= 4) { + for (; offset <= in.size() - 4; offset += 4) { + h64 ^= (getblock32(in.data(), offset) & 0xfffffffful) * prime1; + h64 = rotate_bits_left(h64, 23) * prime2 + prime3; + } + } + + // and the rest + if (in.size() % 4) { + while (offset < in.size()) { + h64 ^= (std::to_integer(in[offset]) & 0xff) * prime5; + h64 = rotate_bits_left(h64, 11) * prime1; + ++offset; + } + } + return h64; + } + + result_type __device__ compute_bytes(device_span& in) const + { + uint64_t offset = 0; + uint64_t h64; + // data can be processed in 32-byte chunks + if (in.size() >= 32) { + auto limit = in.size() - 32; + uint64_t v1 = m_seed + prime1 + prime2; + uint64_t v2 = m_seed + prime2; + uint64_t v3 = m_seed; + uint64_t v4 = m_seed - prime1; + + do { + // pipeline 4*8byte computations + v1 += getblock64(in.data(), offset) * prime2; + v1 = rotate_bits_left(v1, 31); + v1 *= prime1; + offset += 8; + v2 += getblock64(in.data(), offset) * prime2; + v2 = rotate_bits_left(v2, 31); + v2 *= prime1; + offset += 8; + v3 += getblock64(in.data(), offset) * prime2; + v3 = rotate_bits_left(v3, 31); + v3 *= prime1; + offset += 8; + v4 += getblock64(in.data(), offset) * prime2; + v4 = rotate_bits_left(v4, 31); + v4 *= prime1; + offset += 8; + } while (offset <= limit); + + h64 = rotate_bits_left(v1, 1) + rotate_bits_left(v2, 7) + rotate_bits_left(v3, 12) + + rotate_bits_left(v4, 18); + + v1 *= prime2; + v1 = rotate_bits_left(v1, 31); + v1 *= prime1; + h64 ^= v1; + h64 = h64 * prime1 + prime4; + + v2 *= prime2; + v2 = rotate_bits_left(v2, 31); + v2 *= prime1; + h64 ^= v2; + h64 = h64 * prime1 + prime4; + + v3 *= prime2; + v3 = rotate_bits_left(v3, 31); + v3 *= prime1; + h64 ^= v3; + h64 = h64 * prime1 + prime4; + + v4 *= prime2; + v4 = rotate_bits_left(v4, 31); + v4 *= prime1; + h64 ^= v4; + h64 = h64 * prime1 + prime4; + } else { + h64 = m_seed + prime5; + } + + h64 += in.size(); + + h64 = compute_remaining_bytes(in, offset, h64); + + return finalize(h64); + } + + constexpr __host__ __device__ std::uint64_t finalize(std::uint64_t h) const noexcept + { + h ^= h >> 33; + h *= prime2; + h ^= h >> 29; + h *= prime3; + h ^= h >> 32; + return h; + } + + private: + hash_value_type m_seed{}; + static constexpr uint64_t prime1 = 0x9e3779b185ebca87ul; + static constexpr uint64_t prime2 = 0xc2b2ae3d27d4eb4ful; + static constexpr uint64_t prime3 = 0x165667b19e3779f9ul; + static constexpr uint64_t prime4 = 0x85ebca77c2b2ae63ul; + static constexpr uint64_t prime5 = 0x27d4eb2f165667c5ul; +}; + +template <> +hash_value_type __device__ inline XXHash_64::operator()(bool const& key) const +{ + return compute(static_cast(key)); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()(float const& key) const +{ + return compute(normalize_nans(key)); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()(double const& key) const +{ + return compute(normalize_nans(key)); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()( + cudf::string_view const& key) const +{ + auto const len = key.size_bytes(); + auto data = device_span(reinterpret_cast(key.data()), len); + return compute_bytes(data); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()( + numeric::decimal32 const& key) const +{ + return compute(key.value()); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()( + numeric::decimal64 const& key) const +{ + return compute(key.value()); +} + +template <> +hash_value_type __device__ inline XXHash_64::operator()( + numeric::decimal128 const& key) const +{ + return compute(key.value()); +} + +/** + * @brief Computes the hash value of a row in the given table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ +template +class xxhash_64_device_row_hasher { + public: + xxhash_64_device_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) + : _check_nulls(nulls), _table(t), _seed(seed) + { + } + + __device__ auto operator()(size_type row_index) const noexcept + { + return cudf::detail::accumulate( + _table.begin(), + _table.end(), + _seed, + [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { + return cudf::type_dispatcher( + column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); + }); + } + + /** + * @brief Computes the hash value of an element in the given column. + */ + class element_hasher_adapter { + public: + template ())> + __device__ hash_value_type operator()(column_device_view const& col, + size_type const row_index, + Nullate const _check_nulls, + hash_value_type const _seed) const noexcept + { + if (_check_nulls && col.is_null(row_index)) { + return std::numeric_limits::max(); + } + auto const hasher = XXHash_64{_seed}; + return hasher(col.element(row_index)); + } + + template ())> + __device__ hash_value_type operator()(column_device_view const&, + size_type const, + Nullate const, + hash_value_type const) const noexcept + { + CUDF_UNREACHABLE("Unsupported type for XXHash_64"); + } + }; + + Nullate const _check_nulls; + table_device_view const _table; + hash_value_type const _seed; +}; + +} // namespace cudf::hashing::detail \ No newline at end of file diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index a60a7f63882..ca47f1d4d60 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -237,6 +237,18 @@ std::vector> simple_aggregations_collector::visit( return visit(col_type, static_cast(agg)); } +std::vector> simple_aggregations_collector::visit( + data_type col_type, hyper_log_log_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, merge_hyper_log_log_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + // aggregation_finalizer ---------------------------------------- void aggregation_finalizer::visit(aggregation const& agg) {} @@ -410,6 +422,16 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg) visit(static_cast(agg)); } +void aggregation_finalizer::visit(hyper_log_log_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(merge_hyper_log_log_aggregation const& agg) +{ + visit(static_cast(agg)); +} + } // namespace detail std::vector> aggregation::get_simple_aggregations( @@ -917,6 +939,32 @@ make_merge_tdigest_aggregation(int max_centroids); template CUDF_EXPORT std::unique_ptr make_merge_tdigest_aggregation(int max_centroids); +/// Factory to create a HLLPP aggregation +template +std::unique_ptr make_hyper_log_log_aggregation(int const precision) +{ + return std::make_unique(precision); +} +template CUDF_EXPORT std::unique_ptr make_hyper_log_log_aggregation( + int precision); +template CUDF_EXPORT std::unique_ptr +make_hyper_log_log_aggregation(int precision); +template CUDF_EXPORT std::unique_ptr +make_hyper_log_log_aggregation(int precision); + +/// Factory to create a MERGE_HLLPP aggregation +template +std::unique_ptr make_merge_hyper_log_log_aggregation(int const precision) +{ + return std::make_unique(precision); +} +template CUDF_EXPORT std::unique_ptr make_merge_hyper_log_log_aggregation( + int const precision); +template CUDF_EXPORT std::unique_ptr +make_merge_hyper_log_log_aggregation(int const precision); +template CUDF_EXPORT std::unique_ptr +make_merge_hyper_log_log_aggregation(int const precision); + namespace detail { namespace { struct target_type_functor { diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 3041e261945..814577fa2dd 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -749,6 +749,22 @@ void aggregate_result_functor::operator()(aggregation cons mr)); } +template <> +void aggregate_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(values, agg)) { return; } + + int const precision = dynamic_cast(agg).precision; + cache.add_result(values, + agg, + detail::group_hyper_log_log_plus_plus(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + precision, + stream, + mr)); +} + /** * @brief Generate a merged tdigest column from a grouped set of input tdigest columns. * @@ -791,6 +807,24 @@ void aggregate_result_functor::operator()(aggregatio mr)); } +template <> +void aggregate_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(values, agg)) { return; } + + int const precision = + dynamic_cast(agg).precision; + + cache.add_result(values, + agg, + detail::group_merge_hyper_log_log_plus_plus(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + precision, + stream, + mr)); +} + } // namespace detail // Sort-based groupby diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu new file mode 100644 index 00000000000..05580c1a685 --- /dev/null +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -0,0 +1,663 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include // TODO #include once available +#include +#include +#include +#include + +namespace cudf { +namespace groupby { +namespace detail { +namespace { + +// The number of bits required by register value. Register value stores num of zeros. +// XXHash64 value is 64 bits, it's safe to use 6 bits to store a register value. +constexpr int REGISTER_VALUE_BITS = 6; + +// MASK binary 6 bits: 111111 +constexpr uint64_t MASK = (1L << REGISTER_VALUE_BITS) - 1L; + +// One long stores 10 register values +constexpr int REGISTERS_PER_LONG = 64 / REGISTER_VALUE_BITS; + +// XXHash seed +constexpr int64_t SEED = 42L; + +/** + * + * Computes register values from hash values and partially groups from register values. + * It splits input into multiple segments with num_hashs_per_thread length. + * Each thread scans in its segment, find the max register values for all the values + * at the same register index at the same group, outputs gathered result when meets a new group, + * and in the end each thread saves a buffer for the last group in the segment. + * + * In this way, we can save memory usage, only need to cache `num_threads` caches. + * + * num_threads = div_round_up(num_hashs, num_hashs_per_thread). + * + * After register values are computed. + * + * e.g.: num_registers_per_sketch = 512 and num_hashs_per_thread = 4; + * + * Input: + * register_index register_value group_lable + * [ + * (0, 1), 0 + * (0, 2), 0 + * (1, 1), 1 // meets a new group, outputs result for g0 + * (1, 9), 1 // outputs for thread 0 when scan to here + * (1, 1), 1 + * (1, 1), 1 + * (1, 5), 1 + * (1, 1), 1 // outputs for thread 1; Output result for g1 + * ] + * Output e.g.: + * + * group_lables_thread_cache: + * [ + * g0 + * g1 + * ] + * Has num_threads rows. + * + * registers_thread_cache: + * [ + * 512 values: [0, 9, 0, ... ] // register values for group 1 + * 512 values: [0, 5, 0, ... ] // register values for group 1 + * ] + * Has num_threads rows, each row is corresponding to `group_lables_thread_cache` + * + * registers_output_cache: + * [ + * 512 values: [2, 0, 0, ... ] // register values for group 0 + * 512 values: [0, 5, 0, ... ] // register values for group 1 + * ] + * Has num_groups rows. + * + * The next kernel will merge the partial result to final result + */ +template +CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( + column_device_view hashs, + cudf::device_span group_lables, + int64_t const precision, // num of bits for register addressing, e.g.: 9 + int* const registers_output_cache, // num is num_groups * num_registers_per_sketch + int* const registers_thread_cache, // num is num_threads * num_registers_per_sketch + size_type* const group_lables_thread_cache // save the group lables for each thread +) +{ + auto const tid = cudf::detail::grid_1d::global_thread_id(); + int64_t const num_hashs = hashs.size(); + if (tid * num_hashs_per_thread >= hashs.size()) { return; } + + // 2^precision = num_registers_per_sketch + int64_t num_registers_per_sketch = 1L << precision; + // e.g.: integer in binary: 1 0000 0000 + uint64_t const w_padding = 1ULL << (precision - 1); + // e.g.: 64 - 9 = 55 + int const idx_shift = 64 - precision; + + auto const hash_first = tid * num_hashs_per_thread; + auto const hash_end = cuda::std::min((tid + 1) * num_hashs_per_thread, num_hashs); + + // init sketches for each thread + int* const sketch_ptr = registers_thread_cache + tid * num_registers_per_sketch; + for (auto i = 0; i < num_registers_per_sketch; i++) { + sketch_ptr[i] = 0; + } + + size_type prev_group = group_lables[hash_first]; + for (auto hash_idx = hash_first; hash_idx < hash_end; hash_idx++) { + size_type curr_group = group_lables[hash_idx]; + + // cast to unsigned, then >> will shift without preserve the sign bit. + uint64_t const hash = static_cast(hashs.element(hash_idx)); + auto const reg_idx = hash >> idx_shift; + int const reg_v = + static_cast(cuda::std::countl_zero((hash << precision) | w_padding) + 1ULL); + + if (curr_group == prev_group) { + // still in the same group, update the max value + if (reg_v > sketch_ptr[reg_idx]) { sketch_ptr[reg_idx] = reg_v; } + } else { + // meets new group, save output for the previous group + for (auto i = 0; i < num_registers_per_sketch; i++) { + registers_output_cache[prev_group * num_registers_per_sketch + i] = sketch_ptr[i]; + } + + // reset cache + for (auto i = 0; i < num_registers_per_sketch; i++) { + sketch_ptr[i] = 0; + } + + // save the max value + sketch_ptr[reg_idx] = reg_v; + } + + // special logic for the last sketch in this thread + if (hash_idx == hash_end - 1) { + // meets the last hash in the segment + if (hash_idx == num_hashs - 1) { + // this segment is the last one + for (auto i = 0; i < num_registers_per_sketch; i++) { + registers_output_cache[curr_group * num_registers_per_sketch + i] = sketch_ptr[i]; + } + } else { + // not the last segment, proble one item forward. + if (curr_group != group_lables[hash_idx + 1]) { + for (auto i = 0; i < num_registers_per_sketch; i++) { + registers_output_cache[curr_group * num_registers_per_sketch + i] = sketch_ptr[i]; + } + } + } + } + + prev_group = curr_group; + } + + // save the group lable for this thread + group_lables_thread_cache[tid] = group_lables[hash_end - 1]; +} + +/* + * + * Merge sketches vertically. + * + * For all register at the same index, starts a thread to merge the max value. + * num_threads = num_registers_per_sketch. + * + * Input e.g.: + * + * group_lables_thread_cache: + * [ + * g0 + * g0 + * g1 + * ... + * gN + * ] + * Has num_threads rows. + * + * registers_thread_cache: + * [ + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 + * ... + * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N + * ] + * Has num_threads rows, each row is corresponding to `group_lables_thread_cache` + * + * registers_output_cache: + * [ + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 + * ... + * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N + * ] + * Has num_groups rows. + * + */ +CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, + int64_t num_registers_per_sketch, + int* const registers_output_cache, + int const* const registers_thread_cache, + size_type const* const group_lables_thread_cache) +{ + // register idx is tid + auto const tid = cudf::detail::grid_1d::global_thread_id(); + int reg_max = 0; + int prev_group = group_lables_thread_cache[0]; + for (auto i = 0; i < num_sketches; i++) { + int curr_group = group_lables_thread_cache[i]; + int curr_reg_v = registers_thread_cache[tid + i * num_registers_per_sketch]; + if (curr_group == prev_group) { + if (curr_reg_v > reg_max) { reg_max = curr_reg_v; } + } else { + // meets a new group, store the result for previous group + int64_t reg_idx = prev_group * num_registers_per_sketch + tid; + int curr_reg_v = registers_output_cache[reg_idx]; + if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; } + + reg_max = curr_reg_v; + } + prev_group = curr_group; + } + + // handles the last register in this thread + int64_t reg_idx = prev_group * num_registers_per_sketch + tid; + int curr_reg_v = registers_output_cache[reg_idx]; + if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; } +} + +/** + * + * Compact register values, compact 10 registers values (each is 6 bits) in to a long. + * Number of threads is num_groups * num_longs_per_sketch + * + * e.g.: + * Input: + * registers_output_cache: + * [ + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 + * ... + * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N + * ] + * Has num_groups rows. + * + * Output: + * 52 long columns + * + * r0 to r9 integers are all: 00000000-00000000-00000000-00100001, tailing 6 bits: 100-001 + * Compact to one long is: 100001-100001-100001-100001-100001-100001-100001-100001-100001-100001 + */ +CUDF_KERNEL void compact_kernel(int64_t const num_groups, + int64_t const num_registers_per_sketch, + cudf::device_span sketches_output, + // num_groups * num_registers_per_sketch integers + cudf::device_span registers_output_cache) +{ + int64_t const tid = cudf::detail::grid_1d::global_thread_id(); + + int64_t const num_longs_per_sketch = sketches_output.size(); + if (tid >= num_groups * num_longs_per_sketch) { return; } + + int64_t const group_idx = tid / num_longs_per_sketch; + int64_t const long_idx = tid % num_longs_per_sketch; + + int64_t const reg_begin_idx = + group_idx * num_registers_per_sketch + long_idx * REGISTERS_PER_LONG; + int64_t num_regs = REGISTERS_PER_LONG; + if (long_idx == num_longs_per_sketch - 1) { + num_regs = num_registers_per_sketch % REGISTERS_PER_LONG; + } + + int64_t ten_registers = 0; + for (auto i = 0; i < num_regs; i++) { + int64_t reg_v = registers_output_cache[reg_begin_idx + i]; + int64_t tmp = reg_v << (REGISTER_VALUE_BITS * i); + ten_registers |= tmp; + } + + sketches_output[long_idx][group_idx] = ten_registers; +} + +std::unique_ptr group_hllpp(column_view const& input, + int64_t const num_groups, + cudf::device_span group_lables, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int64_t num_registers_per_sketch = 1 << precision; + + // 1. compute all the hashs + auto hash_col = + make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + bool const nullable = input.has_nulls(); + thrust::tabulate( + rmm::exec_policy(stream), + hash_col->mutable_view().begin(), + hash_col->mutable_view().end(), + cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED)); + auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream); + + // 2. execute partial group by + constexpr int64_t block_size = 256; + constexpr int64_t num_hashs_per_thread = 32; // handles 32 items per thread + int64_t total_threads_partial_group = + cudf::util::div_rounding_up_safe(static_cast(input.size()), num_hashs_per_thread); + int64_t num_blocks_p1 = cudf::util::div_rounding_up_safe(total_threads_partial_group, block_size); + + auto sketches_output = + rmm::device_uvector(num_groups * num_registers_per_sketch, stream, mr); + auto registers_thread_cache = rmm::device_uvector( + total_threads_partial_group * num_registers_per_sketch, stream, mr); + auto group_lables_thread_cache = + rmm::device_uvector(total_threads_partial_group, stream, mr); + + partial_group_sketches_from_hashs_kernel + <<>>(*d_hashs, + group_lables, + precision, + sketches_output.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + + // 3. merge the intermidate result + auto num_merge_threads = num_registers_per_sketch; + auto num_merge_blocks = cudf::util::div_rounding_up_safe(num_merge_threads, block_size); + merge_sketches_vertically<<>>( + total_threads_partial_group, // num_sketches + num_registers_per_sketch, + sketches_output.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + + // 4. create output columns + auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; + auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) { + return make_numeric_column( + data_type{type_id::INT64}, num_groups, mask_state::ALL_VALID, stream, mr); + }); + auto children = std::vector>(results_iter, results_iter + num_long_cols); + auto d_results = [&] { + auto host_results_pointer_iter = + thrust::make_transform_iterator(children.begin(), [](auto const& results_column) { + return results_column->mutable_view().template data(); + }); + auto host_results_pointers = + std::vector(host_results_pointer_iter, host_results_pointer_iter + children.size()); + return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); + }(); + auto result = cudf::make_structs_column(num_groups, + std::move(children), + 0, // null count + rmm::device_buffer{}, // null mask + stream); + + // 4. compact sketches + auto num_phase3_threads = num_groups * num_long_cols; + auto num_phase3_blocks = cudf::util::div_rounding_up_safe(num_phase3_threads, block_size); + compact_kernel<<>>( + num_groups, num_registers_per_sketch, d_results, sketches_output); + + return result; +} + +__device__ inline int get_register_value(int64_t const long_10_registers, int reg_idx) +{ + int64_t shift_mask = MASK << (REGISTER_VALUE_BITS * reg_idx); + int64_t v = (long_10_registers & shift_mask) >> (REGISTER_VALUE_BITS * reg_idx); + return static_cast(v); +} + +/** + * Partial groups sketches in long columns, similar to `partial_group_sketches_from_hashs_kernel` + * It split longs into segments with each has `num_longs_per_threads` elements + * e.g.: num_registers_per_sketch = 512. + * Each sketch uses 52 (512 / 10 + 1) longs. + * + * Input: + * col_0 col_1 col_51 + * sketch_0: long, long, ..., long + * sketch_1: long, long, ..., long + * sketch_2: long, long, ..., long + * + * num_threads = 52 * div_round_up(num_sketches_input, num_longs_per_threads) + * Each thread scans and merge num_longs_per_threads longs, + * and output the max register value when meets a new group. + * For the last long in a thread, outputs the result into `registers_thread_cache`. + * + * Output: + * + * group_lables_thread_cache: + * [ + * g0 + * g0 + * g1 + * ... + * gN + * ] + * Has num_threads rows. + * + * registers_thread_cache: + * [ + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 + * ... + * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N + * ] + * Has num_threads rows, each row is corresponding to `group_lables_thread_cache` + * + * registers_output_cache: + * [ + * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 + * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 + * ... + * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N + * ] + * Has num_groups rows. + * + */ +template +CUDF_KERNEL void partial_group_long_sketches_kernel( + cudf::device_span sketches_input, + int64_t const num_sketches_input, + int64_t const num_threads_per_col, + int64_t const num_registers_per_sketch, + int64_t const num_groups, + cudf::device_span group_lables, + // num_groups * num_registers_per_sketch integers + int* const registers_output_cache, + // num_threads * num_registers_per_sketch integers + int* const registers_thread_cache, + // num_threads integers + size_type* const group_lables_thread_cache) +{ + auto const tid = cudf::detail::grid_1d::global_thread_id(); + auto const num_long_cols = sketches_input.size(); + if (tid >= num_threads_per_col * num_long_cols) { return; } + + auto const long_idx = tid / num_threads_per_col; + auto const thread_idx_in_cols = tid % num_threads_per_col; + int64_t const* const longs_ptr = sketches_input[long_idx]; + + int* const registers_thread_ptr = + registers_thread_cache + thread_idx_in_cols * num_registers_per_sketch; + auto const sketch_first = thread_idx_in_cols * num_longs_per_threads; + + auto const sketch_end = cuda::std::min(sketch_first + num_longs_per_threads, num_sketches_input); + + int num_regs = REGISTERS_PER_LONG; + if (long_idx == num_long_cols - 1) { num_regs = num_registers_per_sketch % REGISTERS_PER_LONG; } + + for (auto i = 0; i < num_regs; i++) { + size_type prev_group = group_lables[sketch_first]; + int max_reg_v = 0; + int reg_idx_in_sketch = long_idx * REGISTERS_PER_LONG + i; + for (auto sketch_idx = sketch_first; sketch_idx < sketch_end; sketch_idx++) { + size_type curr_group = group_lables[sketch_idx]; + + int64_t output_idx_for_prev_group = num_registers_per_sketch * prev_group + reg_idx_in_sketch; + + int curr_reg_v = get_register_value(longs_ptr[sketch_idx], i); + if (curr_group == prev_group) { + // still in the same group, update the max value + if (curr_reg_v > max_reg_v) { max_reg_v = curr_reg_v; } + } else { + // meets new group, save output for the previous group + registers_output_cache[output_idx_for_prev_group] = max_reg_v; + + // reset the cache + max_reg_v = curr_reg_v; + } + + // special logic for the last sketch in this thread + if (sketch_idx == sketch_end - 1) { + // last long in the segment + int64_t output_idx_for_curr_group = + num_registers_per_sketch * curr_group + reg_idx_in_sketch; + if (sketch_idx == num_sketches_input - 1) { + // last segment + registers_output_cache[output_idx_for_curr_group] = max_reg_v; + max_reg_v = curr_reg_v; + } else { + if (curr_group != group_lables[sketch_idx + 1]) { + // look one more forward + registers_output_cache[output_idx_for_curr_group] = max_reg_v; + max_reg_v = curr_reg_v; + } + } + } + + prev_group = curr_group; + } + + // For each thread, output register values + registers_thread_ptr[reg_idx_in_sketch] = max_reg_v; + } + if (long_idx == 0) { + group_lables_thread_cache[thread_idx_in_cols] = group_lables[sketch_end - 1]; + } +} + +std::unique_ptr merge_hyper_log_log( + column_view const& hll_input, // struct column + int64_t const num_groups, + cudf::device_span group_lables, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int64_t num_registers_per_sketch = 1 << precision; + int64_t const num_sketches = hll_input.size(); + int64_t const num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; + constexpr int64_t num_longs_per_threads = 32; + constexpr int64_t block_size = 256; + + int64_t num_threads_per_col_phase1 = + cudf::util::div_rounding_up_safe(num_sketches, num_longs_per_threads); + int64_t num_threads_phase1 = num_threads_per_col_phase1 * num_long_cols; + int64_t num_blocks = cudf::util::div_rounding_up_safe(num_threads_phase1, block_size); + auto registers_output_cache = + rmm::device_uvector(num_registers_per_sketch * num_groups, stream, mr); + { + auto registers_thread_cache = + rmm::device_uvector(num_registers_per_sketch * num_threads_phase1, stream, mr); + auto group_lables_thread_cache = rmm::device_uvector(num_threads_phase1, stream, mr); + + cudf::structs_column_view scv(hll_input); + auto const input_iter = cudf::detail::make_counting_transform_iterator( + 0, [&](int i) { return scv.get_sliced_child(i, stream).begin(); }); + auto input_cols = std::vector(input_iter, input_iter + num_long_cols); + auto d_inputs = cudf::detail::make_device_uvector_async(input_cols, stream, mr); + // 1st kernel: partially group + partial_group_long_sketches_kernel + <<>>(d_inputs, + num_sketches, + num_threads_per_col_phase1, + num_registers_per_sketch, + num_groups, + group_lables, + registers_output_cache.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + auto const num_phase2_threads = num_registers_per_sketch; + auto const num_phase2_blocks = cudf::util::div_rounding_up_safe(num_phase2_threads, block_size); + // 2nd kernel: vertical merge + merge_sketches_vertically<<>>( + num_threads_per_col_phase1, // num_sketches + num_registers_per_sketch, + registers_output_cache.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + } + + // create output columns + auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) { + return make_numeric_column( + data_type{type_id::INT64}, num_groups, mask_state::ALL_VALID, stream, mr); + }); + auto results = std::vector>(results_iter, results_iter + num_long_cols); + auto d_sketches_output = [&] { + auto host_results_pointer_iter = + thrust::make_transform_iterator(results.begin(), [](auto const& results_column) { + return results_column->mutable_view().template data(); + }); + auto host_results_pointers = + std::vector(host_results_pointer_iter, host_results_pointer_iter + results.size()); + return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); + }(); + + auto num_phase3_threads = num_groups * num_long_cols; + auto num_phase3_blocks = cudf::util::div_rounding_up_safe(num_phase3_threads, block_size); + // 3rd kernel: compact + compact_kernel<<>>( + num_groups, num_registers_per_sketch, d_sketches_output, registers_output_cache); + return make_structs_column(num_groups, std::move(results), 0, rmm::device_buffer{}); +} + +} // namespace + +/** + * Compute hyper log log against the input values and merge the sketches in the same group. + * Output is a struct column with multiple long columns which is consistent with Spark. + */ +std::unique_ptr group_hyper_log_log_plus_plus( + column_view const& input, + int64_t const num_groups, + cudf::device_span group_lables, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]"); + auto input_type = + cudf::is_dictionary(input.type()) ? dictionary_column_view(input).keys().type() : input.type(); + + return group_hllpp(input, num_groups, group_lables, precision, stream, mr); +} + +/** + * Merge sketches in the same group. + * Input is a struct column with multiple long columns which is consistent with Spark. + */ +std::unique_ptr group_merge_hyper_log_log_plus_plus( + column_view const& values, + int64_t const num_groups, + cudf::device_span group_lables, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]"); + CUDF_EXPECTS(values.type().id() == type_id::STRUCT, + "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); + for (auto i = 0; i < values.num_children(); i++) { + CUDF_EXPECTS(values.child(i).type().id() == type_id::INT64, + "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); + } + return merge_hyper_log_log(values, num_groups, group_lables, precision, stream, mr); +} + +} // namespace detail +} // namespace groupby +} // namespace cudf diff --git a/cpp/src/groupby/sort/group_reductions.hpp b/cpp/src/groupby/sort/group_reductions.hpp index f8a531094c6..ae9c441b75d 100644 --- a/cpp/src/groupby/sort/group_reductions.hpp +++ b/cpp/src/groupby/sort/group_reductions.hpp @@ -539,6 +539,21 @@ std::unique_ptr group_correlation(column_view const& covariance, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); +std::unique_ptr group_hyper_log_log_plus_plus( + column_view const& input, + int64_t const num_groups, + cudf::device_span group_lables, + int64_t const num_registers_per_sketch, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr group_merge_hyper_log_log_plus_plus( + column_view const& values, + long const num_groups, + cudf::device_span group_offsets, + long const num_registers_per_sketch, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); } // namespace detail } // namespace groupby } // namespace cudf diff --git a/cpp/src/hash/xxhash_64.cu b/cpp/src/hash/xxhash_64.cu index fad8383210b..c3cc9d87d74 100644 --- a/cpp/src/hash/xxhash_64.cu +++ b/cpp/src/hash/xxhash_64.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -31,271 +32,6 @@ namespace cudf { namespace hashing { namespace detail { -namespace { - -using hash_value_type = uint64_t; - -template -struct XXHash_64 { - using result_type = hash_value_type; - - constexpr XXHash_64() = default; - constexpr XXHash_64(hash_value_type seed) : m_seed(seed) {} - - __device__ inline uint32_t getblock32(std::byte const* data, std::size_t offset) const - { - // Read a 4-byte value from the data pointer as individual bytes for safe - // unaligned access (very likely for string types). - auto block = reinterpret_cast(data + offset); - return block[0] | (block[1] << 8) | (block[2] << 16) | (block[3] << 24); - } - - __device__ inline uint64_t getblock64(std::byte const* data, std::size_t offset) const - { - uint64_t result = getblock32(data, offset + 4); - result = result << 32; - return result | getblock32(data, offset); - } - - result_type __device__ inline operator()(Key const& key) const { return compute(key); } - - template - result_type __device__ inline compute(T const& key) const - { - auto data = device_span(reinterpret_cast(&key), sizeof(T)); - return compute_bytes(data); - } - - result_type __device__ inline compute_remaining_bytes(device_span& in, - std::size_t offset, - result_type h64) const - { - // remaining data can be processed in 8-byte chunks - if ((in.size() % 32) >= 8) { - for (; offset <= in.size() - 8; offset += 8) { - uint64_t k1 = getblock64(in.data(), offset) * prime2; - - k1 = rotate_bits_left(k1, 31) * prime1; - h64 ^= k1; - h64 = rotate_bits_left(h64, 27) * prime1 + prime4; - } - } - - // remaining data can be processed in 4-byte chunks - if ((in.size() % 8) >= 4) { - for (; offset <= in.size() - 4; offset += 4) { - h64 ^= (getblock32(in.data(), offset) & 0xfffffffful) * prime1; - h64 = rotate_bits_left(h64, 23) * prime2 + prime3; - } - } - - // and the rest - if (in.size() % 4) { - while (offset < in.size()) { - h64 ^= (std::to_integer(in[offset]) & 0xff) * prime5; - h64 = rotate_bits_left(h64, 11) * prime1; - ++offset; - } - } - return h64; - } - - result_type __device__ compute_bytes(device_span& in) const - { - uint64_t offset = 0; - uint64_t h64; - // data can be processed in 32-byte chunks - if (in.size() >= 32) { - auto limit = in.size() - 32; - uint64_t v1 = m_seed + prime1 + prime2; - uint64_t v2 = m_seed + prime2; - uint64_t v3 = m_seed; - uint64_t v4 = m_seed - prime1; - - do { - // pipeline 4*8byte computations - v1 += getblock64(in.data(), offset) * prime2; - v1 = rotate_bits_left(v1, 31); - v1 *= prime1; - offset += 8; - v2 += getblock64(in.data(), offset) * prime2; - v2 = rotate_bits_left(v2, 31); - v2 *= prime1; - offset += 8; - v3 += getblock64(in.data(), offset) * prime2; - v3 = rotate_bits_left(v3, 31); - v3 *= prime1; - offset += 8; - v4 += getblock64(in.data(), offset) * prime2; - v4 = rotate_bits_left(v4, 31); - v4 *= prime1; - offset += 8; - } while (offset <= limit); - - h64 = rotate_bits_left(v1, 1) + rotate_bits_left(v2, 7) + rotate_bits_left(v3, 12) + - rotate_bits_left(v4, 18); - - v1 *= prime2; - v1 = rotate_bits_left(v1, 31); - v1 *= prime1; - h64 ^= v1; - h64 = h64 * prime1 + prime4; - - v2 *= prime2; - v2 = rotate_bits_left(v2, 31); - v2 *= prime1; - h64 ^= v2; - h64 = h64 * prime1 + prime4; - - v3 *= prime2; - v3 = rotate_bits_left(v3, 31); - v3 *= prime1; - h64 ^= v3; - h64 = h64 * prime1 + prime4; - - v4 *= prime2; - v4 = rotate_bits_left(v4, 31); - v4 *= prime1; - h64 ^= v4; - h64 = h64 * prime1 + prime4; - } else { - h64 = m_seed + prime5; - } - - h64 += in.size(); - - h64 = compute_remaining_bytes(in, offset, h64); - - return finalize(h64); - } - - constexpr __host__ __device__ std::uint64_t finalize(std::uint64_t h) const noexcept - { - h ^= h >> 33; - h *= prime2; - h ^= h >> 29; - h *= prime3; - h ^= h >> 32; - return h; - } - - private: - hash_value_type m_seed{}; - static constexpr uint64_t prime1 = 0x9e3779b185ebca87ul; - static constexpr uint64_t prime2 = 0xc2b2ae3d27d4eb4ful; - static constexpr uint64_t prime3 = 0x165667b19e3779f9ul; - static constexpr uint64_t prime4 = 0x85ebca77c2b2ae63ul; - static constexpr uint64_t prime5 = 0x27d4eb2f165667c5ul; -}; - -template <> -hash_value_type __device__ inline XXHash_64::operator()(bool const& key) const -{ - return compute(static_cast(key)); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()(float const& key) const -{ - return compute(normalize_nans(key)); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()(double const& key) const -{ - return compute(normalize_nans(key)); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()( - cudf::string_view const& key) const -{ - auto const len = key.size_bytes(); - auto data = device_span(reinterpret_cast(key.data()), len); - return compute_bytes(data); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal32 const& key) const -{ - return compute(key.value()); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal64 const& key) const -{ - return compute(key.value()); -} - -template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal128 const& key) const -{ - return compute(key.value()); -} - -/** - * @brief Computes the hash value of a row in the given table. - * - * @tparam Nullate A cudf::nullate type describing whether to check for nulls. - */ -template -class device_row_hasher { - public: - device_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) - : _check_nulls(nulls), _table(t), _seed(seed) - { - } - - __device__ auto operator()(size_type row_index) const noexcept - { - return cudf::detail::accumulate( - _table.begin(), - _table.end(), - _seed, - [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { - return cudf::type_dispatcher( - column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); - }); - } - - /** - * @brief Computes the hash value of an element in the given column. - */ - class element_hasher_adapter { - public: - template ())> - __device__ hash_value_type operator()(column_device_view const& col, - size_type const row_index, - Nullate const _check_nulls, - hash_value_type const _seed) const noexcept - { - if (_check_nulls && col.is_null(row_index)) { - return std::numeric_limits::max(); - } - auto const hasher = XXHash_64{_seed}; - return hasher(col.element(row_index)); - } - - template ())> - __device__ hash_value_type operator()(column_device_view const&, - size_type const, - Nullate const, - hash_value_type const) const noexcept - { - CUDF_UNREACHABLE("Unsupported type for XXHash_64"); - } - }; - - Nullate const _check_nulls; - table_device_view const _table; - hash_value_type const _seed; -}; - -} // namespace - std::unique_ptr xxhash_64(table_view const& input, uint64_t seed, rmm::cuda_stream_view stream, @@ -318,7 +54,7 @@ std::unique_ptr xxhash_64(table_view const& input, thrust::tabulate(rmm::exec_policy(stream), output_view.begin(), output_view.end(), - device_row_hasher(nullable, *input_view, seed)); + xxhash_64_device_row_hasher(nullable, *input_view, seed)); return output; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 666a7d4ba4b..6ae2f4f095f 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -132,6 +132,7 @@ ConfigureTest( groupby/groupby_test_util.cpp groupby/groups_tests.cpp groupby/histogram_tests.cpp + groupby/hllpp_tests.cpp groupby/keys_tests.cpp groupby/lists_tests.cpp groupby/m2_tests.cpp diff --git a/cpp/tests/groupby/hllpp_tests.cpp b/cpp/tests/groupby/hllpp_tests.cpp new file mode 100644 index 00000000000..3e40322850e --- /dev/null +++ b/cpp/tests/groupby/hllpp_tests.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +using namespace cudf::test::iterators; + +namespace { +constexpr cudf::test::debug_output_level verbosity{cudf::test::debug_output_level::FIRST_ERROR}; +constexpr int32_t null{0}; // Mark for null elements +constexpr double NaN{std::numeric_limits::quiet_NaN()}; // Mark for NaN double elements + +template +using keys_col = cudf::test::fixed_width_column_wrapper; + +template +using vals_col = cudf::test::fixed_width_column_wrapper; + +template +using M2s_col = cudf::test::fixed_width_column_wrapper; + +auto compute_HLL(cudf::column_view const& keys, cudf::column_view const& values) +{ + std::vector requests; + requests.emplace_back(); + requests[0].values = values; + requests[0].aggregations.emplace_back( + cudf::make_hyper_log_log_aggregation(9)); + auto gb_obj = cudf::groupby::groupby(cudf::table_view({keys})); + auto result = gb_obj.aggregate(requests); + return std::pair(std::move(result.first->release()[0]), std::move(result.second[0].results[0])); +} +} // namespace + +template +struct GroupbyHLLTypedTest : public cudf::test::BaseFixture {}; + +using TestTypes = cudf::test::Concat, + cudf::test::FloatingPointTypes>; +TYPED_TEST_SUITE(GroupbyHLLTypedTest, TestTypes); + +TYPED_TEST(GroupbyHLLTypedTest, SimpleInput) +{ + using T = TypeParam; + + // key = 1: vals = [0, 3, 6] + // key = 2: vals = [1, 4, 5, 9] + // key = 3: vals = [2, 7, 8] + auto const keys = keys_col{1, 2, 3, 1, 2, 2, 1, 3, 3, 2}; + auto const vals = vals_col{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + compute_HLL(keys, vals); +} diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 379750bb0b7..754c1e7b594 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -70,7 +70,9 @@ enum Kind { TDIGEST(31), // This can take a delta argument for accuracy level MERGE_TDIGEST(32), // This can take a delta argument for accuracy level HISTOGRAM(33), - MERGE_HISTOGRAM(34); + MERGE_HISTOGRAM(34), + HLLPP(35), + MERGE_HLLPP(36); final int nativeId; @@ -912,6 +914,66 @@ public boolean equals(Object other) { } } + private static final class HLLAggregation extends Aggregation { + private final int num_registers_per_sketch; + + public HLLAggregation(Kind kind, int num_registers_per_sketch) { + super(kind); + this.num_registers_per_sketch = num_registers_per_sketch; + } + + @Override + long createNativeInstance() { + return Aggregation.createHLLAgg(kind.nativeId, num_registers_per_sketch); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + num_registers_per_sketch; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof HLLAggregation) { + HLLAggregation o = (HLLAggregation) other; + return o.num_registers_per_sketch == this.num_registers_per_sketch; + } + return false; + } + } + + static final class MergeHLLAggregation extends Aggregation { + private final int num_registers_per_sketch; + + public MergeHLLAggregation(Kind kind, int num_registers_per_sketch) { + super(kind); + this.num_registers_per_sketch = num_registers_per_sketch; + } + + @Override + long createNativeInstance() { + return Aggregation.createHLLAgg(kind.nativeId, num_registers_per_sketch); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + num_registers_per_sketch; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof MergeHLLAggregation) { + MergeHLLAggregation o = (MergeHLLAggregation) other; + return o.num_registers_per_sketch == this.num_registers_per_sketch; + } + return false; + } + } + static TDigestAggregation createTDigest(int delta) { return new TDigestAggregation(Kind.TDIGEST, delta); } @@ -940,6 +1002,14 @@ static MergeHistogramAggregation mergeHistogram() { return new MergeHistogramAggregation(); } + static HLLAggregation HLLPP(int numRegistersPerSketch) { + return new HLLAggregation(Kind.HLLPP, numRegistersPerSketch); + } + + static MergeHLLAggregation mergeHLLPP(int numRegistersPerSketch) { + return new MergeHLLAggregation(Kind.MERGE_HLLPP, numRegistersPerSketch); + } + /** * Create one of the aggregations that only needs a kind, no other parameters. This does not * work for all types and for code safety reasons each kind is added separately. @@ -990,4 +1060,6 @@ static MergeHistogramAggregation mergeHistogram() { * Create a TDigest aggregation. */ private static native long createTDigestAgg(int kind, int delta); + + private static native long createHLLAgg(int kind, int numRegistersPerSketch); } diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index 0fae33927b6..5fcba0c1619 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -337,4 +337,12 @@ public static GroupByAggregation histogram() { public static GroupByAggregation mergeHistogram() { return new GroupByAggregation(Aggregation.mergeHistogram()); } + + public static GroupByAggregation HLLPP(int numRegistersPerSketch) { + return new GroupByAggregation(Aggregation.HLLPP(numRegistersPerSketch)); + } + + public static GroupByAggregation mergeHLL(int numRegistersPerSketch) { + return new GroupByAggregation(Aggregation.mergeHLLPP(numRegistersPerSketch)); + } } diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java index ba8ae379bae..02dc2e33c0b 100644 --- a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -304,4 +304,12 @@ public static ReductionAggregation histogram() { public static ReductionAggregation mergeHistogram() { return new ReductionAggregation(Aggregation.mergeHistogram()); } + + public static ReductionAggregation HLLPP(int numRegistersPerSketch) { + return new ReductionAggregation(Aggregation.HLLPP(numRegistersPerSketch)); + } + + public static ReductionAggregation mergeHLL(int numRegistersPerSketch) { + return new ReductionAggregation(Aggregation.mergeHLLPP(numRegistersPerSketch)); + } } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index c40f1c55500..2407f12d048 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -100,7 +100,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv* return cudf::make_histogram_aggregation(); case 34: // MERGE_HISTOGRAM return cudf::make_merge_histogram_aggregation(); - default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); } }(); @@ -296,4 +295,27 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEn CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createHLLAgg(JNIEnv* env, + jclass class_object, + jint kind, + jint precision) +{ + try { + cudf::jni::auto_set_device(env); + std::unique_ptr ret; + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 35: // HLLPP + ret = cudf::make_hyper_log_log_aggregation(precision); + break; + case 36: // MERGE_HLLPP + ret = cudf::make_merge_hyper_log_log_aggregation(precision); + break; + default: throw std::logic_error("Unsupported HyperLogLog++ Aggregation Operation"); + } + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index c7fcb1756b6..5a0a6b5cea4 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -24,7 +24,7 @@ import ai.rapids.cudf.HostColumnVector.ListType; import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; - +import ai.rapids.cudf.Table.TestBuilder; import ai.rapids.cudf.ast.BinaryOperation; import ai.rapids.cudf.ast.BinaryOperator; import ai.rapids.cudf.ast.ColumnReference; @@ -58,7 +58,9 @@ import static ai.rapids.cudf.AssertUtils.assertPartialTablesAreEqual; import static ai.rapids.cudf.AssertUtils.assertTableTypes; import static ai.rapids.cudf.AssertUtils.assertTablesAreEqual; +import static ai.rapids.cudf.ColumnWriterOptions.listBuilder; import static ai.rapids.cudf.ColumnWriterOptions.mapColumn; +import static ai.rapids.cudf.ColumnWriterOptions.structBuilder; import static ai.rapids.cudf.ParquetWriterOptions.listBuilder; import static ai.rapids.cudf.ParquetWriterOptions.structBuilder; import static ai.rapids.cudf.Table.TestBuilder; @@ -10016,4 +10018,15 @@ void testSample() { } } } + + @Test + void testGroupByHLL() { + // A trivial test: + try (Table input = new Table.TestBuilder().column(1, 2, 3, 1, 2, 2, 1, 3, 3, 2) + .column(0, 1, -2, 3, -4, -5, -6, 7, -8, 9) + .build()){ + input.groupBy(0).aggregate(GroupByAggregation.M2() + .onColumn(1)); + } + } } From abb4cad2d98dcde549ed50b71bed661e356a8330 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 1 Nov 2024 11:39:37 +0800 Subject: [PATCH 02/10] Improve: use shared memory --- .../sort/group_hyper_log_log_plus_plus.cu | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index 05580c1a685..92e9c878d9e 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -231,28 +231,35 @@ CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( * Has num_groups rows. * */ +template CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, int64_t num_registers_per_sketch, int* const registers_output_cache, int const* const registers_thread_cache, size_type const* const group_lables_thread_cache) { - // register idx is tid + extern __shared__ int8_t shared_data[]; auto const tid = cudf::detail::grid_1d::global_thread_id(); - int reg_max = 0; - int prev_group = group_lables_thread_cache[0]; + int shared_idx = tid % block_size; + + // register idx is tid + shared_data[shared_idx] = static_cast(0); + int prev_group = group_lables_thread_cache[0]; for (auto i = 0; i < num_sketches; i++) { int curr_group = group_lables_thread_cache[i]; - int curr_reg_v = registers_thread_cache[tid + i * num_registers_per_sketch]; + int8_t curr_reg_v = + static_cast(registers_thread_cache[tid + i * num_registers_per_sketch]); if (curr_group == prev_group) { - if (curr_reg_v > reg_max) { reg_max = curr_reg_v; } + if (curr_reg_v > shared_data[shared_idx]) { shared_data[shared_idx] = curr_reg_v; } } else { // meets a new group, store the result for previous group int64_t reg_idx = prev_group * num_registers_per_sketch + tid; int curr_reg_v = registers_output_cache[reg_idx]; - if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; } + if (shared_data[shared_idx] > curr_reg_v) { + registers_output_cache[reg_idx] = shared_data[shared_idx]; + } - reg_max = curr_reg_v; + shared_data[shared_idx] = curr_reg_v; } prev_group = curr_group; } @@ -260,7 +267,9 @@ CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, // handles the last register in this thread int64_t reg_idx = prev_group * num_registers_per_sketch + tid; int curr_reg_v = registers_output_cache[reg_idx]; - if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; } + if (shared_data[shared_idx] > curr_reg_v) { + registers_output_cache[reg_idx] = shared_data[shared_idx]; + } } /** @@ -363,12 +372,13 @@ std::unique_ptr group_hllpp(column_view const& input, // 3. merge the intermidate result auto num_merge_threads = num_registers_per_sketch; auto num_merge_blocks = cudf::util::div_rounding_up_safe(num_merge_threads, block_size); - merge_sketches_vertically<<>>( - total_threads_partial_group, // num_sketches - num_registers_per_sketch, - sketches_output.begin(), - registers_thread_cache.begin(), - group_lables_thread_cache.begin()); + merge_sketches_vertically + <<>>( + total_threads_partial_group, // num_sketches + num_registers_per_sketch, + sketches_output.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); // 4. create output columns auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; @@ -583,12 +593,13 @@ std::unique_ptr merge_hyper_log_log( auto const num_phase2_threads = num_registers_per_sketch; auto const num_phase2_blocks = cudf::util::div_rounding_up_safe(num_phase2_threads, block_size); // 2nd kernel: vertical merge - merge_sketches_vertically<<>>( - num_threads_per_col_phase1, // num_sketches - num_registers_per_sketch, - registers_output_cache.begin(), - registers_thread_cache.begin(), - group_lables_thread_cache.begin()); + merge_sketches_vertically + <<>>( + num_threads_per_col_phase1, // num_sketches + num_registers_per_sketch, + registers_output_cache.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); } // create output columns From 77ea21c4d196c955ec6d4579f0b12634940d8e24 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 12 Nov 2024 17:38:26 +0800 Subject: [PATCH 03/10] Reduction for hllpp --- .../hyper_log_log_plus_plus.hpp | 45 ++++ .../sort/group_hyper_log_log_plus_plus.cu | 231 +++++++++++++++++- cpp/src/reductions/reductions.cpp | 10 +- 3 files changed, 273 insertions(+), 13 deletions(-) create mode 100644 cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp diff --git a/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp b/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp new file mode 100644 index 00000000000..71f27cd1a36 --- /dev/null +++ b/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace cudf { +namespace groupby::detail { + +/** + * Compute the hashs of the input column, then generate a scalar that is a sketch in long array + * format + */ +std::unique_ptr reduce_hyper_log_log_plus_plus(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +/** + * Merge sketches in long array format, and compute the estimated distinct value(long) + * Input is a struct column with multiple long columns which is consistent with Spark. + */ +std::unique_ptr reduce_merge_hyper_log_log_plus_plus(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace groupby::detail +} // namespace cudf diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index 92e9c878d9e..7ed1ae876c5 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -56,6 +56,9 @@ constexpr int REGISTERS_PER_LONG = 64 / REGISTER_VALUE_BITS; // XXHash seed constexpr int64_t SEED = 42L; +// max precision, if require a precision bigger than 18, then use 18. +constexpr int MAX_PRECISION = 18; + /** * * Computes register values from hash values and partially groups from register values. @@ -238,7 +241,7 @@ CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, int const* const registers_thread_cache, size_type const* const group_lables_thread_cache) { - extern __shared__ int8_t shared_data[]; + __shared__ int8_t shared_data[block_size]; auto const tid = cudf::detail::grid_1d::global_thread_id(); int shared_idx = tid % block_size; @@ -626,6 +629,172 @@ std::unique_ptr merge_hyper_log_log( return make_structs_column(num_groups, std::move(results), 0, rmm::device_buffer{}); } +/** + * launch only 1 block + */ +template +CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, int32_t* const output, int precision) +{ + __shared__ int32_t shared_data[block_size]; + + auto const tid = cudf::detail::grid_1d::global_thread_id(); + auto const num_hashs = hashs.size(); + uint64_t const num_registers_per_sketch = 1L << precision; + int const idx_shift = 64 - precision; + uint64_t const w_padding = 1ULL << (precision - 1); + + // init tmp data + for (int i = tid; i < num_registers_per_sketch; i += block_size) { + shared_data[i] = 0; + } + __syncthreads(); + + // update max reg value + for (int i = tid; i < num_hashs; i += block_size) { + uint64_t const hash = static_cast(hashs.element(i)); + uint64_t const reg_idx = hash >> idx_shift; + int const reg_v = + static_cast(cuda::std::countl_zero((hash << precision) | w_padding) + 1ULL); + cuda::atomic_ref register_ref(shared_data[reg_idx]); + register_ref.fetch_max(reg_v, cuda::memory_order_relaxed); + } + __syncthreads(); + + // copy to output + for (int i = tid; i < num_registers_per_sketch; i += block_size) { + output[i] = shared_data[i]; + } +} + +std::unique_ptr reduce_hllpp(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int64_t num_registers_per_sketch = 1L << precision; + // 1. compute all the hashs + auto hash_col = + make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + bool const nullable = input.has_nulls(); + thrust::tabulate( + rmm::exec_policy(stream), + hash_col->mutable_view().begin(), + hash_col->mutable_view().end(), + cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED)); + auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream); + + // 2. reduce + rmm::device_uvector output_tmp(num_registers_per_sketch, stream, mr); + constexpr int64_t block_size = 256; + // max shared memory is 2^18 * 4 = 1M + auto const shared_mem_size = num_registers_per_sketch * sizeof(int32_t); + reduce_hllpp_kernel + <<<1, block_size, shared_mem_size, stream.value()>>>(*d_hashs, output_tmp.begin(), precision); + + // 3. compact to longs + auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; + auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) { + return make_numeric_column( + data_type{type_id::INT64}, 1 /**num_groups*/, mask_state::ALL_VALID, stream, mr); + }); + auto children = std::vector>(results_iter, results_iter + num_long_cols); + auto d_results = [&] { + auto host_results_pointer_iter = + thrust::make_transform_iterator(children.begin(), [](auto const& results_column) { + return results_column->mutable_view().template data(); + }); + auto host_results_pointers = + std::vector(host_results_pointer_iter, host_results_pointer_iter + children.size()); + return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); + }(); + auto const num_compact_threads = num_long_cols; + auto const num_compact_blocks = cudf::util::div_rounding_up_safe(num_compact_threads, block_size); + compact_kernel<<>>( + 1 /**num_groups*/, num_registers_per_sketch, d_results, output_tmp); + + // 4. create scalar + auto host_results_view_iter = thrust::make_transform_iterator( + children.begin(), [](auto const& results_column) { return results_column->view(); }); + auto views = + std::vector(host_results_view_iter, host_results_view_iter + num_long_cols); + auto table_view = cudf::table_view{views}; + auto table = cudf::table(table_view); + return std::make_unique(std::move(table), true, stream, mr); +} + +CUDF_KERNEL void reduce_merge_hll_kernel_vertically(cudf::device_span sketch_longs, + size_type num_sketches, + int num_registers_per_sketch, + int* const output) +{ + auto const tid = cudf::detail::grid_1d::global_thread_id(); + if (tid >= num_registers_per_sketch) { return; } + auto long_idx = tid / REGISTERS_PER_LONG; + auto reg_idx_in_long = tid % REGISTERS_PER_LONG; + int max = 0; + for (auto row_idx = 0; row_idx < num_sketches; row_idx++) { + int reg_v = get_register_value(sketch_longs[long_idx][row_idx], reg_idx_in_long); + if (reg_v > max) { max = reg_v; } + } + output[tid] = max; +} + +std::unique_ptr reduce_merge_hllpp(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + // create device input + int64_t num_registers_per_sketch = 1 << precision; + auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; + cudf::structs_column_view scv(input); + auto const input_iter = cudf::detail::make_counting_transform_iterator( + 0, [&](int i) { return scv.get_sliced_child(i, stream).begin(); }); + auto input_cols = std::vector(input_iter, input_iter + num_long_cols); + auto d_inputs = cudf::detail::make_device_uvector_async(input_cols, stream, mr); + + // create one row output + auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) { + return make_numeric_column( + data_type{type_id::INT64}, 1 /** num_rows */, mask_state::ALL_VALID, stream, mr); + }); + auto children = std::vector>(results_iter, results_iter + num_long_cols); + auto d_results = [&] { + auto host_results_pointer_iter = + thrust::make_transform_iterator(children.begin(), [](auto const& results_column) { + return results_column->mutable_view().template data(); + }); + auto host_results_pointers = + std::vector(host_results_pointer_iter, host_results_pointer_iter + children.size()); + return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); + }(); + + // execute merge kernel + auto num_threads = num_registers_per_sketch; + constexpr int64_t block_size = 256; + auto num_blocks = cudf::util::div_rounding_up_safe(num_threads, block_size); + auto output_cache = rmm::device_uvector(num_registers_per_sketch, stream, mr); + reduce_merge_hll_kernel_vertically<<>>( + d_inputs, input.size(), num_registers_per_sketch, output_cache.begin()); + + // compact to longs + auto const num_compact_threads = num_long_cols; + auto const num_compact_blocks = cudf::util::div_rounding_up_safe(num_compact_threads, block_size); + compact_kernel<<>>( + 1 /** num_groups **/, num_registers_per_sketch, d_results, output_cache); + + // create scalar + auto host_results_view_iter = thrust::make_transform_iterator( + children.begin(), [](auto const& results_column) { return results_column->view(); }); + auto views = + std::vector(host_results_view_iter, host_results_view_iter + num_long_cols); + auto table_view = cudf::table_view{views}; + auto table = cudf::table(table_view); + return std::make_unique(std::move(table), true, stream, mr); +} + } // namespace /** @@ -640,11 +809,9 @@ std::unique_ptr group_hyper_log_log_plus_plus( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]"); - auto input_type = - cudf::is_dictionary(input.type()) ? dictionary_column_view(input).keys().type() : input.type(); - - return group_hllpp(input, num_groups, group_lables, precision, stream, mr); + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; + return group_hllpp(input, num_groups, group_lables, adjust_precision, stream, mr); } /** @@ -652,21 +819,61 @@ std::unique_ptr group_hyper_log_log_plus_plus( * Input is a struct column with multiple long columns which is consistent with Spark. */ std::unique_ptr group_merge_hyper_log_log_plus_plus( - column_view const& values, + column_view const& input, int64_t const num_groups, cudf::device_span group_lables, int64_t const precision, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]"); - CUDF_EXPECTS(values.type().id() == type_id::STRUCT, + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(input.type().id() == type_id::STRUCT, + "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); + for (auto i = 0; i < input.num_children(); i++) { + CUDF_EXPECTS(input.child(i).type().id() == type_id::INT64, + "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); + } + auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; + auto expected_num_longs = (1 << adjust_precision) / REGISTERS_PER_LONG + 1; + CUDF_EXPECTS(input.num_children() == expected_num_longs, + "The num of long columns in input is incorrect."); + return merge_hyper_log_log(input, num_groups, group_lables, adjust_precision, stream, mr); +} + +/** + * Compute the hashs of the input column, then generate a sketch stored in a struct of long scalar. + */ +std::unique_ptr reduce_hyper_log_log_plus_plus(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; + return reduce_hllpp(input, adjust_precision, stream, mr); +} + +/** + * Merge all sketches in the input column into one sketch. + * Input is a struct column with multiple long columns which is consistent with Spark. + */ +std::unique_ptr reduce_merge_hyper_log_log_plus_plus(column_view const& input, + int64_t const precision, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(input.type().id() == type_id::STRUCT, "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); - for (auto i = 0; i < values.num_children(); i++) { - CUDF_EXPECTS(values.child(i).type().id() == type_id::INT64, + for (auto i = 0; i < input.num_children(); i++) { + CUDF_EXPECTS(input.child(i).type().id() == type_id::INT64, "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); } - return merge_hyper_log_log(values, num_groups, group_lables, precision, stream, mr); + auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; + auto expected_num_longs = (1 << adjust_precision) / REGISTERS_PER_LONG + 1; + CUDF_EXPECTS(input.num_children() == expected_num_longs, + "The num of long columns in input is incorrect."); + return reduce_merge_hllpp(input, adjust_precision, stream, mr); } } // namespace detail diff --git a/cpp/src/reductions/reductions.cpp b/cpp/src/reductions/reductions.cpp index 75ebc078930..68a33ad9fc1 100644 --- a/cpp/src/reductions/reductions.cpp +++ b/cpp/src/reductions/reductions.cpp @@ -29,7 +29,7 @@ #include #include #include - +#include #include #include @@ -144,6 +144,14 @@ struct reduce_dispatch_functor { auto td_agg = static_cast(agg); return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr); } + case aggregation::HLLPP: { + auto hllpp_agg = static_cast(agg); + return cudf::groupby::detail::reduce_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr); + } + case aggregation::MERGE_HLLPP: { + auto hllpp_agg = static_cast(agg); + return cudf::groupby::detail::reduce_merge_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr); + } default: CUDF_FAIL("Unsupported reduction operator"); } } From d3b6066f341d8763448d14bf59eacec61463deed Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 22 Nov 2024 18:34:42 +0800 Subject: [PATCH 04/10] Improve reduction --- .../sort/group_hyper_log_log_plus_plus.cu | 66 ++++++++++++------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index 7ed1ae876c5..c17ed39a428 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -43,8 +43,14 @@ namespace groupby { namespace detail { namespace { -// The number of bits required by register value. Register value stores num of zeros. -// XXHash64 value is 64 bits, it's safe to use 6 bits to store a register value. +/** + * The number of bits that is required for register value. + * + * This number is determined by the maximum number of leading binary zeros a hashcode can + * produce. This is equal to the number of bits the hashcode returns. The current + * implementation uses a 64-bit hashcode, this means 6-bits are (at most) needed to store the + * number of leading zeros. + */ constexpr int REGISTER_VALUE_BITS = 6; // MASK binary 6 bits: 111111 @@ -630,10 +636,14 @@ std::unique_ptr merge_hyper_log_log( } /** - * launch only 1 block + * Launch only 1 block, uses max 1M(2^18 *sizeof(int)) shared memory. + * For each hash, get a pair: (register index, register value). + * Use shared memory to speedup the fetch max atomic operation. */ template -CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, int32_t* const output, int precision) +CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, + cudf::device_span output, + int precision) { __shared__ int32_t shared_data[block_size]; @@ -649,10 +659,12 @@ CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, int32_t* const ou } __syncthreads(); - // update max reg value + // update max reg value for the reg index for (int i = tid; i < num_hashs; i += block_size) { - uint64_t const hash = static_cast(hashs.element(i)); + uint64_t const hash = static_cast(hashs.element(i)); + // use unsigned int to avoid insert 1 for the highest bit when do right shift uint64_t const reg_idx = hash >> idx_shift; + // get the leading zeros int const reg_v = static_cast(cuda::std::countl_zero((hash << precision) | w_padding) + 1ULL); cuda::atomic_ref register_ref(shared_data[reg_idx]); @@ -660,9 +672,22 @@ CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, int32_t* const ou } __syncthreads(); - // copy to output - for (int i = tid; i < num_registers_per_sketch; i += block_size) { - output[i] = shared_data[i]; + // compact from register values (int array) to long array + // each long holds 10 integers, note reg value < 64 which means the bits from 7 to highest are all + // 0. + if (tid * REGISTERS_PER_LONG < num_registers_per_sketch) { + int start = tid * REGISTERS_PER_LONG; + int end = (tid + 1) * REGISTERS_PER_LONG; + if (end > num_registers_per_sketch) { end = num_registers_per_sketch; } + + int64_t ret = 0; + for (int i = 0; i < end - start; i++) { + int shift = i * REGISTER_VALUE_BITS; + int64_t reg = shared_data[start + i]; + ret |= (reg << shift); + } + + output[tid][0] = ret; } } @@ -685,15 +710,7 @@ std::unique_ptr reduce_hllpp(column_view const& input, cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED)); auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream); - // 2. reduce - rmm::device_uvector output_tmp(num_registers_per_sketch, stream, mr); - constexpr int64_t block_size = 256; - // max shared memory is 2^18 * 4 = 1M - auto const shared_mem_size = num_registers_per_sketch * sizeof(int32_t); - reduce_hllpp_kernel - <<<1, block_size, shared_mem_size, stream.value()>>>(*d_hashs, output_tmp.begin(), precision); - - // 3. compact to longs + // 2. generate long columns, the size of each long column is 1 auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) { return make_numeric_column( @@ -709,12 +726,15 @@ std::unique_ptr reduce_hllpp(column_view const& input, std::vector(host_results_pointer_iter, host_results_pointer_iter + children.size()); return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); }(); - auto const num_compact_threads = num_long_cols; - auto const num_compact_blocks = cudf::util::div_rounding_up_safe(num_compact_threads, block_size); - compact_kernel<<>>( - 1 /**num_groups*/, num_registers_per_sketch, d_results, output_tmp); - // 4. create scalar + // 2. reduce and generate compacted long values + constexpr int64_t block_size = 256; + // max shared memory is 2^18 * 4 = 1M + auto const shared_mem_size = num_registers_per_sketch * sizeof(int32_t); + reduce_hllpp_kernel + <<<1, block_size, shared_mem_size, stream.value()>>>(*d_hashs, d_results, precision); + + // 3. create struct scalar auto host_results_view_iter = thrust::make_transform_iterator( children.begin(), [](auto const& results_column) { return results_column->view(); }); auto views = From 57efaff852450c979cf104e0b6397d7663fbbf54 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 26 Nov 2024 13:46:08 +0800 Subject: [PATCH 05/10] Refine code; Add comments --- .../sort/group_hyper_log_log_plus_plus.cu | 231 ++++++++---------- 1 file changed, 104 insertions(+), 127 deletions(-) diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index c17ed39a428..07957a37347 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -44,7 +44,7 @@ namespace detail { namespace { /** - * The number of bits that is required for register value. + * The number of bits that is required for a HLLPP register value. * * This number is determined by the maximum number of leading binary zeros a hashcode can * produce. This is equal to the number of bits the hashcode returns. The current @@ -53,13 +53,13 @@ namespace { */ constexpr int REGISTER_VALUE_BITS = 6; -// MASK binary 6 bits: 111111 +// MASK binary 6 bits: 111-111 constexpr uint64_t MASK = (1L << REGISTER_VALUE_BITS) - 1L; -// One long stores 10 register values +// This value is 10, one long stores 10 register values constexpr int REGISTERS_PER_LONG = 64 / REGISTER_VALUE_BITS; -// XXHash seed +// XXHash seed, consistent with Spark constexpr int64_t SEED = 42L; // max precision, if require a precision bigger than 18, then use 18. @@ -67,37 +67,41 @@ constexpr int MAX_PRECISION = 18; /** * - * Computes register values from hash values and partially groups from register values. - * It splits input into multiple segments with num_hashs_per_thread length. - * Each thread scans in its segment, find the max register values for all the values - * at the same register index at the same group, outputs gathered result when meets a new group, - * and in the end each thread saves a buffer for the last group in the segment. + * Computes register values from hash values and partially groups register values. + * It splits input into multiple segments with each segment has num_hashs_per_thread length. + * The input is sorted by group labels, each segment contains several consecutive groups. + * Each thread scans in its segment, find the max register values for all the register values + * at the same register index at the same group, outputs gathered result of previous group + * when meets a new group, and in the end each thread saves a buffer for the last group + * in the segment. * - * In this way, we can save memory usage, only need to cache `num_threads` caches. + * In this way, we can save memory usage, only need to cache + * (num_hashs / num_hashs_per_thread) sketches. * * num_threads = div_round_up(num_hashs, num_hashs_per_thread). * - * After register values are computed. * * e.g.: num_registers_per_sketch = 512 and num_hashs_per_thread = 4; * * Input: * register_index register_value group_lable * [ + * ------------------ segment 0 begin -------------------------------------- * (0, 1), 0 * (0, 2), 0 - * (1, 1), 1 // meets a new group, outputs result for g0 - * (1, 9), 1 // outputs for thread 0 when scan to here + * (1, 1), 1 // meets a new group, outputs result for g0 + * (1, 9), 1 // outputs for thread 0 when scan to here + * ------------------ segment 1 begin -------------------------------------- * (1, 1), 1 * (1, 1), 1 * (1, 5), 1 - * (1, 1), 1 // outputs for thread 1; Output result for g1 + * (1, 1), 1 // outputs for thread 1; Output result for g1 * ] * Output e.g.: * * group_lables_thread_cache: * [ - * g0 + * g1 * g1 * ] * Has num_threads rows. @@ -116,7 +120,7 @@ constexpr int MAX_PRECISION = 18; * ] * Has num_groups rows. * - * The next kernel will merge the partial result to final result + * The next kernel will merge the registers_output_cache and registers_thread_cache */ template CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( @@ -162,31 +166,26 @@ CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( // still in the same group, update the max value if (reg_v > sketch_ptr[reg_idx]) { sketch_ptr[reg_idx] = reg_v; } } else { - // meets new group, save output for the previous group + // meets new group, save output for the previous group and reset for (auto i = 0; i < num_registers_per_sketch; i++) { registers_output_cache[prev_group * num_registers_per_sketch + i] = sketch_ptr[i]; + sketch_ptr[i] = 0; } - - // reset cache - for (auto i = 0; i < num_registers_per_sketch; i++) { - sketch_ptr[i] = 0; - } - - // save the max value + // save the result for current group sketch_ptr[reg_idx] = reg_v; } - // special logic for the last sketch in this thread if (hash_idx == hash_end - 1) { // meets the last hash in the segment if (hash_idx == num_hashs - 1) { - // this segment is the last one + // meets the last segment, special logic: assume meets new group for (auto i = 0; i < num_registers_per_sketch; i++) { registers_output_cache[curr_group * num_registers_per_sketch + i] = sketch_ptr[i]; } } else { - // not the last segment, proble one item forward. + // not the last segment, probe one item forward. if (curr_group != group_lables[hash_idx + 1]) { + // meets a new group by checking the next item in the next segment for (auto i = 0; i < num_registers_per_sketch; i++) { registers_output_cache[curr_group * num_registers_per_sketch + i] = sketch_ptr[i]; } @@ -202,7 +201,7 @@ CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( } /* - * + * Merge registers_output_cache and registers_thread_cache produced in the above kernel * Merge sketches vertically. * * For all register at the same index, starts a thread to merge the max value. @@ -239,6 +238,7 @@ CUDF_KERNEL void partial_group_sketches_from_hashs_kernel( * ] * Has num_groups rows. * + * First find the max value in registers_thread_cache and then merge to registers_output_cache */ template CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, @@ -257,15 +257,15 @@ CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, for (auto i = 0; i < num_sketches; i++) { int curr_group = group_lables_thread_cache[i]; int8_t curr_reg_v = - static_cast(registers_thread_cache[tid + i * num_registers_per_sketch]); + static_cast(registers_thread_cache[i * num_registers_per_sketch + tid]); if (curr_group == prev_group) { if (curr_reg_v > shared_data[shared_idx]) { shared_data[shared_idx] = curr_reg_v; } } else { // meets a new group, store the result for previous group - int64_t reg_idx = prev_group * num_registers_per_sketch + tid; - int curr_reg_v = registers_output_cache[reg_idx]; - if (shared_data[shared_idx] > curr_reg_v) { - registers_output_cache[reg_idx] = shared_data[shared_idx]; + int64_t result_reg_idx = prev_group * num_registers_per_sketch + tid; + int result_curr_reg_v = registers_output_cache[result_reg_idx]; + if (shared_data[shared_idx] > result_curr_reg_v) { + registers_output_cache[result_reg_idx] = shared_data[shared_idx]; } shared_data[shared_idx] = curr_reg_v; @@ -282,11 +282,14 @@ CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches, } /** + * Compact register values, compact 10 registers values + * (each register value is 6 bits) in to a long. + * This is consistent with Spark. + * Output: long columns which will be composed into a struct column * - * Compact register values, compact 10 registers values (each is 6 bits) in to a long. - * Number of threads is num_groups * num_longs_per_sketch + * Number of threads is num_groups * num_long_cols. * - * e.g.: + * e.g., num_registers_per_sketch is 512, precision is 9: * Input: * registers_output_cache: * [ @@ -309,20 +312,17 @@ CUDF_KERNEL void compact_kernel(int64_t const num_groups, // num_groups * num_registers_per_sketch integers cudf::device_span registers_output_cache) { - int64_t const tid = cudf::detail::grid_1d::global_thread_id(); - - int64_t const num_longs_per_sketch = sketches_output.size(); - if (tid >= num_groups * num_longs_per_sketch) { return; } + int64_t const tid = cudf::detail::grid_1d::global_thread_id(); + int64_t const num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; + if (tid >= num_groups * num_long_cols) { return; } - int64_t const group_idx = tid / num_longs_per_sketch; - int64_t const long_idx = tid % num_longs_per_sketch; + int64_t const group_idx = tid / num_long_cols; + int64_t const long_idx = tid % num_long_cols; int64_t const reg_begin_idx = group_idx * num_registers_per_sketch + long_idx * REGISTERS_PER_LONG; int64_t num_regs = REGISTERS_PER_LONG; - if (long_idx == num_longs_per_sketch - 1) { - num_regs = num_registers_per_sketch % REGISTERS_PER_LONG; - } + if (long_idx == num_long_cols - 1) { num_regs = num_registers_per_sketch % REGISTERS_PER_LONG; } int64_t ten_registers = 0; for (auto i = 0; i < num_regs; i++) { @@ -344,10 +344,12 @@ std::unique_ptr group_hllpp(column_view const& input, int64_t num_registers_per_sketch = 1 << precision; // 1. compute all the hashs + // TODO: mask_state::ALL_VALID => unallocate auto hash_col = make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); - auto input_table = cudf::table_view{{input}}; - auto d_input_table = cudf::table_device_view::create(input_table, stream); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + // TODO: has_nulls has nested null? bool const nullable = input.has_nulls(); thrust::tabulate( rmm::exec_policy(stream), @@ -362,32 +364,32 @@ std::unique_ptr group_hllpp(column_view const& input, int64_t total_threads_partial_group = cudf::util::div_rounding_up_safe(static_cast(input.size()), num_hashs_per_thread); int64_t num_blocks_p1 = cudf::util::div_rounding_up_safe(total_threads_partial_group, block_size); - auto sketches_output = rmm::device_uvector(num_groups * num_registers_per_sketch, stream, mr); - auto registers_thread_cache = rmm::device_uvector( - total_threads_partial_group * num_registers_per_sketch, stream, mr); - auto group_lables_thread_cache = - rmm::device_uvector(total_threads_partial_group, stream, mr); - - partial_group_sketches_from_hashs_kernel - <<>>(*d_hashs, - group_lables, - precision, - sketches_output.begin(), - registers_thread_cache.begin(), - group_lables_thread_cache.begin()); - - // 3. merge the intermidate result - auto num_merge_threads = num_registers_per_sketch; - auto num_merge_blocks = cudf::util::div_rounding_up_safe(num_merge_threads, block_size); - merge_sketches_vertically - <<>>( - total_threads_partial_group, // num_sketches - num_registers_per_sketch, - sketches_output.begin(), - registers_thread_cache.begin(), - group_lables_thread_cache.begin()); + { + auto registers_thread_cache = rmm::device_uvector( + total_threads_partial_group * num_registers_per_sketch, stream, mr); + auto group_lables_thread_cache = + rmm::device_uvector(total_threads_partial_group, stream, mr); + partial_group_sketches_from_hashs_kernel + <<>>(*d_hashs, + group_lables, + precision, + sketches_output.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + + // 3. merge the intermidate result + auto num_merge_threads = num_registers_per_sketch; + auto num_merge_blocks = cudf::util::div_rounding_up_safe(num_merge_threads, block_size); + merge_sketches_vertically + <<>>( + total_threads_partial_group, // num_sketches + num_registers_per_sketch, + sketches_output.begin(), + registers_thread_cache.begin(), + group_lables_thread_cache.begin()); + } // 4. create output columns auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; @@ -411,7 +413,7 @@ std::unique_ptr group_hllpp(column_view const& input, rmm::device_buffer{}, // null mask stream); - // 4. compact sketches + // 5. compact sketches auto num_phase3_threads = num_groups * num_long_cols; auto num_phase3_blocks = cudf::util::div_rounding_up_safe(num_phase3_threads, block_size); compact_kernel<<>>( @@ -444,37 +446,9 @@ __device__ inline int get_register_value(int64_t const long_10_registers, int re * and output the max register value when meets a new group. * For the last long in a thread, outputs the result into `registers_thread_cache`. * - * Output: - * - * group_lables_thread_cache: - * [ - * g0 - * g0 - * g1 - * ... - * gN - * ] - * Has num_threads rows. - * - * registers_thread_cache: - * [ - * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 - * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 - * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 - * ... - * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N - * ] - * Has num_threads rows, each row is corresponding to `group_lables_thread_cache` - * - * registers_output_cache: - * [ - * r0_g0, r1_g0, r2_g0, r3_g0, ... , r511_g0 // register values for group 0 - * r0_g1, r1_g1, r2_g1, r3_g1, ... , r511_g1 // register values for group 1 - * ... - * r0_gN, r1_gN, r2_gN, r3_gN, ... , r511_gN // register values for group N - * ] - * Has num_groups rows. - * + * By split inputs into segments like `partial_group_sketches_from_hashs_kernel` and + * do partial merge, it will use less memory. Then the kernel merge_sketches_vertically + * can be used to merge the intermidate results: registers_output_cache, registers_thread_cache */ template CUDF_KERNEL void partial_group_long_sketches_kernel( @@ -501,8 +475,8 @@ CUDF_KERNEL void partial_group_long_sketches_kernel( int* const registers_thread_ptr = registers_thread_cache + thread_idx_in_cols * num_registers_per_sketch; - auto const sketch_first = thread_idx_in_cols * num_longs_per_threads; + auto const sketch_first = thread_idx_in_cols * num_longs_per_threads; auto const sketch_end = cuda::std::min(sketch_first + num_longs_per_threads, num_sketches_input); int num_regs = REGISTERS_PER_LONG; @@ -514,35 +488,31 @@ CUDF_KERNEL void partial_group_long_sketches_kernel( int reg_idx_in_sketch = long_idx * REGISTERS_PER_LONG + i; for (auto sketch_idx = sketch_first; sketch_idx < sketch_end; sketch_idx++) { size_type curr_group = group_lables[sketch_idx]; - - int64_t output_idx_for_prev_group = num_registers_per_sketch * prev_group + reg_idx_in_sketch; - - int curr_reg_v = get_register_value(longs_ptr[sketch_idx], i); + int curr_reg_v = get_register_value(longs_ptr[sketch_idx], i); if (curr_group == prev_group) { // still in the same group, update the max value if (curr_reg_v > max_reg_v) { max_reg_v = curr_reg_v; } } else { // meets new group, save output for the previous group - registers_output_cache[output_idx_for_prev_group] = max_reg_v; + int64_t output_idx_prev = num_registers_per_sketch * prev_group + reg_idx_in_sketch; + registers_output_cache[output_idx_prev] = max_reg_v; - // reset the cache + // reset max_reg_v = curr_reg_v; } - // special logic for the last sketch in this thread if (sketch_idx == sketch_end - 1) { - // last long in the segment - int64_t output_idx_for_curr_group = - num_registers_per_sketch * curr_group + reg_idx_in_sketch; + // last item in the segment + int64_t output_idx_curr = num_registers_per_sketch * curr_group + reg_idx_in_sketch; if (sketch_idx == num_sketches_input - 1) { // last segment - registers_output_cache[output_idx_for_curr_group] = max_reg_v; - max_reg_v = curr_reg_v; + registers_output_cache[output_idx_curr] = max_reg_v; + max_reg_v = curr_reg_v; } else { if (curr_group != group_lables[sketch_idx + 1]) { - // look one more forward - registers_output_cache[output_idx_for_curr_group] = max_reg_v; - max_reg_v = curr_reg_v; + // look the first item in the next segment + registers_output_cache[output_idx_curr] = max_reg_v; + max_reg_v = curr_reg_v; } } } @@ -550,14 +520,19 @@ CUDF_KERNEL void partial_group_long_sketches_kernel( prev_group = curr_group; } - // For each thread, output register values + // For each thread, output current max value registers_thread_ptr[reg_idx_in_sketch] = max_reg_v; } + if (long_idx == 0) { group_lables_thread_cache[thread_idx_in_cols] = group_lables[sketch_end - 1]; } } +/** + * Merge for struct column. Each long contains 10 register values. + * Merge all rows in the same group. + */ std::unique_ptr merge_hyper_log_log( column_view const& hll_input, // struct column int64_t const num_groups, @@ -581,7 +556,8 @@ std::unique_ptr merge_hyper_log_log( { auto registers_thread_cache = rmm::device_uvector(num_registers_per_sketch * num_threads_phase1, stream, mr); - auto group_lables_thread_cache = rmm::device_uvector(num_threads_phase1, stream, mr); + auto group_lables_thread_cache = + rmm::device_uvector(num_threads_per_col_phase1, stream, mr); cudf::structs_column_view scv(hll_input); auto const input_iter = cudf::detail::make_counting_transform_iterator( @@ -627,11 +603,12 @@ std::unique_ptr merge_hyper_log_log( return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr); }(); + // 3rd kernel: compact auto num_phase3_threads = num_groups * num_long_cols; auto num_phase3_blocks = cudf::util::div_rounding_up_safe(num_phase3_threads, block_size); - // 3rd kernel: compact compact_kernel<<>>( num_groups, num_registers_per_sketch, d_sketches_output, registers_output_cache); + return make_structs_column(num_groups, std::move(results), 0, rmm::device_buffer{}); } @@ -818,7 +795,7 @@ std::unique_ptr reduce_merge_hllpp(column_view const& input, } // namespace /** - * Compute hyper log log against the input values and merge the sketches in the same group. + * Compute hyper log log for the input values and merge the sketches in the same group. * Output is a struct column with multiple long columns which is consistent with Spark. */ std::unique_ptr group_hyper_log_log_plus_plus( @@ -829,7 +806,7 @@ std::unique_ptr group_hyper_log_log_plus_plus( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision >= 4."); auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; return group_hllpp(input, num_groups, group_lables, adjust_precision, stream, mr); } @@ -846,7 +823,7 @@ std::unique_ptr group_merge_hyper_log_log_plus_plus( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision >= 4."); CUDF_EXPECTS(input.type().id() == type_id::STRUCT, "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); for (auto i = 0; i < input.num_children(); i++) { @@ -868,7 +845,7 @@ std::unique_ptr reduce_hyper_log_log_plus_plus(column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision >= 4."); auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision; return reduce_hllpp(input, adjust_precision, stream, mr); } @@ -882,7 +859,7 @@ std::unique_ptr reduce_merge_hyper_log_log_plus_plus(column_view const& rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); + CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision >= 4."); CUDF_EXPECTS(input.type().id() == type_id::STRUCT, "HyperLogLogPlusPlus buffer type must be a STRUCT of long columns."); for (auto i = 0; i < input.num_children(); i++) { From 57be29e05caae84eba3e24bdf2ce5ebeba690d4f Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 26 Nov 2024 15:13:56 +0800 Subject: [PATCH 06/10] Adjust configs to get better performance --- cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index 07957a37347..f7ebc292400 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -31,7 +31,6 @@ #include #include -#include #include // TODO #include once available #include #include @@ -360,7 +359,7 @@ std::unique_ptr group_hllpp(column_view const& input, // 2. execute partial group by constexpr int64_t block_size = 256; - constexpr int64_t num_hashs_per_thread = 32; // handles 32 items per thread + constexpr int64_t num_hashs_per_thread = 256; // handles 32 items per thread int64_t total_threads_partial_group = cudf::util::div_rounding_up_safe(static_cast(input.size()), num_hashs_per_thread); int64_t num_blocks_p1 = cudf::util::div_rounding_up_safe(total_threads_partial_group, block_size); @@ -544,7 +543,7 @@ std::unique_ptr merge_hyper_log_log( int64_t num_registers_per_sketch = 1 << precision; int64_t const num_sketches = hll_input.size(); int64_t const num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1; - constexpr int64_t num_longs_per_threads = 32; + constexpr int64_t num_longs_per_threads = 256; constexpr int64_t block_size = 256; int64_t num_threads_per_col_phase1 = From 51ead98780fbc490e502507e4ba02da33c0789c2 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 26 Nov 2024 15:17:46 +0800 Subject: [PATCH 07/10] Update code comments; Minor changes --- cpp/include/cudf/aggregation.hpp | 4 ++-- cpp/src/groupby/sort/aggregate.cpp | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 355e4f59f60..09018d58c5c 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -121,8 +121,8 @@ class aggregation { MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together HISTOGRAM, ///< compute frequency of each element MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation - HLLPP, ///< approximating the number of distinct items by using hyper log log plus plus (HLLPP) - MERGE_HLLPP ///< merge partial values of HLLPP aggregation + HLLPP, ///< approximating the number of distinct items by using HyperLogLogPlusPlus (HLLPP) + MERGE_HLLPP ///< merge partial values of HyperLogLogPlusPlus aggregation }; aggregation() = delete; diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 814577fa2dd..13766c9a557 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -814,7 +814,6 @@ void aggregate_result_functor::operator()(aggregation int const precision = dynamic_cast(agg).precision; - cache.add_result(values, agg, detail::group_merge_hyper_log_log_plus_plus(get_grouped_values(), From 8e0ff011122bbab5a305fbfe20490e3859baed50 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 26 Nov 2024 15:58:25 +0800 Subject: [PATCH 08/10] Format code --- .../hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp | 1 + cpp/src/reductions/reductions.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp b/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp index 71f27cd1a36..ebbe08044bc 100644 --- a/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp +++ b/cpp/include/cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp @@ -18,6 +18,7 @@ #include #include + #include namespace cudf { diff --git a/cpp/src/reductions/reductions.cpp b/cpp/src/reductions/reductions.cpp index 68a33ad9fc1..1b475bb87e0 100644 --- a/cpp/src/reductions/reductions.cpp +++ b/cpp/src/reductions/reductions.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -29,7 +30,7 @@ #include #include #include -#include + #include #include @@ -146,11 +147,13 @@ struct reduce_dispatch_functor { } case aggregation::HLLPP: { auto hllpp_agg = static_cast(agg); - return cudf::groupby::detail::reduce_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr); + return cudf::groupby::detail::reduce_hyper_log_log_plus_plus( + col, hllpp_agg.precision, stream, mr); } case aggregation::MERGE_HLLPP: { auto hllpp_agg = static_cast(agg); - return cudf::groupby::detail::reduce_merge_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr); + return cudf::groupby::detail::reduce_merge_hyper_log_log_plus_plus( + col, hllpp_agg.precision, stream, mr); } default: CUDF_FAIL("Unsupported reduction operator"); } From 884efe0c106fccaf79d19c530eb56fea756914d0 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 26 Nov 2024 18:35:35 +0800 Subject: [PATCH 09/10] Use has_nested_nulls; fix compile error --- cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index f7ebc292400..c7dc5ee34e5 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -31,6 +31,7 @@ #include #include +#include #include // TODO #include once available #include #include @@ -343,13 +344,11 @@ std::unique_ptr group_hllpp(column_view const& input, int64_t num_registers_per_sketch = 1 << precision; // 1. compute all the hashs - // TODO: mask_state::ALL_VALID => unallocate auto hash_col = make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); - auto input_table = cudf::table_view{{input}}; - auto d_input_table = cudf::table_device_view::create(input_table, stream); - // TODO: has_nulls has nested null? - bool const nullable = input.has_nulls(); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + bool const nullable = has_nested_nulls(input_table); thrust::tabulate( rmm::exec_policy(stream), hash_col->mutable_view().begin(), From 41f4ea27344e4cabe91657d13959eb5f9601c484 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Wed, 27 Nov 2024 14:35:04 +0800 Subject: [PATCH 10/10] update xxhash64 for hllpp --- cpp/include/cudf/hashing/detail/xxhash_64.cuh | 279 +++--------------- .../hashing/detail/xxhash_64_for_hllpp.cuh | 94 ++++++ .../sort/group_hyper_log_log_plus_plus.cu | 18 +- cpp/src/hash/xxhash_64.cu | 67 ++++- 4 files changed, 210 insertions(+), 248 deletions(-) create mode 100644 cpp/include/cudf/hashing/detail/xxhash_64_for_hllpp.cuh diff --git a/cpp/include/cudf/hashing/detail/xxhash_64.cuh b/cpp/include/cudf/hashing/detail/xxhash_64.cuh index eaf85dae5e9..b00e8297ac9 100644 --- a/cpp/include/cudf/hashing/detail/xxhash_64.cuh +++ b/cpp/include/cudf/hashing/detail/xxhash_64.cuh @@ -13,282 +13,87 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include +#pragma once -namespace cudf::hashing::detail { -using hash_value_type = uint64_t; +#include "hash_functions.cuh" -template -struct XXHash_64 { - using result_type = hash_value_type; +#include +#include +#include - constexpr XXHash_64() = default; - constexpr XXHash_64(hash_value_type seed) : m_seed(seed) {} +#include +#include - __device__ inline uint32_t getblock32(std::byte const* data, std::size_t offset) const - { - // Read a 4-byte value from the data pointer as individual bytes for safe - // unaligned access (very likely for string types). - auto block = reinterpret_cast(data + offset); - return block[0] | (block[1] << 8) | (block[2] << 16) | (block[3] << 24); - } - - __device__ inline uint64_t getblock64(std::byte const* data, std::size_t offset) const - { - uint64_t result = getblock32(data, offset + 4); - result = result << 32; - return result | getblock32(data, offset); - } +namespace cudf::hashing::detail { - result_type __device__ inline operator()(Key const& key) const { return compute(key); } +template +struct XXHash_64 : public cuco::xxhash_64 { + using result_type = typename cuco::xxhash_64::result_type; - template - result_type __device__ inline compute(T const& key) const + __device__ result_type operator()(Key const& key) const { - auto data = device_span(reinterpret_cast(&key), sizeof(T)); - return compute_bytes(data); + return cuco::xxhash_64::operator()(key); } - result_type __device__ inline compute_remaining_bytes(device_span& in, - std::size_t offset, - result_type h64) const + template + __device__ result_type compute_hash(cuda::std::byte const* bytes, Extent size) const { - // remaining data can be processed in 8-byte chunks - if ((in.size() % 32) >= 8) { - for (; offset <= in.size() - 8; offset += 8) { - uint64_t k1 = getblock64(in.data(), offset) * prime2; - - k1 = rotate_bits_left(k1, 31) * prime1; - h64 ^= k1; - h64 = rotate_bits_left(h64, 27) * prime1 + prime4; - } - } - - // remaining data can be processed in 4-byte chunks - if ((in.size() % 8) >= 4) { - for (; offset <= in.size() - 4; offset += 4) { - h64 ^= (getblock32(in.data(), offset) & 0xfffffffful) * prime1; - h64 = rotate_bits_left(h64, 23) * prime2 + prime3; - } - } - - // and the rest - if (in.size() % 4) { - while (offset < in.size()) { - h64 ^= (std::to_integer(in[offset]) & 0xff) * prime5; - h64 = rotate_bits_left(h64, 11) * prime1; - ++offset; - } - } - return h64; + return cuco::xxhash_64::compute_hash(bytes, size); } - - result_type __device__ compute_bytes(device_span& in) const - { - uint64_t offset = 0; - uint64_t h64; - // data can be processed in 32-byte chunks - if (in.size() >= 32) { - auto limit = in.size() - 32; - uint64_t v1 = m_seed + prime1 + prime2; - uint64_t v2 = m_seed + prime2; - uint64_t v3 = m_seed; - uint64_t v4 = m_seed - prime1; - - do { - // pipeline 4*8byte computations - v1 += getblock64(in.data(), offset) * prime2; - v1 = rotate_bits_left(v1, 31); - v1 *= prime1; - offset += 8; - v2 += getblock64(in.data(), offset) * prime2; - v2 = rotate_bits_left(v2, 31); - v2 *= prime1; - offset += 8; - v3 += getblock64(in.data(), offset) * prime2; - v3 = rotate_bits_left(v3, 31); - v3 *= prime1; - offset += 8; - v4 += getblock64(in.data(), offset) * prime2; - v4 = rotate_bits_left(v4, 31); - v4 *= prime1; - offset += 8; - } while (offset <= limit); - - h64 = rotate_bits_left(v1, 1) + rotate_bits_left(v2, 7) + rotate_bits_left(v3, 12) + - rotate_bits_left(v4, 18); - - v1 *= prime2; - v1 = rotate_bits_left(v1, 31); - v1 *= prime1; - h64 ^= v1; - h64 = h64 * prime1 + prime4; - - v2 *= prime2; - v2 = rotate_bits_left(v2, 31); - v2 *= prime1; - h64 ^= v2; - h64 = h64 * prime1 + prime4; - - v3 *= prime2; - v3 = rotate_bits_left(v3, 31); - v3 *= prime1; - h64 ^= v3; - h64 = h64 * prime1 + prime4; - - v4 *= prime2; - v4 = rotate_bits_left(v4, 31); - v4 *= prime1; - h64 ^= v4; - h64 = h64 * prime1 + prime4; - } else { - h64 = m_seed + prime5; - } - - h64 += in.size(); - - h64 = compute_remaining_bytes(in, offset, h64); - - return finalize(h64); - } - - constexpr __host__ __device__ std::uint64_t finalize(std::uint64_t h) const noexcept - { - h ^= h >> 33; - h *= prime2; - h ^= h >> 29; - h *= prime3; - h ^= h >> 32; - return h; - } - - private: - hash_value_type m_seed{}; - static constexpr uint64_t prime1 = 0x9e3779b185ebca87ul; - static constexpr uint64_t prime2 = 0xc2b2ae3d27d4eb4ful; - static constexpr uint64_t prime3 = 0x165667b19e3779f9ul; - static constexpr uint64_t prime4 = 0x85ebca77c2b2ae63ul; - static constexpr uint64_t prime5 = 0x27d4eb2f165667c5ul; }; template <> -hash_value_type __device__ inline XXHash_64::operator()(bool const& key) const +XXHash_64::result_type __device__ inline XXHash_64::operator()(bool const& key) const { - return compute(static_cast(key)); + return this->compute_hash(reinterpret_cast(&key), sizeof(key)); } template <> -hash_value_type __device__ inline XXHash_64::operator()(float const& key) const +XXHash_64::result_type __device__ inline XXHash_64::operator()(float const& key) const { - return compute(normalize_nans(key)); + return cuco::xxhash_64::operator()(normalize_nans(key)); } template <> -hash_value_type __device__ inline XXHash_64::operator()(double const& key) const +XXHash_64::result_type __device__ inline XXHash_64::operator()( + double const& key) const { - return compute(normalize_nans(key)); + return cuco::xxhash_64::operator()(normalize_nans(key)); } template <> -hash_value_type __device__ inline XXHash_64::operator()( - cudf::string_view const& key) const +XXHash_64::result_type + __device__ inline XXHash_64::operator()(cudf::string_view const& key) const { - auto const len = key.size_bytes(); - auto data = device_span(reinterpret_cast(key.data()), len); - return compute_bytes(data); + return this->compute_hash(reinterpret_cast(key.data()), key.size_bytes()); } template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal32 const& key) const +XXHash_64::result_type + __device__ inline XXHash_64::operator()(numeric::decimal32 const& key) const { - return compute(key.value()); + auto const val = key.value(); + auto const len = sizeof(val); + return this->compute_hash(reinterpret_cast(&val), len); } template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal64 const& key) const +XXHash_64::result_type + __device__ inline XXHash_64::operator()(numeric::decimal64 const& key) const { - return compute(key.value()); + auto const val = key.value(); + auto const len = sizeof(val); + return this->compute_hash(reinterpret_cast(&val), len); } template <> -hash_value_type __device__ inline XXHash_64::operator()( - numeric::decimal128 const& key) const +XXHash_64::result_type + __device__ inline XXHash_64::operator()(numeric::decimal128 const& key) const { - return compute(key.value()); + auto const val = key.value(); + auto const len = sizeof(val); + return this->compute_hash(reinterpret_cast(&val), len); } -/** - * @brief Computes the hash value of a row in the given table. - * - * @tparam Nullate A cudf::nullate type describing whether to check for nulls. - */ -template -class xxhash_64_device_row_hasher { - public: - xxhash_64_device_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) - : _check_nulls(nulls), _table(t), _seed(seed) - { - } - - __device__ auto operator()(size_type row_index) const noexcept - { - return cudf::detail::accumulate( - _table.begin(), - _table.end(), - _seed, - [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { - return cudf::type_dispatcher( - column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); - }); - } - - /** - * @brief Computes the hash value of an element in the given column. - */ - class element_hasher_adapter { - public: - template ())> - __device__ hash_value_type operator()(column_device_view const& col, - size_type const row_index, - Nullate const _check_nulls, - hash_value_type const _seed) const noexcept - { - if (_check_nulls && col.is_null(row_index)) { - return std::numeric_limits::max(); - } - auto const hasher = XXHash_64{_seed}; - return hasher(col.element(row_index)); - } - - template ())> - __device__ hash_value_type operator()(column_device_view const&, - size_type const, - Nullate const, - hash_value_type const) const noexcept - { - CUDF_UNREACHABLE("Unsupported type for XXHash_64"); - } - }; - - Nullate const _check_nulls; - table_device_view const _table; - hash_value_type const _seed; -}; - -} // namespace cudf::hashing::detail \ No newline at end of file +} // namespace cudf::hashing::detail diff --git a/cpp/include/cudf/hashing/detail/xxhash_64_for_hllpp.cuh b/cpp/include/cudf/hashing/detail/xxhash_64_for_hllpp.cuh new file mode 100644 index 00000000000..ef308ca71d7 --- /dev/null +++ b/cpp/include/cudf/hashing/detail/xxhash_64_for_hllpp.cuh @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +/** + * This file is for HyperLogLogPlusPlus, it returns seed when input is null. + * This is a temp file, TODO use xxhash_64 in JNI repo to handle NaN Inf like Spark does. + */ +namespace cudf::hashing::detail { + +using hash_value_type = uint64_t; + +/** + * @brief Computes the hash value of a row in the given table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ +template +class xxhash_64_hllpp_row_hasher { + public: + xxhash_64_hllpp_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) + : _check_nulls(nulls), _table(t), _seed(seed) + { + } + + __device__ auto operator()(size_type row_index) const noexcept + { + return cudf::detail::accumulate( + _table.begin(), + _table.end(), + _seed, + [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { + return cudf::type_dispatcher( + column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); + }); + } + + /** + * @brief Computes the hash value of an element in the given column. + */ + class element_hasher_adapter { + public: + template ())> + __device__ hash_value_type operator()(column_device_view const& col, + size_type const row_index, + Nullate const _check_nulls, + hash_value_type const _seed) const noexcept + { + if (_check_nulls && col.is_null(row_index)) { return _seed; } + auto const hasher = XXHash_64{_seed}; + return hasher(col.element(row_index)); + } + + template ())> + __device__ hash_value_type operator()(column_device_view const&, + size_type const, + Nullate const, + hash_value_type const) const noexcept + { + CUDF_UNREACHABLE("Unsupported type for XXHash_64"); + } + }; + + Nullate const _check_nulls; + table_device_view const _table; + hash_value_type const _seed; +}; + +} // namespace cudf::hashing::detail diff --git a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu index c7dc5ee34e5..7b30ba4110a 100644 --- a/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu +++ b/cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -346,14 +346,14 @@ std::unique_ptr group_hllpp(column_view const& input, // 1. compute all the hashs auto hash_col = make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); - auto input_table = cudf::table_view{{input}}; - auto d_input_table = cudf::table_device_view::create(input_table, stream); - bool const nullable = has_nested_nulls(input_table); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + bool nullable = has_nested_nulls(input_table); thrust::tabulate( rmm::exec_policy(stream), hash_col->mutable_view().begin(), hash_col->mutable_view().end(), - cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED)); + cudf::hashing::detail::xxhash_64_hllpp_row_hasher(nullable, *d_input_table, SEED)); auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream); // 2. execute partial group by @@ -675,14 +675,14 @@ std::unique_ptr reduce_hllpp(column_view const& input, // 1. compute all the hashs auto hash_col = make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr); - auto input_table = cudf::table_view{{input}}; - auto d_input_table = cudf::table_device_view::create(input_table, stream); - bool const nullable = input.has_nulls(); + auto input_table = cudf::table_view{{input}}; + auto d_input_table = cudf::table_device_view::create(input_table, stream); + bool nullable = has_nested_nulls(input_table); thrust::tabulate( rmm::exec_policy(stream), hash_col->mutable_view().begin(), hash_col->mutable_view().end(), - cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED)); + cudf::hashing::detail::xxhash_64_hllpp_row_hasher(nullable, *d_input_table, SEED)); auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream); // 2. generate long columns, the size of each long column is 1 diff --git a/cpp/src/hash/xxhash_64.cu b/cpp/src/hash/xxhash_64.cu index c3cc9d87d74..bdbe13b1ffb 100644 --- a/cpp/src/hash/xxhash_64.cu +++ b/cpp/src/hash/xxhash_64.cu @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,70 @@ namespace cudf { namespace hashing { namespace detail { +namespace { + +using hash_value_type = uint64_t; + +/** + * @brief Computes the hash value of a row in the given table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + */ +template +class device_row_hasher { + public: + device_row_hasher(Nullate nulls, table_device_view const& t, hash_value_type seed) + : _check_nulls(nulls), _table(t), _seed(seed) + { + } + + __device__ auto operator()(size_type row_index) const noexcept + { + return cudf::detail::accumulate( + _table.begin(), + _table.end(), + _seed, + [row_index, nulls = _check_nulls] __device__(auto hash, auto column) { + return cudf::type_dispatcher( + column.type(), element_hasher_adapter{}, column, row_index, nulls, hash); + }); + } + + /** + * @brief Computes the hash value of an element in the given column. + */ + class element_hasher_adapter { + public: + template ())> + __device__ hash_value_type operator()(column_device_view const& col, + size_type const row_index, + Nullate const _check_nulls, + hash_value_type const _seed) const noexcept + { + if (_check_nulls && col.is_null(row_index)) { + return std::numeric_limits::max(); + } + auto const hasher = XXHash_64{_seed}; + return hasher(col.element(row_index)); + } + + template ())> + __device__ hash_value_type operator()(column_device_view const&, + size_type const, + Nullate const, + hash_value_type const) const noexcept + { + CUDF_UNREACHABLE("Unsupported type for XXHash_64"); + } + }; + + Nullate const _check_nulls; + table_device_view const _table; + hash_value_type const _seed; +}; + +} // namespace + std::unique_ptr xxhash_64(table_view const& input, uint64_t seed, rmm::cuda_stream_view stream, @@ -54,7 +117,7 @@ std::unique_ptr xxhash_64(table_view const& input, thrust::tabulate(rmm::exec_policy(stream), output_view.begin(), output_view.end(), - xxhash_64_device_row_hasher(nullable, *input_view, seed)); + device_row_hasher(nullable, *input_view, seed)); return output; }