From 23803ed37df6f1a0ae1a0b47491246f398cf4df9 Mon Sep 17 00:00:00 2001
From: Joe Prosser <joe.prosser@uipath.com>
Date: Wed, 8 Nov 2023 16:04:11 +0000
Subject: [PATCH 1/2] add more stream label validation stats

---
 cli/src/commands/get/streams.rs | 170 +++++++++++++++++++++++++++-----
 1 file changed, 148 insertions(+), 22 deletions(-)

diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs
index 7e220b9c..6cc59da7 100644
--- a/cli/src/commands/get/streams.rs
+++ b/cli/src/commands/get/streams.rs
@@ -103,16 +103,24 @@ pub struct StreamStat {
     recall: NotNan<f64>,
     compare_to_precision: Option<NotNan<f64>>,
     compare_to_recall: Option<NotNan<f64>>,
+    maintain_recall_precision: Option<NotNan<f64>>,
+    maintain_recall_threshold: Option<NotNan<f64>>,
+    maintain_precision_recall: Option<NotNan<f64>>,
+    maintain_precision_threshold: Option<NotNan<f64>>,
 }
 impl DisplayTable for StreamStat {
     fn to_table_headers() -> prettytable::Row {
         row![
             "Name",
-            "Threshold",
-            "Current Precision",
-            "Current Recall",
-            "Compare to Precision",
-            "Compare to Recall"
+            "Threshold (T)",
+            "Precision (P)",
+            "Recall (R)",
+            "P at same T",
+            "R at same T",
+            "P at same R",
+            "R at same P",
+            "T at same R",
+            "T at same P"
         ]
     }
     fn to_table_row(&self) -> prettytable::Row {
@@ -131,6 +139,26 @@ impl DisplayTable for StreamStat {
             } else {
                 "none".dimmed()
             },
+            if let Some(precision) = self.maintain_recall_precision {
+                red_if_lower_green_otherwise(precision, self.precision)
+            } else {
+                "none".dimmed()
+            },
+            if let Some(recall) = self.maintain_precision_recall {
+                red_if_lower_green_otherwise(recall, self.recall)
+            } else {
+                "none".dimmed()
+            },
+            if let Some(threshold) = self.maintain_recall_threshold {
+                format!("{:.5}", threshold).normal()
+            } else {
+                "none".dimmed()
+            },
+            if let Some(threshold) = self.maintain_precision_threshold {
+                format!("{:.5}", threshold).normal()
+            } else {
+                "none".dimmed()
+            }
         ]
     }
 }
