Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more stream label validation stats #241

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 148 additions & 22 deletions cli/src/commands/get/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,24 @@ pub struct StreamStat {
recall: NotNan<f64>,
compare_to_precision: Option<NotNan<f64>>,
compare_to_recall: Option<NotNan<f64>>,
maintain_recall_precision: Option<NotNan<f64>>,
maintain_recall_threshold: Option<NotNan<f64>>,
maintain_precision_recall: Option<NotNan<f64>>,
maintain_precision_threshold: Option<NotNan<f64>>,
}
impl DisplayTable for StreamStat {
fn to_table_headers() -> prettytable::Row {
row![
"Name",
"Threshold",
"Current Precision",
"Current Recall",
"Compare to Precision",
"Compare to Recall"
"Threshold (T)",
"Precision (P)",
"Recall (R)",
"P at same T",
"R at same T",
"P at same R",
"R at same P",
"T at same R",
"T at same P"
]
}
fn to_table_row(&self) -> prettytable::Row {
Expand All @@ -131,46 +139,139 @@ impl DisplayTable for StreamStat {
} else {
"none".dimmed()
},
if let Some(precision) = self.maintain_recall_precision {
red_if_lower_green_otherwise(precision, self.precision)
} else {
"none".dimmed()
},
if let Some(recall) = self.maintain_precision_recall {
red_if_lower_green_otherwise(recall, self.recall)
} else {
"none".dimmed()
},
if let Some(threshold) = self.maintain_recall_threshold {
format!("{:.5}", threshold).normal()
} else {
"none".dimmed()
},
if let Some(threshold) = self.maintain_precision_threshold {
format!("{:.5}", threshold).normal()
} else {
"none".dimmed()
}
]
}
}

fn red_if_lower_green_otherwise(test: NotNan<f64>, threshold: NotNan<f64>) -> ColoredString {
let test_str = format!("{:.3}", test);

let diff = test - threshold;

match test {
test if test < threshold => format!("{test_str} (decrease)").red(),
test if test > threshold => format!("{test_str} (increase)").green(),
test if test < threshold => format!("{test_str} ({diff:.3})").red(),
test if test > threshold => format!("{test_str} (+{diff:.3})").green(),
Copy link
Collaborator

@tommilligan tommilligan Nov 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the printf directive + to always add the appropriate sign here:

fn main() {
    for value in [0.0, -0.0, 0.0, -42.3, 42.3] {
        println!("{:+0.3}", value);
    }
}

https://doc.rust-lang.org/std/fmt/#sign0

https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=54b7555880b48aa31f24fd5f3fc499e1

_ => test_str.green(),
}
}

#[derive(Default)]
struct ThresholdAndPrecision {
threshold: Option<NotNan<f64>>,
precision: Option<NotNan<f64>>,
}

fn get_threshold_and_precision_for_recall(
recall: NotNan<f64>,
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<ThresholdAndPrecision> {
let recall_index = label_validation
.recalls
.iter()
.position(|&val_recall| val_recall >= recall)
.context(format!("Could not get recall for label {}", label_name.0))?;

let precision = label_validation.precisions.get(recall_index);

let threshold = label_validation.thresholds.get(recall_index);

Ok(ThresholdAndPrecision {
threshold: threshold.cloned(),
precision: precision.cloned(),
})
}

#[derive(Default)]
struct ThresholdAndRecall {
threshold: Option<NotNan<f64>>,
recall: Option<NotNan<f64>>,
}

fn get_threshold_and_recall_for_precision(
precision: NotNan<f64>,
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<ThresholdAndRecall> {
// Get lowest index with greater than or equal precision
let mut precision_index = None;
label_validation
.precisions
.iter()
.enumerate()
.for_each(|(idx, val_precision)| {
if val_precision >= &precision {
precision_index = Some(idx);
}
});

let precision_index = precision_index.context(format!(
"Could not get precision index for label {}",
label_name.0
))?;

let recall = label_validation.recalls.get(precision_index);
let threshold = label_validation.thresholds.get(precision_index);

Ok(ThresholdAndRecall {
threshold: threshold.cloned(),
recall: recall.cloned(),
})
}

#[derive(Default)]
struct PrecisionAndRecall {
precision: NotNan<f64>,
recall: NotNan<f64>,
}

fn get_precision_and_recall_for_threshold(
threshold: NotNan<f64>,
label_name: LabelName,
label_validation: LabelValidation,
) -> Result<(NotNan<f64>, NotNan<f64>)> {
label_name: &LabelName,
label_validation: &LabelValidation,
) -> Result<PrecisionAndRecall> {
let threshold_index = label_validation
.thresholds
.iter()
.position(|&val_threshold| val_threshold < threshold)
.position(|&val_threshold| val_threshold <= threshold)
.context(format!(
"Could not find threshold for label {}",
label_name.0
))?;

let precision = label_validation
let precision = *label_validation
.precisions
.get(threshold_index)
.context(format!(
"Could not get precision for label {}",
label_name.0
))?;
let recall = label_validation
let recall = *label_validation
.recalls
.get(threshold_index)
.context(format!("Could not get recall for label {}", label_name.0))?;
Ok((*precision, *recall))

Ok(PrecisionAndRecall { precision, recall })
}

#[derive(Clone)]
Expand Down Expand Up @@ -238,10 +339,10 @@ fn get_stream_stat(
let label_validation =
client.get_label_validation(&label_name, &stream_full_name.dataset, &model.version)?;

let (precision, recall) = get_precision_and_recall_for_threshold(
let PrecisionAndRecall { precision, recall } = get_precision_and_recall_for_threshold(
label_threshold.threshold,
label_name.clone(),
label_validation,
&label_name.clone(),
&label_validation,
)?;

let mut stream_stat = StreamStat {
Expand All @@ -251,6 +352,10 @@ fn get_stream_stat(
recall,
compare_to_precision: None,
compare_to_recall: None,
maintain_recall_precision: None,
maintain_recall_threshold: None,
maintain_precision_recall: None,
maintain_precision_threshold: None,
};

if let Some(ref compare_config) = compare_config {
Expand All @@ -265,14 +370,35 @@ fn get_stream_stat(
&compare_config.model_version,
)?;

let (compare_to_precision, compare_to_recall) = get_precision_and_recall_for_threshold(
let same_threshold_precision_and_recall = get_precision_and_recall_for_threshold(
label_threshold.threshold,
label_name,
compare_to_label_validation,
&label_name,
&compare_to_label_validation,
)?;

stream_stat.compare_to_precision = Some(compare_to_precision);
stream_stat.compare_to_recall = Some(compare_to_recall);
let maintain_recall_threshold_and_precision = get_threshold_and_precision_for_recall(
recall,
&label_name,
&compare_to_label_validation,
)
.unwrap_or_default();

let maintain_precision_threshold_and_recall = get_threshold_and_recall_for_precision(
precision,
&label_name,
&compare_to_label_validation,
)
.unwrap_or_default();

stream_stat.compare_to_precision = Some(same_threshold_precision_and_recall.precision);
stream_stat.compare_to_recall = Some(same_threshold_precision_and_recall.recall);
stream_stat.maintain_recall_precision =
maintain_recall_threshold_and_precision.precision;
stream_stat.maintain_recall_threshold =
maintain_recall_threshold_and_precision.threshold;
stream_stat.maintain_precision_recall = maintain_precision_threshold_and_recall.recall;
stream_stat.maintain_precision_threshold =
maintain_precision_threshold_and_recall.threshold;
}
}
Ok(stream_stat)
Expand Down
Loading