Skip to content

Commit

Permalink
Support hyper log log plus plus(HLL++)
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Oct 29, 2024
1 parent 074ab74 commit bb2bfc2
Show file tree
Hide file tree
Showing 15 changed files with 1,512 additions and 40 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ add_library(
src/groupby/sort/group_correlation.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_histogram.cu
src/groupby/sort/group_hyper_log_log.cu
src/groupby/sort/group_m2.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
Expand All @@ -396,6 +397,7 @@ add_library(
src/groupby/sort/scan.cpp
src/groupby/sort/group_count_scan.cu
src/groupby/sort/group_max_scan.cu
src/groupby/sort/group_merge_hyper_log_log.cu
src/groupby/sort/group_min_scan.cu
src/groupby/sort/group_product_scan.cu
src/groupby/sort/group_rank_scan.cu
Expand Down
82 changes: 45 additions & 37 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,43 +84,45 @@ class aggregation {
* @brief Possible aggregation operations
*/
enum Kind {
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM ///< merge partial values of HISTOGRAM aggregation,
SUM, ///< sum reduction
PRODUCT, ///< product reduction
MIN, ///< min reduction
MAX, ///< max reduction
COUNT_VALID, ///< count number of valid elements
COUNT_ALL, ///< count number of elements
ANY, ///< any reduction
ALL, ///< all reduction
SUM_OF_SQUARES, ///< sum of squares reduction
MEAN, ///< arithmetic mean reduction
M2, ///< sum of squares of differences from the mean
VARIANCE, ///< variance
STD, ///< standard deviation
MEDIAN, ///< median reduction
QUANTILE, ///< compute specified quantile(s)
ARGMAX, ///< Index of max element
ARGMIN, ///< Index of min element
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
EWMA, ///< get exponential weighted moving average at current index
RANK, ///< get rank of current index
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST, ///< create a tdigest by merging multiple tdigests together
HISTOGRAM, ///< compute frequency of each element
MERGE_HISTOGRAM, ///< merge partial values of HISTOGRAM aggregation
HLLPP, ///< approximating the number of distinct items by using hyper log log plus plus (HLLPP)
MERGE_HLLPP ///< merge partial values of HLLPP aggregation
};

aggregation() = delete;
Expand Down Expand Up @@ -770,5 +772,11 @@ std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids = 1000);
template <typename Base>
std::unique_ptr<Base> make_merge_tdigest_aggregation(int max_centroids = 1000);

template <typename Base = aggregation>
std::unique_ptr<Base> make_hyper_log_log_aggregation(int num_registers_per_sketch);

template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_hyper_log_log_aggregation(int const num_registers_per_sketch);

/** @} */ // end of group
} // namespace CUDF_EXPORT cudf
70 changes: 70 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
class tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, class merge_tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, class hyper_log_log_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, class merge_hyper_log_log_aggregation const& agg);
};

class aggregation_finalizer { // Declares the interface for the finalizer
Expand Down Expand Up @@ -144,6 +148,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
virtual void visit(class ewma_aggregation const& agg);
virtual void visit(class hyper_log_log_aggregation const& agg);
virtual void visit(class merge_hyper_log_log_aggregation const& agg);
};

/**
Expand Down Expand Up @@ -1186,6 +1192,54 @@ class merge_tdigest_aggregation final : public groupby_aggregation, public reduc
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying TDIGEST aggregation
*/
class hyper_log_log_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
explicit hyper_log_log_aggregation(int const num_registers_per_sketch_)
: aggregation{HLLPP}, num_registers_per_sketch(num_registers_per_sketch_)
{
}

