From 23803ed37df6f1a0ae1a0b47491246f398cf4df9 Mon Sep 17 00:00:00 2001 From: Joe Prosser <joe.prosser@uipath.com> Date: Wed, 8 Nov 2023 16:04:11 +0000 Subject: [PATCH 1/2] add more stream label validation stats --- cli/src/commands/get/streams.rs | 170 +++++++++++++++++++++++++++----- 1 file changed, 148 insertions(+), 22 deletions(-) diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs index 7e220b9c..6cc59da7 100644 --- a/cli/src/commands/get/streams.rs +++ b/cli/src/commands/get/streams.rs @@ -103,16 +103,24 @@ pub struct StreamStat { recall: NotNan<f64>, compare_to_precision: Option<NotNan<f64>>, compare_to_recall: Option<NotNan<f64>>, + maintain_recall_precision: Option<NotNan<f64>>, + maintain_recall_threshold: Option<NotNan<f64>>, + maintain_precision_recall: Option<NotNan<f64>>, + maintain_precision_threshold: Option<NotNan<f64>>, } impl DisplayTable for StreamStat { fn to_table_headers() -> prettytable::Row { row![ "Name", - "Threshold", - "Current Precision", - "Current Recall", - "Compare to Precision", - "Compare to Recall" + "Threshold (T)", + "Precision (P)", + "Recall (R)", + "P at same T", + "R at same T", + "P at same R", + "R at same P", + "T at same R", + "T at same P" ] } fn to_table_row(&self) -> prettytable::Row { @@ -131,6 +139,26 @@ impl DisplayTable for StreamStat { } else { "none".dimmed() }, + if let Some(precision) = self.maintain_recall_precision { + red_if_lower_green_otherwise(precision, self.precision) + } else { + "none".dimmed() + }, + if let Some(recall) = self.maintain_precision_recall { + red_if_lower_green_otherwise(recall, self.recall) + } else { + "none".dimmed() + }, + if let Some(threshold) = self.maintain_recall_threshold { + format!("{:.5}", threshold).normal() + } else { + "none".dimmed() + }, + if let Some(threshold) = self.maintain_precision_threshold { + format!("{:.5}", threshold).normal() + } else { + "none".dimmed() + } ] } } @@ -138,39 +166,112 @@ impl DisplayTable for StreamStat { fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> ColoredString { let test_str = format!("{:.3}", test); + let diff = test - threshold; + match test { - test if test < threshold => format!("{test_str} (decrease)").red(), - test if test > threshold => format!("{test_str} (increase)").green(), + test if test < threshold => format!("{test_str} ({diff:.3})").red(), + test if test > threshold => format!("{test_str} (+{diff:.3})").green(), _ => test_str.green(), } } +#[derive(Default)] +struct ThresholdAndPrecision { + threshold: Option<NotNan<f64>>, + precision: Option<NotNan<f64>>, +} + +fn get_threshold_and_precision_for_recall( + recall: NotNan<f64>, + label_name: &LabelName, + label_validation: &LabelValidation, +) -> Result<ThresholdAndPrecision> { + let recall_index = label_validation + .recalls + .iter() + .position(|&val_recall| val_recall >= recall) + .context(format!("Could not get recall for label {}", label_name.0))?; + + let precision = label_validation.precisions.get(recall_index); + + let threshold = label_validation.thresholds.get(recall_index); + + Ok(ThresholdAndPrecision { + threshold: threshold.cloned(), + precision: precision.cloned(), + }) +} + +#[derive(Default)] +struct ThresholdAndRecall { + threshold: Option<NotNan<f64>>, + recall: Option<NotNan<f64>>, +} + +fn get_threshold_and_recall_for_precision( + precision: NotNan<f64>, + label_name: &LabelName, + label_validation: &LabelValidation, +) -> Result<ThresholdAndRecall> { + // Get lowest index with greater than or equal precision + let mut precision_index = None; + label_validation + .precisions + .iter() + .enumerate() + .for_each(|(idx, val_precision)| { + if val_precision >= &precision { + precision_index = Some(idx); + } + }); + + let precision_index = precision_index.context(format!( + "Could not get precision index for label {}", + label_name.0 + ))?; + + let recall = label_validation.recalls.get(precision_index); + let threshold = label_validation.thresholds.get(precision_index); + + Ok(ThresholdAndRecall { + threshold: threshold.cloned(), + recall: recall.cloned(), + }) +} + +#[derive(Default)] +struct PrecisionAndRecall { + precision: NotNan<f64>, + recall: NotNan<f64>, +} + fn get_precision_and_recall_for_threshold( threshold: NotNan<f64>, - label_name: LabelName, - label_validation: LabelValidation, -) -> Result<(NotNan<f64>, NotNan<f64>)> { + label_name: &LabelName, + label_validation: &LabelValidation, +) -> Result<PrecisionAndRecall> { let threshold_index = label_validation .thresholds .iter() - .position(|&val_threshold| val_threshold < threshold) + .position(|&val_threshold| val_threshold <= threshold) .context(format!( "Could not find threshold for label {}", label_name.0 ))?; - let precision = label_validation + let precision = *label_validation .precisions .get(threshold_index) .context(format!( "Could not get precision for label {}", label_name.0 ))?; - let recall = label_validation + let recall = *label_validation .recalls .get(threshold_index) .context(format!("Could not get recall for label {}", label_name.0))?; - Ok((*precision, *recall)) + + Ok(PrecisionAndRecall { precision, recall }) } #[derive(Clone)] @@ -238,10 +339,10 @@ fn get_stream_stat( let label_validation = client.get_label_validation(&label_name, &stream_full_name.dataset, &model.version)?; - let (precision, recall) = get_precision_and_recall_for_threshold( + let PrecisionAndRecall { precision, recall } = get_precision_and_recall_for_threshold( label_threshold.threshold, - label_name.clone(), - label_validation, + &label_name.clone(), + &label_validation, )?; let mut stream_stat = StreamStat { @@ -251,6 +352,10 @@ fn get_stream_stat( recall, compare_to_precision: None, compare_to_recall: None, + maintain_recall_precision: None, + maintain_recall_threshold: None, + maintain_precision_recall: None, + maintain_precision_threshold: None, }; if let Some(ref compare_config) = compare_config { @@ -265,14 +370,35 @@ fn get_stream_stat( &compare_config.model_version, )?; - let (compare_to_precision, compare_to_recall) = get_precision_and_recall_for_threshold( + let same_threshold_precision_and_recall = get_precision_and_recall_for_threshold( label_threshold.threshold, - label_name, - compare_to_label_validation, + &label_name, + &compare_to_label_validation, )?; - stream_stat.compare_to_precision = Some(compare_to_precision); - stream_stat.compare_to_recall = Some(compare_to_recall); + let maintain_recall_threshold_and_precision = get_threshold_and_precision_for_recall( + recall, + &label_name, + &compare_to_label_validation, + ) + .unwrap_or_default(); + + let maintain_precision_threshold_and_recall = get_threshold_and_recall_for_precision( + precision, + &label_name, + &compare_to_label_validation, + ) + .unwrap_or_default(); + + stream_stat.compare_to_precision = Some(same_threshold_precision_and_recall.precision); + stream_stat.compare_to_recall = Some(same_threshold_precision_and_recall.recall); + stream_stat.maintain_recall_precision = + maintain_recall_threshold_and_precision.precision; + stream_stat.maintain_recall_threshold = + maintain_recall_threshold_and_precision.threshold; + stream_stat.maintain_precision_recall = maintain_precision_threshold_and_recall.recall; + stream_stat.maintain_precision_threshold = + maintain_precision_threshold_and_recall.threshold; } } Ok(stream_stat) From 71f77e107ea27b92caff71d3f201fec8bd07df3e Mon Sep 17 00:00:00 2001 From: Joe Prosser <joe.prosser@uipath.com> Date: Wed, 8 Nov 2023 17:07:35 +0000 Subject: [PATCH 2/2] fix formatting nit --- cli/src/commands/get/streams.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs index 6cc59da7..bfe7c5f8 100644 --- a/cli/src/commands/get/streams.rs +++ b/cli/src/commands/get/streams.rs @@ -169,8 +169,8 @@ fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> Co let diff = test - threshold; match test { - test if test < threshold => format!("{test_str} ({diff:.3})").red(), - test if test > threshold => format!("{test_str} (+{diff:.3})").green(), + test if test < threshold => format!("{test_str} ({diff:+.3})").red(), + test if test > threshold => format!("{test_str} ({diff:+.3})").green(), _ => test_str.green(), } }