diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1b84befb0269..ed3d6d49f9f3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -28,12 +28,12 @@ use crate::aggregates::{ PhysicalGroupBy, }; use crate::common::IPCWriter; -use crate::metrics::{BaselineMetrics, RecordOutput}; +use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge; use crate::spill::read_spill_as_stream; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; @@ -117,10 +117,22 @@ struct SkipAggregationProbe { /// Flag indicating that further updates of `SkipAggregationProbe` /// state won't make any effect is_locked: bool, + + /// Number of rows where state was output without aggregation. + /// + /// * If 0, all input rows were aggregated (should_skip was always false) + /// + /// * if greater than zero, the number of rows which were output directly + /// without aggregation + skipped_aggregation_rows: metrics::Count, } impl SkipAggregationProbe { - fn new(probe_rows_threshold: usize, probe_ratio_threshold: f64) -> Self { + fn new( + probe_rows_threshold: usize, + probe_ratio_threshold: f64, + skipped_aggregation_rows: metrics::Count, + ) -> Self { Self { input_rows: 0, num_groups: 0, @@ -128,6 +140,7 @@ impl SkipAggregationProbe { probe_ratio_threshold, should_skip: false, is_locked: false, + skipped_aggregation_rows, } } @@ -160,6 +173,11 @@ impl SkipAggregationProbe { self.should_skip = false; self.is_locked = true; } + + /// Record the number of rows that were output directly without aggregation + fn record_skipped(&mut self, batch: &RecordBatch) { + self.skipped_aggregation_rows.add(batch.num_rows()); + } } /// HashTable based Grouping Aggregator @@ -473,17 +491,17 @@ impl GroupedHashAggregateStream { .all(|acc| acc.supports_convert_to_state()) && agg_group_by.is_single() { + let options = &context.session_config().options().execution; + let probe_rows_threshold = + options.skip_partial_aggregation_probe_rows_threshold; + let probe_ratio_threshold = + options.skip_partial_aggregation_probe_ratio_threshold; + let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .counter("skipped_aggregation_rows", partition); Some(SkipAggregationProbe::new( - context - .session_config() - .options() - .execution - .skip_partial_aggregation_probe_rows_threshold, - context - .session_config() - .options() - .execution - .skip_partial_aggregation_probe_ratio_threshold, + probe_rows_threshold, + probe_ratio_threshold, + skipped_aggregation_rows, )) } else { None @@ -611,6 +629,9 @@ impl Stream for GroupedHashAggregateStream { match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { let _timer = elapsed_compute.timer(); + if let Some(probe) = self.skip_aggregation_probe.as_mut() { + probe.record_skipped(&batch); + } let states = self.transform_to_states(batch)?; return Poll::Ready(Some(Ok( states.record_output(&self.baseline_metrics)