Skip to content

Commit

Permalink
quantize_stats: print rmse and max error as fraction of <x> (#21)
Browse files Browse the repository at this point in the history
This allows for a better comparison between different models
or different tensors of the same model where the magnitude of
the model weights may differ.

Co-authored-by: Iwan Kawrakow <[email protected]>
  • Loading branch information
ikawrakow and Kawrakow authored Aug 19, 2024
1 parent c7b47fc commit 5652100
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion examples/quantize-stats/quantize-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct error_stats {
size_t num_samples;
double total_error;
double max_error;
double sum_x2;
uint64_t error_histogram[HISTOGRAM_BUCKETS];
};

Expand Down Expand Up @@ -89,6 +90,7 @@ static void update_error_stats(int64_t nelements, const float * input, const flo
double diff = input[i] - output[i];
stats.total_error += diff * diff;
stats.max_error = fmax(fabs(diff), stats.max_error);
stats.sum_x2 += input[i]*input[i];
stats.error_histogram[std::max(std::min((size_t) floor(fabs(diff) / HISTOGRAM_RANGE * HISTOGRAM_BUCKETS), HISTOGRAM_BUCKETS-1), (size_t) 0)]++;
}
stats.num_samples += nelements;
Expand All @@ -97,6 +99,7 @@ static void update_error_stats(int64_t nelements, const float * input, const flo
static void combine_error_stats(error_stats & into, const error_stats & from) {
into.num_samples += from.num_samples;
into.total_error += from.total_error;
into.sum_x2 += from.sum_x2;
if (from.max_error > into.max_error) into.max_error = from.max_error;
for (size_t i=0; i<HISTOGRAM_BUCKETS; ++i) into.error_histogram[i] += from.error_histogram[i];
}
Expand All @@ -116,9 +119,11 @@ static double find_quantile(const error_stats & stats, double quantile) {

static void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) {
double rmse = sqrt(stats.total_error / (double) stats.num_samples);
double av_x = sqrt(stats.sum_x2 / (double) stats.num_samples);
double median = find_quantile(stats, .5);
double pct95 = find_quantile(stats, .95);
printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median);
printf("%-40s: rmse %.8f, %.6f maxerr %.8f, %.6f 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, rmse/av_x,
stats.max_error, stats.max_error/av_x, pct95, median);
if (print_histogram) {
printf("Error distribution:\n");
for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {
Expand Down

0 comments on commit 5652100

Please sign in to comment.