Skip to content

Commit

Permalink
add more stream label validation stats (#241)
Browse files Browse the repository at this point in the history
* add more stream label validation stats
  • Loading branch information
joe-prosser authored Nov 8, 2023
1 parent 121dc95 commit 2af59dc
Showing 1 changed file with 148 additions and 22 deletions.
170 changes: 148 additions & 22 deletions cli/src/commands/get/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,24 @@ pub struct StreamStat {
recall: NotNan<f64>,
compare_to_precision: Option<NotNan<f64>>,
compare_to_recall: Option<NotNan<f64>>,
maintain_recall_precision: Option<NotNan<f64>>,
maintain_recall_threshold: Option<NotNan<f64>>,
maintain_precision_recall: Option<NotNan<f64>>,
maintain_precision_threshold: Option<NotNan<f64>>,
}
impl DisplayTable for StreamStat {
fn to_table_headers() -> prettytable::Row {
row![
"Name",
"Threshold",
"Current Precision",
"Current Recall",
"Compare to Precision",
"Compare to Recall"
"Threshold (T)",
"Precision (P)",
"Recall (R)",
"P at same T",
"R at same T",
"P at same R",
"R at same P",
"T at same R",
"T at same P"
]
}
fn to_table_row(&self) -> prettytable::Row {
Expand All @@ -131,46 +139,139 @@ impl DisplayTable for StreamStat {
} else {
"none".dimmed()
},
if let Some(precision) = self.maintain_recall_precision {
red_if_lower_green_otherwise(precision, self.precision)
} else {
"none".dimmed()
},
if let Some(recall) = self.maintain_precision_recall {
red_if_lower_green_otherwise(recall, self.recall)
} else {
"none".dimmed()
},
if let Some(threshold) = self.maintain_recall_threshold {
format!("{:.5}", threshold).normal()
} else {
"none".dimmed()
},
if let Some(threshold) = self.maintain_precision_threshold {
format!("{:.5}", threshold).normal()
} else {
"none".dimmed()
}
]
}
}

fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> ColoredString {
let test_str = format!("{:.3}", test);

let diff = test - threshold;

match test {
test if test < threshold => format!("{test_str} (decrease)").red(),
test if test > threshold => format!("{test_str} (increase)").green(),
test if test < threshold => format!("{test_str} ({diff:+.3})").red(),
test if test > threshold => format!("{test_str} ({diff:+.3})").green(),
_ => test_str.green(),
}
}

#[derive(Default)]
struct ThresholdAndPrecision {
threshold: Option<NotNan<f64>>,
precision: Option<NotNan<f64>>,
}

fn get_threshold_and_precision_for_recall(
recall: NotNan<f64>,
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<ThresholdAndPrecision> {
let recall_index = label_validation
.recalls
.iter()
.position(|&val_recall| val_recall >= recall)
.context(format!("Could not get recall for label {}", label_name.0))?;

let precision = label_validation.precisions.get(recall_index);

let threshold = label_validation.thresholds.get(recall_index);

Ok(ThresholdAndPrecision {
threshold: threshold.cloned(),
precision: precision.cloned(),
})
}

#[derive(Default)]
struct ThresholdAndRecall {
threshold: Option<NotNan<f64>>,
recall: Option<NotNan<f64>>,
}

fn get_threshold_and_recall_for_precision(
precision: NotNan<f64>,
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<ThresholdAndRecall> {
// Get lowest index with greater than or equal precision
let mut precision_index = None;
label_validation
.precisions
.iter()
.enumerate()
.for_each(|(idx, val_precision)| {
if val_precision >= &precision {
precision_index = Some(idx);
}
});

let precision_index = precision_index.context(format!(
"Could not get precision index for label {}",
label_name.0
))?;

let recall = label_validation.recalls.get(precision_index);
let threshold = label_validation.thresholds.get(precision_index);

Ok(ThresholdAndRecall {
threshold: threshold.cloned(),
recall: recall.cloned(),
})
}

#[derive(Default)]
struct PrecisionAndRecall {
precision: NotNan<f64>,
recall: NotNan<f64>,
}

fn get_precision_and_recall_for_threshold(
threshold: NotNan<f64>,
label_name: LabelName,
label_validation: LabelValidation,
) -> Result<(NotNan<f64>, NotNan<f64>)> {
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<PrecisionAndRecall> {
let threshold_index = label_validation
.thresholds
.iter()
.position(|&val_threshold| val_threshold < threshold)
.position(|&val_threshold| val_threshold <= threshold)
.context(format!(
"Could not find threshold for label {}",
label_name.0
))?;

let precision = label_validation
let precision = *label_validation
.precisions
.get(threshold_index)
.context(format!(
"Could not get precision for label {}",
label_name.0
))?;
let recall = label_validation
let recall = *label_validation
.recalls
.get(threshold_index)
.context(format!("Could not get recall for label {}", label_name.0))?;
Ok((*precision, *recall))

Ok(PrecisionAndRecall { precision, recall })
}

#[derive(Clone)]
Expand Down Expand Up @@ -238,10 +339,10 @@ fn get_stream_stat(
let label_validation =
client.get_label_validation(&label_name, &stream_full_name.dataset, &model.version)?;

let (precision, recall) = get_precision_and_recall_for_threshold(
let PrecisionAndRecall { precision, recall } = get_precision_and_recall_for_threshold(
label_threshold.threshold,
label_name.clone(),
label_validation,
&label_name.clone(),
&label_validation,
)?;

let mut stream_stat = StreamStat {
Expand All @@ -251,6 +352,10 @@ fn get_stream_stat(
recall,
compare_to_precision: None,
compare_to_recall: None,
maintain_recall_precision: None,
maintain_recall_threshold: None,
maintain_precision_recall: None,
maintain_precision_threshold: None,
};

if let Some(ref compare_config) = compare_config {
Expand All @@ -265,14 +370,35 @@ fn get_stream_stat(
&compare_config.model_version,
)?;

let (compare_to_precision, compare_to_recall) = get_precision_and_recall_for_threshold(
let same_threshold_precision_and_recall = get_precision_and_recall_for_threshold(
label_threshold.threshold,
label_name,
compare_to_label_validation,
&label_name,
&compare_to_label_validation,
)?;

stream_stat.compare_to_precision = Some(compare_to_precision);
stream_stat.compare_to_recall = Some(compare_to_recall);
let maintain_recall_threshold_and_precision = get_threshold_and_precision_for_recall(
recall,
&label_name,
&compare_to_label_validation,
)
.unwrap_or_default();

let maintain_precision_threshold_and_recall = get_threshold_and_recall_for_precision(
precision,
&label_name,
&compare_to_label_validation,
)
.unwrap_or_default();

stream_stat.compare_to_precision = Some(same_threshold_precision_and_recall.precision);
stream_stat.compare_to_recall = Some(same_threshold_precision_and_recall.recall);
stream_stat.maintain_recall_precision =
maintain_recall_threshold_and_precision.precision;
stream_stat.maintain_recall_threshold =
maintain_recall_threshold_and_precision.threshold;
stream_stat.maintain_precision_recall = maintain_precision_threshold_and_recall.recall;
stream_stat.maintain_precision_threshold =
maintain_precision_threshold_and_recall.threshold;
}
}
Ok(stream_stat)
Expand Down

0 comments on commit 2af59dc

Please sign in to comment.