@@ -138,39 +166,112 @@ impl DisplayTable for StreamStat {
 fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> ColoredString {
     let test_str = format!("{:.3}", test);
 
+    let diff = test - threshold;
+
     match test {
-        test if test < threshold => format!("{test_str} (decrease)").red(),
-        test if test > threshold => format!("{test_str} (increase)").green(),
+        test if test < threshold => format!("{test_str} ({diff:.3})").red(),
+        test if test > threshold => format!("{test_str} (+{diff:.3})").green(),
         _ => test_str.green(),
     }
 }
 
+#[derive(Default)]
+struct ThresholdAndPrecision {
+    threshold: Option<NotNan<f64>>,
+    precision: Option<NotNan<f64>>,
+}
+
+fn get_threshold_and_precision_for_recall(
+    recall: NotNan<f64>,
+    label_name: &LabelName,
+    label_validation: &LabelValidation,
+) -> Result<ThresholdAndPrecision> {
+    let recall_index = label_validation
+        .recalls
+        .iter()
+        .position(|&val_recall| val_recall >= recall)
+        .context(format!("Could not get recall for label {}", label_name.0))?;
+
+    let precision = label_validation.precisions.get(recall_index);
+
+    let threshold = label_validation.thresholds.get(recall_index);
+
+    Ok(ThresholdAndPrecision {
+        threshold: threshold.cloned(),
+        precision: precision.cloned(),
+    })
+}
+
+#[derive(Default)]
+struct ThresholdAndRecall {
+    threshold: Option<NotNan<f64>>,
+    recall: Option<NotNan<f64>>,
+}
+
+fn get_threshold_and_recall_for_precision(
+    precision: NotNan<f64>,
+    label_name: &LabelName,
+    label_validation: &LabelValidation,
+) -> Result<ThresholdAndRecall> {
+    // Get lowest index with greater than or equal precision
+    let mut precision_index = None;
+    label_validation
+        .precisions
+        .iter()
+        .enumerate()
+        .for_each(|(idx, val_precision)| {
+            if val_precision >= &precision {
+                precision_index = Some(idx);
+            }
+        });
+
+    let precision_index = precision_index.context(format!(
+        "Could not get precision index for label {}",
+        label_name.0
+    ))?;
+
+    let recall = label_validation.recalls.get(precision_index);
+    let threshold = label_validation.thresholds.get(precision_index);
+
+    Ok(ThresholdAndRecall {
+        threshold: threshold.cloned(),
+        recall: recall.cloned(),
+    })
+}
+
+#[derive(Default)]
+struct PrecisionAndRecall {
+    precision: NotNan<f64>,
+    recall: NotNan<f64>,
+}
+
 fn get_precision_and_recall_for_threshold(
     threshold: NotNan<f64>,
-    label_name: LabelName,
-    label_validation: LabelValidation,
-) -> Result<(NotNan<f64>, NotNan<f64>)> {
+    label_name: &LabelName,
+    label_validation: &LabelValidation,
+) -> Result<PrecisionAndRecall> {
     let threshold_index = label_validation
         .thresholds
         .iter()
-        .position(|&val_threshold| val_threshold < threshold)
+        .position(|&val_threshold| val_threshold <= threshold)
         .context(format!(
             "Could not find threshold for label {}",
             label_name.0
         ))?;
 
-    let precision = label_validation
+    let precision = *label_validation
         .precisions
         .get(threshold_index)
         .context(format!(
             "Could not get precision for label {}",
             label_name.0
         ))?;
-    let recall = label_validation
+    let recall = *label_validation
         .recalls
         .get(threshold_index)
         .context(format!("Could not get recall for label {}", label_name.0))?;
-    Ok((*precision, *recall))
+
+    Ok(PrecisionAndRecall { precision, recall })
 }
 
 #[derive(Clone)]
@@ -238,10 +339,10 @@ fn get_stream_stat(
     let label_validation =
         client.get_label_validation(&label_name, &stream_full_name.dataset, &model.version)?;
 
-    let (precision, recall) = get_precision_and_recall_for_threshold(
+    let PrecisionAndRecall { precision, recall } = get_precision_and_recall_for_threshold(
         label_threshold.threshold,
-        label_name.clone(),
-        label_validation,
+        &label_name.clone(),
+        &label_validation,
     )?;
 
     let mut stream_stat = StreamStat {
@@ -251,6 +352,10 @@ fn get_stream_stat(
         recall,
         compare_to_precision: None,
         compare_to_recall: None,
+        maintain_recall_precision: None,
+        maintain_recall_threshold: None,
+        maintain_precision_recall: None,
+        maintain_precision_threshold: None,
     };
 
     if let Some(ref compare_config) = compare_config {
@@ -265,14 +370,35 @@ fn get_stream_stat(
                 &compare_config.model_version,
             )?;
 
-            let (compare_to_precision, compare_to_recall) = get_precision_and_recall_for_threshold(
+            let same_threshold_precision_and_recall = get_precision_and_recall_for_threshold(
                 label_threshold.threshold,
-                label_name,
-                compare_to_label_validation,
+                &label_name,
+                &compare_to_label_validation,
             )?;
 
-            stream_stat.compare_to_precision = Some(compare_to_precision);
-            stream_stat.compare_to_recall = Some(compare_to_recall);
+            let maintain_recall_threshold_and_precision = get_threshold_and_precision_for_recall(
+                recall,
+                &label_name,
+                &compare_to_label_validation,
+            )
+            .unwrap_or_default();
+
+            let maintain_precision_threshold_and_recall = get_threshold_and_recall_for_precision(
+                precision,
+                &label_name,
+                &compare_to_label_validation,
+            )
+            .unwrap_or_default();
+
+            stream_stat.compare_to_precision = Some(same_threshold_precision_and_recall.precision);
+            stream_stat.compare_to_recall = Some(same_threshold_precision_and_recall.recall);
+            stream_stat.maintain_recall_precision =
+                maintain_recall_threshold_and_precision.precision;
+            stream_stat.maintain_recall_threshold =
+                maintain_recall_threshold_and_precision.threshold;
+            stream_stat.maintain_precision_recall = maintain_precision_threshold_and_recall.recall;
+            stream_stat.maintain_precision_threshold =
+                maintain_precision_threshold_and_recall.threshold;
         }
     }
     Ok(stream_stat)

From 71f77e107ea27b92caff71d3f201fec8bd07df3e Mon Sep 17 00:00:00 2001
From: Joe Prosser <joe.prosser@uipath.com>
Date: Wed, 8 Nov 2023 17:07:35 +0000
Subject: [PATCH 2/2] fix formatting nit

---
 cli/src/commands/get/streams.rs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs
index 6cc59da7..bfe7c5f8 100644
--- a/cli/src/commands/get/streams.rs
+++ b/cli/src/commands/get/streams.rs
@@ -169,8 +169,8 @@ fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> Co
     let diff = test - threshold;
 
     match test {
-        test if test < threshold => format!("{test_str} ({diff:.3})").red(),
-        test if test > threshold => format!("{test_str} (+{diff:.3})").green(),
+        test if test < threshold => format!("{test_str} ({diff:+.3})").red(),
+        test if test > threshold => format!("{test_str} ({diff:+.3})").green(),
         _ => test_str.green(),
     }
 }