int const num_registers_per_sketch;

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<hyper_log_log_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying MERGE_TDIGEST aggregation
*/
class merge_hyper_log_log_aggregation final : public groupby_aggregation,
public reduce_aggregation {
public:
explicit merge_hyper_log_log_aggregation(int const num_registers_per_sketch_)
: aggregation{MERGE_HLLPP}, num_registers_per_sketch(num_registers_per_sketch_)
{
}
int const num_registers_per_sketch;

[[nodiscard]] std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<merge_hyper_log_log_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down Expand Up @@ -1319,6 +1373,12 @@ struct target_type_impl<SourceType, aggregation::M2> {
using type = double;
};

// Always use list for HLLPP
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::HLLPP> {
using type = list_view;
};

// Always use `double` for VARIANCE
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::VARIANCE> {
Expand Down Expand Up @@ -1426,6 +1486,12 @@ struct target_type_impl<SourceType, aggregation::MERGE_M2> {
using type = struct_view;
};

// Always use list for MERGE_HLLPP
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::MERGE_HLLPP> {
using type = list_view;
};

// Use list for MERGE_HISTOGRAM
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::MERGE_HISTOGRAM> {
Expand Down Expand Up @@ -1579,6 +1645,10 @@ CUDF_HOST_DEVICE inline decltype(auto) aggregation_dispatcher(aggregation::Kind
return f.template operator()<aggregation::MERGE_TDIGEST>(std::forward<Ts>(args)...);
case aggregation::EWMA:
return f.template operator()<aggregation::EWMA>(std::forward<Ts>(args)...);
case aggregation::HLLPP:
return f.template operator()<aggregation::HLLPP>(std::forward<Ts>(args)...);
case aggregation::MERGE_HLLPP:
return f.template operator()<aggregation::MERGE_HLLPP>(std::forward<Ts>(args)...);
default: {
#ifndef __CUDA_ARCH__
CUDF_FAIL("Unsupported aggregation.");
Expand Down
48 changes: 48 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,18 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, hyper_log_log_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, merge_hyper_log_log_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

// aggregation_finalizer ----------------------------------------

void aggregation_finalizer::visit(aggregation const& agg) {}
Expand Down Expand Up @@ -410,6 +422,16 @@ void aggregation_finalizer::visit(merge_tdigest_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(hyper_log_log_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(merge_hyper_log_log_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

} // namespace detail

std::vector<std::unique_ptr<aggregation>> aggregation::get_simple_aggregations(
Expand Down Expand Up @@ -917,6 +939,32 @@ make_merge_tdigest_aggregation<groupby_aggregation>(int max_centroids);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_merge_tdigest_aggregation<reduce_aggregation>(int max_centroids);

/// Factory to create a HLLPP aggregation
template <typename Base>
std::unique_ptr<Base> make_hyper_log_log_aggregation(int const num_registers_per_sketch)
{
return std::make_unique<detail::hyper_log_log_aggregation>(num_registers_per_sketch);
}
template CUDF_EXPORT std::unique_ptr<aggregation> make_hyper_log_log_aggregation<aggregation>(
int num_registers_per_sketch);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_hyper_log_log_aggregation<groupby_aggregation>(int num_registers_per_sketch);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_hyper_log_log_aggregation<reduce_aggregation>(int num_registers_per_sketch);

/// Factory to create a MERGE_HLLPP aggregation
template <typename Base>
std::unique_ptr<Base> make_merge_hyper_log_log_aggregation(int const num_registers_per_sketch)
{
return std::make_unique<detail::merge_hyper_log_log_aggregation>(num_registers_per_sketch);
}
template CUDF_EXPORT std::unique_ptr<aggregation> make_merge_hyper_log_log_aggregation<aggregation>(
int const num_registers_per_sketch);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_merge_hyper_log_log_aggregation<groupby_aggregation>(int const num_registers_per_sketch);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_merge_hyper_log_log_aggregation<reduce_aggregation>(int const num_registers_per_sketch);

namespace detail {
namespace {
struct target_type_functor {
Expand Down
39 changes: 39 additions & 0 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,26 @@ void aggregate_result_functor::operator()<aggregation::TDIGEST>(aggregation cons
mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::HLLPP>(aggregation const& agg)
{
if (cache.has_result(values, agg)) { return; }

int const num_registers_per_sketch =
dynamic_cast<cudf::detail::hyper_log_log_aggregation const&>(agg).num_registers_per_sketch;

printf("my-debug: dynamic cast, num is %d \n", num_registers_per_sketch);

cache.add_result(values,
agg,
detail::group_hyper_log_log(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
num_registers_per_sketch,
stream,
mr));
}

/**
* @brief Generate a merged tdigest column from a grouped set of input tdigest columns.
*
Expand Down Expand Up @@ -792,6 +812,25 @@ void aggregate_result_functor::operator()<aggregation::MERGE_TDIGEST>(aggregatio
mr));
}

template <>
void aggregate_result_functor::operator()<aggregation::MERGE_HLLPP>(aggregation const& agg)
{
if (cache.has_result(values, agg)) { return; }

int const num_registers_per_sketch =
dynamic_cast<cudf::detail::merge_hyper_log_log_aggregation const&>(agg)
.num_registers_per_sketch;

cache.add_result(values,
agg,
detail::group_merge_hyper_log_log(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
num_registers_per_sketch,
stream,
mr));
}

} // namespace detail

// Sort-based groupby
Expand Down
Loading

0 comments on commit bb2bfc2

Please sign in to comment.