From d2b00eda0767ea41c01a027cebdae0e9ee938a6e Mon Sep 17 00:00:00 2001 From: Yenda Li Date: Sun, 8 Dec 2024 21:17:48 -0800 Subject: [PATCH] feat: add classification functions (#11792) Summary: Add the classification functions from presto into velox: https://prestodb.io/docs/current/functions/aggregate.html#classification-metrics-aggregate-functions Classification functions all use `FixedDoubleHistogram`, which is a data structure to represent the bucket of weights. The index of the bucket for the histogram is evenly distributed between the min and value values. For all of the classification functions, the only difference is the extraction phase. All other steps will be the same. At a high level: - addRawInput will add a value into either the true or false weight bucket. The bucket to add the value to will depend on the prediction value. The prediction value is linearly mapped into a bucket based on (min, max and bucketCount) by normalizing the prediction between min and max. - The schema of the intermediate states is [version header][bucket count][min][max][weights] Differential Revision: D66684198 --- velox/docs/functions/presto/aggregate.rst | 195 +++++ velox/docs/functions/presto/coverage.rst | 10 +- .../prestosql/aggregates/AggregateNames.h | 5 + .../aggregates/ClassificationAggregation.cpp | 720 ++++++++++++++++++ .../aggregates/RegisterAggregateFunctions.cpp | 5 + .../prestosql/aggregates/tests/CMakeLists.txt | 1 + .../tests/ClassificationAggregationTest.cpp | 221 ++++++ .../prestosql/fuzzer/WindowFuzzerTest.cpp | 5 + 8 files changed, 1157 insertions(+), 5 deletions(-) create mode 100644 velox/functions/prestosql/aggregates/ClassificationAggregation.cpp create mode 100644 velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp diff --git a/velox/docs/functions/presto/aggregate.rst b/velox/docs/functions/presto/aggregate.rst index c8443653857f..38e798dc324f 100644 --- a/velox/docs/functions/presto/aggregate.rst +++ b/velox/docs/functions/presto/aggregate.rst @@ -411,6 +411,201 @@ __ https://www.cse.ust.hk/~raywong/comp5331/References/EfficientComputationOfFre As ``approx_percentile(x, w, percentages)``, but with a maximum rank error of ``accuracy``. +Classification Metrics Aggregate Functions +------------------------------------------ + +The following functions each measure how some metric of a binary +`confusion matrix `_ changes as a function of +classification thresholds. They are meant to be used in conjunction. + +For example, to find the `precision-recall curve `_, use + + .. code-block:: none + + WITH + recall_precision AS ( + SELECT + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls, + CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions + FROM + classification_dataset + ) + SELECT + recall, + precision + FROM + recall_precision + CROSS JOIN UNNEST(recalls, precisions) AS t(recall, precision) + +To get the corresponding thresholds for these values, use + + .. code-block:: none + + WITH + recall_precision AS ( + SELECT + CLASSIFICATION_THRESHOLDS(10000, correct, pred) AS thresholds, + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls, + CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions + FROM + classification_dataset + ) + SELECT + threshold, + recall, + precision + FROM + recall_precision + CROSS JOIN UNNEST(thresholds, recalls, precisions) AS t(threshold, recall, precision) + +To find the `ROC curve `_, use + + .. code-block:: none + + WITH + fallout_recall AS ( + SELECT + CLASSIFICATION_FALLOUT(10000, correct, pred) AS fallouts, + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls + FROM + classification_dataset + ) + SELECT + fallout + recall, + FROM + recall_fallout + CROSS JOIN UNNEST(fallouts, recalls) AS t(fallout, recall) + + +.. function:: classification_miss_rate(buckets, y, x, weight) -> array + + Computes the miss-rate with up to ``buckets`` number of buckets. Returns + an array of miss-rate values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `miss-rate `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right] + + + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_miss_rate(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_miss_rate` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_fall_out(buckets, y, x, weight) -> array + + Computes the fall-out with up to ``buckets`` number of buckets. Returns + an array of fall-out values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `fall-out `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 0} \left[ w_i \right] + \over + \sum_{i \;|\; y_i = 0} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_fall_out(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_fall_out` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_precision(buckets, y, x, weight) -> array + + Computes the precision with up to ``buckets`` number of buckets. Returns + an array of precision values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `precision `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; x_i > t_j} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_precision(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_precision` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_recall(buckets, y, x, weight) -> array + + Computes the recall with up to ``buckets`` number of buckets. Returns + an array of recall values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `recall `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; y_i = 1} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_recall(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_recall` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_thresholds(buckets, y, x) -> array + + Computes the thresholds with up to ``buckets`` number of buckets. Returns + an array of threshold values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1. + + The thresholds are defined as a sequence whose :math:`j`-th entry is the :math:`j`-th smallest threshold. + Statistical Aggregate Functions ------------------------------- diff --git a/velox/docs/functions/presto/coverage.rst b/velox/docs/functions/presto/coverage.rst index 127297b7e197..a4df266e46b2 100644 --- a/velox/docs/functions/presto/coverage.rst +++ b/velox/docs/functions/presto/coverage.rst @@ -325,11 +325,11 @@ Here is a list of all scalar and aggregate Presto functions with functions that :func:`array_duplicates` :func:`dow` :func:`json_extract` :func:`repeat` st_union :func:`bool_and` :func:`rank` :func:`array_except` :func:`doy` :func:`json_extract_scalar` :func:`replace` st_within :func:`bool_or` :func:`row_number` :func:`array_frequency` :func:`e` :func:`json_format` replace_first st_x :func:`checksum` - :func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax classification_fall_out - :func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin classification_miss_rate - :func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y classification_precision - array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax classification_recall - :func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin classification_thresholds + :func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax :func: `classification_fall_out` + :func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin :func: `classification_miss_rate` + :func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y :func: `classification_precision` + array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax :func: `classification_recall` + :func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin :func: `classification_thresholds` array_max_by expand_envelope :func:`least` scale_qdigest :func:`starts_with` convex_hull_agg :func:`array_min` :func:`f_cdf` :func:`length` :func:`second` :func:`strpos` :func:`corr` array_min_by features :func:`levenshtein_distance` secure_rand :func:`strrpos` :func:`count` diff --git a/velox/functions/prestosql/aggregates/AggregateNames.h b/velox/functions/prestosql/aggregates/AggregateNames.h index 7cf2ff9810d8..27f539693d19 100644 --- a/velox/functions/prestosql/aggregates/AggregateNames.h +++ b/velox/functions/prestosql/aggregates/AggregateNames.h @@ -32,6 +32,11 @@ const char* const kBitwiseXor = "bitwise_xor_agg"; const char* const kBoolAnd = "bool_and"; const char* const kBoolOr = "bool_or"; const char* const kChecksum = "checksum"; +const char* const kClassificationFallout = "classification_fall_out"; +const char* const kClassificationPrecision = "classification_precision"; +const char* const kClassificationRecall = "classification_recall"; +const char* const kClassificationMissRate = "classification_miss_rate"; +const char* const kClassificationThreshold = "classification_thresholds"; const char* const kCorr = "corr"; const char* const kCount = "count"; const char* const kCountIf = "count_if"; diff --git a/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp b/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp new file mode 100644 index 000000000000..c351d2a68481 --- /dev/null +++ b/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp @@ -0,0 +1,720 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/IOUtils.h" +#include "velox/exec/Aggregate.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::aggregate::prestosql { +namespace { + +enum class ClassificationType { + kFallout = 0, + kPrecision = 1, + kRecall = 2, + kMissRate = 3, + kThresholds = 4, +}; + +/// Struct to represent the bucket of the FixedDoubleHistogram +/// at a given index. +struct Bucket { + Bucket(double _left, double _right, double _weight) + : left(_left), right(_right), weight(_weight) {} + const double left; + const double right; + const double weight; +}; + +/// Fixed-bucket histogram of weights as doubles. For each bucket, it stores the +/// total weight accumulated. +class FixedDoubleHistogram { + public: + explicit FixedDoubleHistogram(HashStringAllocator* allocator) + : weightsOne_(StlAllocator(allocator)), + weightsTwo_(StlAllocator(allocator)) {} + + FixedDoubleHistogram( + int64_t bucketCount, + double min, + double max, + HashStringAllocator* allocator) + : weightsOne_(StlAllocator(allocator)), + weightsTwo_(StlAllocator(allocator)), + bucketCount_(bucketCount), + min_(min), + max_(max) { + resizeWeights(); + } + + void resizeWeights() { + validateParameters(bucketCount_, min_, max_); + weightsOne_.resize(std::min(bucketCount_, kMaxBucketCount)); + weightsTwo_.resize( + std::max(static_cast(0), bucketCount_ - kMaxBucketCount)); + } + + /// API to support the case when bucket is created without a bucketCount + /// count. + void tryInit(int64_t bucketCount) { + validateParameters(bucketCount, min_, max_); + if (bucketCount_ == -1) { + bucketCount_ = bucketCount; + resizeWeights(); + } + } + + /// Set the weight of a bucket at a given index. + void setWeight(int64_t index, double weight) { + auto& weights = index >= kMaxBucketCount ? weightsTwo_ : weightsOne_; + totalWeights_ -= weights.at(index); + weights.at(index % kMaxBucketCount) = weight; + totalWeights_ += weight; + if (weight != 0) { + maxUsedIndex_ = std::max(maxUsedIndex_, index); + } + } + + /// Add weight to bucket based on the value of the prediction. + void add(double pred, double weight) { + auto index = getIndexForValue(bucketCount_, min_, max_, pred); + auto& weights = index >= kMaxBucketCount ? weightsTwo_ : weightsOne_; + weights[index % kMaxBucketCount] += weight; + totalWeights_ += weight; + maxUsedIndex_ = std::max(maxUsedIndex_, index); + } + + /// Returns a bucket in this histogram at a given index. + Bucket getBucket(int64_t index) { + auto& weights = index >= kMaxBucketCount ? weightsTwo_ : weightsOne_; + return Bucket( + getLeftValueForIndex(bucketCount_, min_, max_, index), + getRightValueForIndex(bucketCount_, min_, max_, index), + weights[index % kMaxBucketCount]); + } + + /// The size of the histogram is represented by maxUsedIndex_, which + /// represents the largest index in the buckets with a non-zero accrued value. + /// This helps us avoid O(n) operation for the size of the histogram. + int64_t size() const { + return maxUsedIndex_ + 1; + } + + int64_t bucketCount() const { + return bucketCount_; + } + + /// The state of the histogram can be serialized into a buffer. The format is + /// represented as [header][bucketCount][min][max][weights]. The header is + /// used to identify the version of the serialization format. The bucketCount, + /// min, and max are used to represent the parameters of the histogram. + /// Weights are the number of weights (equal to number of buckets) in the + /// histogram. + size_t serialize(char* output) const { + VELOX_CHECK(output); + common::OutputByteStream stream(output); + size_t bytesUsed = 0; + stream.append( + reinterpret_cast(&kSerializationVersionHeader), + sizeof(kSerializationVersionHeader)); + bytesUsed += sizeof(kSerializationVersionHeader); + + stream.append( + reinterpret_cast(&bucketCount_), sizeof(bucketCount_)); + bytesUsed += sizeof(bucketCount_); + + stream.append(reinterpret_cast(&min_), sizeof(min_)); + bytesUsed += sizeof(min_); + + stream.append(reinterpret_cast(&max_), sizeof(max_)); + bytesUsed += sizeof(max_); + + for (auto weight : weightsOne_) { + stream.append(reinterpret_cast(&weight), sizeof(weight)); + bytesUsed += sizeof(weight); + } + for (auto weight : weightsTwo_) { + stream.append(reinterpret_cast(&weight), sizeof(weight)); + bytesUsed += sizeof(weight); + } + + return bytesUsed; + } + + /// Deserializes the histogram from a buffer. + static FixedDoubleHistogram deserialize( + common::InputByteStream& in, + size_t expectedSize, + HashStringAllocator* allocator) { + if (FOLLY_UNLIKELY(expectedSize < minDeserializedBufferSize())) { + VELOX_USER_FAIL( + "Cannot deserialize FixedDoubleHistogram. Expected size: {}, actual size: {}", + minDeserializedBufferSize(), + expectedSize); + } + + uint8_t version; + in.copyTo(&version, 1); + VELOX_CHECK_EQ(version, kSerializationVersionHeader); + + int64_t bucketCount; + double min; + double max; + in.copyTo(&bucketCount, 1); + in.copyTo(&min, 1); + in.copyTo(&max, 1); + + auto ret = FixedDoubleHistogram(bucketCount, min, max, allocator); + for (size_t i = 0; i < bucketCount; ++i) { + double weight; + in.copyTo(&weight, 1); + ret.setWeight(i, weight); + } + const size_t bytesRead = sizeof(kSerializationVersionHeader) + + sizeof(bucketCount) + sizeof(min) + sizeof(max) + + (bucketCount * sizeof(double)); + VELOX_CHECK_EQ(bytesRead, expectedSize); + return ret; + } + + /// The minimium size of a valid buffer to deserialize a histogram. + static constexpr size_t minDeserializedBufferSize() { + return ( + sizeof(kSerializationVersionHeader) + sizeof(int64_t) + sizeof(double) + + /// 2 Reresents the minimum number of buckets. + sizeof(double) + 2 * sizeof(double)); + } + + /// Merges the current histogram with another histogram represented as a + /// buffer. + void mergeWith( + const char* data, + size_t expectedSize, + HashStringAllocator* allocator) { + auto input = common::InputByteStream(data); + auto histogram = deserialize(input, expectedSize, allocator); + /// This accounts for the case when the histogram is not initialized yet. + if (bucketCount_ == -1) { + bucketCount_ = histogram.bucketCount_; + min_ = histogram.min_; + max_ = histogram.max_; + weightsOne_ = std::move(histogram.weightsOne_); + weightsTwo_ = std::move(histogram.weightsTwo_); + totalWeights_ = histogram.totalWeights_; + maxUsedIndex_ = histogram.maxUsedIndex_; + return; + } + + /// When merging histograms, all the parameters except for the values + /// accrued inside the buckets must be the same. + if (bucketCount_ != histogram.bucketCount_) { + VELOX_USER_FAIL( + "Cannot merge histograms with different bucket counts. " + "Left bucket count: {}, right bucket count: {}", + bucketCount_, + histogram.bucketCount_); + } + + if (min_ != histogram.min_ || max_ != histogram.max_) { + VELOX_USER_FAIL( + "Cannot merge histograms with different min/max values. " + "Left min: {}, left max: {}, right min: {}, right max: {}", + min_, + max_, + histogram.min_, + histogram.max_); + } + + for (size_t i = 0; i < bucketCount_; ++i) { + auto& weights = i >= kMaxBucketCount ? weightsTwo_ : weightsOne_; + auto& otherWeights = + i >= kMaxBucketCount ? histogram.weightsTwo_ : histogram.weightsOne_; + weights[i % kMaxBucketCount] += otherWeights[i % kMaxBucketCount]; + totalWeights_ += otherWeights[i % kMaxBucketCount]; + } + maxUsedIndex_ = std::max(maxUsedIndex_, histogram.maxUsedIndex_); + } + + size_t serializationSize() const { + return sizeof(kSerializationVersionHeader) + sizeof(bucketCount_) + + sizeof(min_) + sizeof(max_) + (weightsOne_.size() * sizeof(double)) + + (weightsTwo_.size() * sizeof(double)); + } + + /// This represents the total accrued weights in the bucket. The value is + /// cached to avoid recomputing it every time it is needed. + double totalWeights() const { + return totalWeights_; + } + + private: + /// Returns the index of the bucket in the histogram that contains the + /// value. This is done by mapping value to [min, max) and then mapping that + /// value to the corresponding bucket. + static int64_t + getIndexForValue(int64_t bucketCount, double min, double max, double value) { + VELOX_CHECK(value >= min && value < max); + return std::min( + static_cast((bucketCount * (value - min)) / (max - min)), + bucketCount - 1); + } + + static double getLeftValueForIndex( + int64_t bucketCount, + double min, + double max, + int64_t index) { + return min + index * (max - min) / bucketCount; + } + + static double getRightValueForIndex( + int64_t bucketCount, + double min, + double max, + int64_t index) { + return std::min( + max, getLeftValueForIndex(bucketCount, min, max, index + 1)); + } + + static void validateParameters(int64_t bucketCount, double min, double max) { + if (bucketCount < 2) { + VELOX_USER_FAIL("Bucket count must be at least 2.0"); + } + + if (min >= max) { + VELOX_USER_FAIL("Min must be less than max. Min: {}, max: {}", min, max); + } + } + + static constexpr uint8_t kSerializationVersionHeader = 1; + /// In Java, the bucket count is type 'long', which as max possible value of + /// 2^63. In C++, a given vector can have max value of upwards of + /// std::vector::max_size(), which may be less than 2^63 depending. To + /// account for this, we have two buckets which may be used to store the + /// weights with each bucket being at most kMaxBucketCount in size. + static constexpr int64_t kMaxBucketCount = + std::numeric_limits::max(); + std::vector> weightsOne_; + std::vector> weightsTwo_; + double totalWeights_{0}; + int64_t bucketCount_{-1}; + double min_{0}; + double max_{1.0}; + int64_t maxUsedIndex_{-1}; +}; + +template +struct Accumulator { + explicit Accumulator(HashStringAllocator* allocator) + : trueWeights_(allocator), falseWeights_(allocator) {} + + void + setWeights(int64_t bucketCount, bool outcome, double pred, double weight) { + VELOX_CHECK_EQ(bucketCount, trueWeights_.bucketCount()); + VELOX_CHECK_EQ(bucketCount, falseWeights_.bucketCount()); + + /// Similar to Java Presto, the max prediction value for the histogram + /// is set to be 0.99999999999 in order to ensure bin corresponding to 1 + /// is not reached. + static const double kMaxPredictionValue = 0.99999999999; + pred = std::min(pred, kMaxPredictionValue); + outcome ? trueWeights_.add(pred, weight) : falseWeights_.add(pred, weight); + } + + void tryInit(int64_t bucketCount) { + trueWeights_.tryInit(bucketCount); + falseWeights_.tryInit(bucketCount); + } + + vector_size_t size() const { + return trueWeights_.size(); + } + + size_t serialize(char* output) const { + size_t bytes = trueWeights_.serialize(output); + return bytes + falseWeights_.serialize(output + bytes); + } + + size_t serializationSize() const { + return trueWeights_.serializationSize() + falseWeights_.serializationSize(); + } + + void mergeWith(StringView serialized, HashStringAllocator* allocator) { + auto input = serialized.data(); + VELOX_CHECK_EQ(serialized.size() % 2, 0); + const size_t bufferSize = serialized.size() / 2; + trueWeights_.mergeWith(input, bufferSize, allocator); + falseWeights_.mergeWith( + input + serialized.size() / 2, bufferSize, allocator); + } + + void extractValues(FlatVector* flatResult, vector_size_t offset) { + const double totalTrueWeight = trueWeights_.totalWeights(); + const double totalFalseWeight = falseWeights_.totalWeights(); + + double runningFalseWeight = 0; + double runningTrueWeight = 0; + int64_t trueWeightIndex = 0; + while (trueWeightIndex < trueWeights_.bucketCount() && + totalTrueWeight > runningTrueWeight) { + auto trueBucketResult = trueWeights_.getBucket(trueWeightIndex); + auto falseBucketResult = falseWeights_.getBucket(trueWeightIndex); + + const double falsePositive = totalFalseWeight - runningFalseWeight; + const double negative = totalFalseWeight; + + if constexpr (type == ClassificationType::kFallout) { + flatResult->set(offset + trueWeightIndex, falsePositive / negative); + } else if constexpr (type == ClassificationType::kPrecision) { + const double truePositive = (totalTrueWeight - runningTrueWeight); + const double totalPositives = truePositive + falsePositive; + flatResult->set( + offset + trueWeightIndex, truePositive / totalPositives); + } else if constexpr (type == ClassificationType::kRecall) { + const double truePositive = (totalTrueWeight - runningTrueWeight); + flatResult->set( + offset + trueWeightIndex, truePositive / totalTrueWeight); + } else if constexpr (type == ClassificationType::kMissRate) { + flatResult->set( + offset + trueWeightIndex, runningTrueWeight / totalTrueWeight); + } else if constexpr (type == ClassificationType::kThresholds) { + flatResult->set(offset + trueWeightIndex, trueBucketResult.left); + } else { + VELOX_UNREACHABLE("Not expected to be called."); + } + + runningTrueWeight += trueBucketResult.weight; + runningFalseWeight += falseBucketResult.weight; + trueWeightIndex += 1; + } + } + + private: + FixedDoubleHistogram trueWeights_; + FixedDoubleHistogram falseWeights_; +}; + +template +class ClassificationAggregation : public exec::Aggregate { + public: + explicit ClassificationAggregation( + TypePtr resultType, + bool useDefaultWeight = false) + : Aggregate(std::move(resultType)), useDefaultWeight_(useDefaultWeight) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(Accumulator); + } + + bool isFixedSize() const override { + return false; + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodeArguments(rows, args); + auto accumulator = value>(group); + + auto tracker = trackRowSize(group); + rows.applyToSelected([&](auto row) { + if (decodedBuckets_.isNullAt(row) || decodedOutcome_.isNullAt(row) || + decodedPred_.isNullAt(row) || + (!useDefaultWeight_ && decodedWeight_.isNullAt(row))) { + return; + } + clearNull(group); + accumulator->tryInit(decodedBuckets_.valueAt(row)); + accumulator->setWeights( + decodedBuckets_.valueAt(row), + decodedOutcome_.valueAt(row), + decodedPred_.valueAt(row), + useDefaultWeight_ ? 1.0 : decodedWeight_.valueAt(row)); + }); + } + + // Step 4. + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodeArguments(rows, args); + + rows.applyToSelected([&](vector_size_t row) { + if (decodedBuckets_.isNullAt(row) || decodedOutcome_.isNullAt(row) || + decodedPred_.isNullAt(row) || + (!useDefaultWeight_ && decodedWeight_.isNullAt(row))) { + return; + } + + auto& group = groups[row]; + auto tracker = trackRowSize(group); + + clearNull(group); + auto* accumulator = value>(group); + accumulator->tryInit(decodedBuckets_.valueAt(row)); + + accumulator->setWeights( + decodedBuckets_.valueAt(row), + decodedOutcome_.valueAt(row), + decodedPred_.valueAt(row), + useDefaultWeight_ ? 1.0 : decodedWeight_.valueAt(row)); + }); + } + + // Step 5. + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + auto flatResult = (*result)->asFlatVector(); + flatResult->resize(numGroups); + + uint64_t* rawNulls = nullptr; + if (flatResult->mayHaveNulls()) { + BufferPtr& nulls = flatResult->mutableNulls(flatResult->size()); + rawNulls = nulls->asMutable(); + } + + for (auto i = 0; i < numGroups; ++i) { + auto group = groups[i]; + if (isNull(group)) { + flatResult->setNull(i, true); + continue; + } + + if (rawNulls) { + bits::clearBit(rawNulls, i); + } + auto accumulator = value>(group); + auto serializationSize = accumulator->serializationSize(); + char* rawBuffer = + flatResult->getRawStringBufferWithSpace(serializationSize); + + VELOX_CHECK_EQ(accumulator->serialize(rawBuffer), serializationSize); + auto sv = StringView(rawBuffer, serializationSize); + flatResult->set(i, std::move(sv)); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as(); + VELOX_CHECK(vector); + vector->resize(numGroups); + + vector_size_t numValues = 0; + uint64_t* rawNulls = getRawNulls(result->get()); + + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + auto* accumulator = value>(group); + const auto size = accumulator->size(); + if (isNull(group)) { + vector->setNull(i, true); + continue; + } + + clearNull(rawNulls, i); + numValues += size; + } + + auto flatResults = vector->elements()->asFlatVector(); + flatResults->resize(numValues); + + auto* rawOffsets = vector->offsets()->asMutable(); + auto* rawSizes = vector->sizes()->asMutable(); + + vector_size_t offset = 0; + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + + if (isNull(group)) { + continue; + } + auto* accumulator = value>(group); + const vector_size_t size = accumulator->size(); + + rawOffsets[i] = offset; + rawSizes[i] = size; + + accumulator->extractValues(flatResults, offset); + + offset += size; + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedAcc_.decode(*args[0], rows); + + rows.applyToSelected([&](auto row) { + if (decodedAcc_.isNullAt(row)) { + return; + } + + auto group = groups[row]; + auto tracker = trackRowSize(group); + clearNull(group); + + auto serialized = decodedAcc_.valueAt(row); + + auto accumulator = value>(group); + accumulator->mergeWith(serialized, allocator_); + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedAcc_.decode(*args[0], rows); + auto tracker = trackRowSize(group); + + rows.applyToSelected([&](auto row) { + if (decodedAcc_.isNullAt(row)) { + return; + } + + clearNull(group); + + auto serialized = decodedAcc_.valueAt(row); + + auto accumulator = value>(group); + accumulator->mergeWith(serialized, allocator_); + }); + } + + protected: + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + auto group = groups[i]; + new (group + offset_) Accumulator(allocator_); + } + } + + void destroyInternal(folly::Range groups) override { + destroyAccumulators>(groups); + } + + private: + void decodeArguments( + const SelectivityVector& rows, + const std::vector& args) { + decodedBuckets_.decode(*args[0], rows, true); + decodedOutcome_.decode(*args[1], rows, true); + decodedPred_.decode(*args[2], rows, true); + if (!useDefaultWeight_) { + decodedWeight_.decode(*args[3], rows, true); + } + } + + DecodedVector decodedAcc_; + DecodedVector decodedBuckets_; + DecodedVector decodedOutcome_; + DecodedVector decodedPred_; + DecodedVector decodedWeight_; + const bool useDefaultWeight_{false}; +}; +} // namespace + +template +void registerAggregateFunctionImpl( + const std::string& name, + bool withCompanionFunctions, + bool overwrite, + const std::vector>& + signatures) { + exec::registerAggregateFunction( + name, + signatures, + [](core::AggregationNode::Step, + const std::vector& args, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + if (args.size() == 4) { + return std::make_unique>(resultType); + } else { + return std::make_unique>( + resultType, true); + } + }, + withCompanionFunctions, + overwrite); +} + +void registerClassificationFunctions( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + const auto signatures = + std::vector>{ + exec::AggregateFunctionSignatureBuilder() + .returnType("array(double)") + .intermediateType("varbinary") + .argumentType("bigint") + .argumentType("boolean") + .argumentType("double") + .build(), + exec::AggregateFunctionSignatureBuilder() + .returnType("array(double)") + .intermediateType("varbinary") + .argumentType("bigint") + .argumentType("boolean") + .argumentType("double") + .argumentType("double") + .build()}; + registerAggregateFunctionImpl( + prefix + kClassificationFallout, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationPrecision, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationRecall, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationMissRate, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationThreshold, + withCompanionFunctions, + overwrite, + signatures); +} + +} // namespace facebook::velox::aggregate::prestosql diff --git a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp index 53bf0ce22ba9..1b781de2f247 100644 --- a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp +++ b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp @@ -43,6 +43,10 @@ extern void registerBitwiseXorAggregate( bool withCompanionFunctions, bool onlyPrestoSignatures, bool overwrite); +extern void registerClassificationFunctions( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite); extern void registerChecksumAggregate( const std::string& prefix, bool withCompanionFunctions, @@ -165,6 +169,7 @@ void registerAllAggregateFunctions( registerBoolAggregates(prefix, withCompanionFunctions, overwrite); registerCentralMomentsAggregates(prefix, withCompanionFunctions, overwrite); registerChecksumAggregate(prefix, withCompanionFunctions, overwrite); + registerClassificationFunctions(prefix, withCompanionFunctions, overwrite); registerCountAggregate(prefix, withCompanionFunctions, overwrite); registerCountIfAggregate(prefix, withCompanionFunctions, overwrite); registerCovarianceAggregates(prefix, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt index 85265e3613ee..293e1e096367 100644 --- a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable( BoolAndOrTest.cpp CentralMomentsAggregationTest.cpp ChecksumAggregateTest.cpp + ClassificationAggregationTest.cpp CountAggregationTest.cpp CountDistinctTest.cpp CountIfAggregationTest.cpp diff --git a/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp new file mode 100644 index 000000000000..f1a71ab968d1 --- /dev/null +++ b/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" + +using namespace facebook::velox::functions::aggregate::test; + +namespace facebook::velox::aggregate::test { +namespace { + +class ClassificationAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + } +}; + +TEST_F(ClassificationAggregationTest, basic) { + auto runTest = [&](const std::string& expression, + RowVectorPtr input, + RowVectorPtr expected) { + testAggregations({input}, {}, {expression}, {expected}); + }; + + /// Test without any nulls. + auto input = makeRowVector({ + makeNullableFlatVector( + {true, false, true, false, false, false, false, false, true, false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, 0.5, 0.5}), + }); + + /// Fallout test. + auto expected = makeRowVector({ + makeArrayVector({{1.0, 1.0, 3.0 / 7}}), + }); + runTest("classification_fall_out(5, c0, c1)", input, expected); + + /// Precision test. + expected = makeRowVector({ + makeArrayVector({{0.3, 2.0 / 9, 0.25}}), + }); + runTest("classification_precision(5, c0, c1)", input, expected); + + /// Recall test. + expected = makeRowVector({ + makeArrayVector({{1.0, 2.0 / 3, 1.0 / 3}}), + }); + runTest("classification_recall(5, c0, c1)", input, expected); + + /// Miss rate test. + expected = makeRowVector({ + makeArrayVector({{0, 1.0 / 3, 2.0 / 3}}), + }); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + + /// Thresholds test. + expected = makeRowVector({ + makeArrayVector({{0, 0.2, 0.4}}), + }); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test with some nulls. + input = makeRowVector({ + makeNullableFlatVector( + {std::nullopt, + false, + true, + false, + false, + false, + false, + false, + std::nullopt, + false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, std::nullopt, std::nullopt}), + }); + + /// Fallout test. + expected = makeRowVector({makeArrayVector({{1.0, 1.0}})}); + runTest("classification_fall_out(5, c0, c1)", input, expected); + + /// Precision test. + expected = makeRowVector({makeArrayVector({{1.0 / 7, 1.0 / 7}})}); + runTest("classification_precision(5, c0, c1)", input, expected); + + /// Recall test. + expected = makeRowVector({makeArrayVector({{1, 1}})}); + runTest("classification_recall(5, c0, c1)", input, expected); + + /// Miss rate test. + expected = makeRowVector({makeArrayVector({{0, 0}})}); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + + /// Thresholds test. + expected = makeRowVector({makeArrayVector({{0, 0.2}})}); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test with all nulls. + input = makeRowVector({ + makeNullableFlatVector({std::nullopt, std::nullopt}), + makeNullableFlatVector({std::nullopt, std::nullopt}), + }); + + expected = makeRowVector({makeNullableArrayVector( + std::vector>>>{ + {std::nullopt}})}); + runTest("classification_fall_out(5, c0, c1)", input, expected); + runTest("classification_precision(5, c0, c1)", input, expected); + runTest("classification_recall(5, c0, c1)", input, expected); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test invalid bucket count test + input = makeRowVector({ + makeNullableFlatVector({true}), + makeNullableFlatVector({1.0}), + }); + static const std::vector expressions = { + "classification_fall_out(0, c0, c1)", + "classification_precision(0, c0, c1)", + "classification_recall(0, c0, c1)", + "classification_miss_rate(0, c0, c1)", + "classification_thresholds(0, c0, c1)", + "classification_fall_out(1, c0, c1)", + "classification_precision(1, c0, c1)", + "classification_recall(1, c0, c1)", + "classification_miss_rate(1, c0, c1)", + "classification_thresholds(1, c0, c1)"}; + for (const auto& expression : expressions) { + VELOX_ASSERT_THROW( + runTest(expression, input, expected), + "Bucket count must be at least 2.0"); + } +} + +TEST_F(ClassificationAggregationTest, groupBy) { + auto input = makeRowVector({ + makeNullableFlatVector({0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3}), + makeNullableFlatVector( + {true, + false, + true, + false, + false, + false, + false, + false, + true, + true, + false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, 1.0, 0.5, 0.5}), + }); + + auto runTest = [this]( + const std::string& expression, + RowVectorPtr input, + RowVectorPtr expected) { + testAggregations({input}, {"c0"}, {expression}, {expected}); + }; + auto keys = makeFlatVector({0, 1, 2, 3}); + runTest( + "classification_fall_out(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{1}, {1, 1}, {1, 1, 2.0 / 3, 2.0 / 3, 1.0 / 3}, {1, 1, 1}}}), + })); + runTest( + "classification_precision(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{0.5}, + {1.0 / 3, 1.0 / 3}, + {0.25, 0.25, 1.0 / 3, 1.0 / 3, 0.5}, + {0.5, 0.5, 0.5}}}), + })); + runTest( + "classification_recall(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector({{{1}, {1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1}}}), + })); + runTest( + "classification_miss_rate(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector({{{0}, {0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0}}}), + })); + runTest( + "classification_thresholds(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{0}, {0, 0.2}, {0, 0.2, 0.4, 0.6, 0.8}, {0, 0.2, 0.4}}}), + })); +} + +} // namespace +} // namespace facebook::velox::aggregate::test diff --git a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp index f7d863d7b809..e0c17c22c6c1 100644 --- a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp @@ -139,6 +139,11 @@ int main(int argc, char** argv) { // https://github.com/facebookincubator/velox/issues/6330 {"max_data_size_for_stats", nullptr}, {"sum_data_size_for_stats", nullptr}, + {"classification_fall_out", nullptr}, + {"classification_precision", nullptr}, + {"classification_recall", nullptr}, + {"classification_miss_rate", nullptr}, + {"classification_thresholds", nullptr}, }; static const std::unordered_set orderDependentFunctions = {