Skip to content

Commit

Permalink
feat(cli): add validation to dataset stats (#311)
Browse files Browse the repository at this point in the history
* feat(cli): add validation to dataset stats
  • Loading branch information
joe-prosser authored Aug 29, 2024
1 parent efdb00e commit fa87853
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 51 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unreleased
- Add validation to dataset `--stats`

# v0.32.0
- Add dataset flags to `create-dataset`
- Add `parse aic-classification-csv`
Expand Down
20 changes: 20 additions & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ impl Client {
Retry::Yes,
)
}
pub fn get_latest_validation(
&self,
dataset_name: &DatasetFullName,
) -> Result<ValidationResponse> {
self.get::<_, ValidationResponse>(self.endpoints.latest_validation(dataset_name)?)
}

pub fn get_validation(
&self,
Expand Down Expand Up @@ -1838,6 +1844,20 @@ impl Endpoints {
],
)
}
fn latest_validation(&self, dataset_name: &DatasetFullName) -> Result<Url> {
construct_endpoint(
&self.base,
&[
"api",
"_private",
"datasets",
&dataset_name.0,
"labellers",
"latest",
"validation",
],
)
}

fn validation(
&self,
Expand Down
14 changes: 11 additions & 3 deletions api/src/resources/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::{
str::FromStr,
};

use super::validation::ValidationResponse;

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum DatasetFlag {
Expand Down Expand Up @@ -49,13 +51,13 @@ pub struct Dataset {
pub dataset_flags: Vec<DatasetFlag>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct DatasetStats {
pub num_reviewed: NotNan<f64>,
pub total_verbatims: NotNan<f64>,
pub validation: Option<ValidationResponse>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct DatasetAndStats {
pub dataset: Dataset,
pub stats: DatasetStats,
Expand Down Expand Up @@ -203,6 +205,12 @@ pub struct ModelFamily(pub String);
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct ModelVersion(pub u32);

impl std::fmt::Display for ModelVersion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

// TODO(mcobzarenco)[3963]: Make `Identifier` into a trait (ensure it still implements
// `FromStr` so we can take T: Identifier as a clap command line argument).
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
Expand Down
53 changes: 51 additions & 2 deletions api/src/resources/validation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{LabelGroup, LabelName};
use crate::{LabelGroup, LabelName, ModelVersion};
use ordered_float::NotNan;
use serde::{Deserialize, Serialize};

Expand All @@ -19,9 +19,58 @@ pub struct LabelValidationResponse {
pub label_validation: LabelValidation,
}

#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelScore(pub f32);

impl std::fmt::Display for ModelScore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum DatasetQuality {
None,
Poor,
Average,
Good,
Excellent,
}

impl std::fmt::Display for DatasetQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::None => "None",
Self::Poor => "Poor",
Self::Average => "Average",
Self::Good => "Good",
Self::Excellent => "Excellent",
}
)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelRating {
pub score: ModelScore,
pub quality: DatasetQuality,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ValidationSummary {
pub version: ModelVersion,
pub model_rating: ModelRating,
pub reviewed_size: usize,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ValidationResponse {
pub label_groups: Vec<LabelGroup>,
pub validation: ValidationSummary,
}

impl ValidationResponse {
Expand Down
85 changes: 49 additions & 36 deletions cli/src/commands/get/datasets.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::sync::mpsc::channel;

use anyhow::{Context, Result};
use log::info;
use reinfer_client::{
resources::dataset::{DatasetAndStats, DatasetStats, StatisticsRequestParams},
Client, CommentFilter, DatasetIdentifier, SourceIdentifier,
Client, DatasetIdentifier, SourceIdentifier,
};
use scoped_threadpool::Pool;
use structopt::StructOpt;

use crate::printer::Printer;
Expand All @@ -23,7 +26,12 @@ pub struct GetDatasetsArgs {
source_identifier: Option<SourceIdentifier>,
}

pub fn get(client: &Client, args: &GetDatasetsArgs, printer: &Printer) -> Result<()> {
pub fn get(
client: &Client,
args: &GetDatasetsArgs,
printer: &Printer,
pool: &mut Pool,
) -> Result<()> {
let GetDatasetsArgs {
dataset,
include_stats,
Expand All @@ -49,42 +57,47 @@ pub fn get(client: &Client, args: &GetDatasetsArgs, printer: &Printer) -> Result
datasets.retain(|d| d.source_ids.contains(&source.id));
}

let mut dataset_stats = Vec::new();
let (sender, receiver) = channel();

if *include_stats {
datasets.iter().try_for_each(|dataset| -> Result<()> {
info!("Getting statistics for dataset {}", dataset.full_name().0);
let unfiltered_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let reviewed_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
comment_filter: CommentFilter {
reviewed:Some(reinfer_client::resources::comment::ReviewedFilterEnum::OnlyReviewed),
..Default::default()
pool.scoped(|scope| {
datasets.iter().for_each(|dataset| {
let get_stats = || -> Result<DatasetAndStats> {
info!("Getting statistics for dataset {}", dataset.full_name().0);
let unfiltered_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let validation_response = client.get_latest_validation(&dataset.full_name());

Ok(DatasetAndStats {
dataset: dataset.clone(),
stats: DatasetStats {
total_verbatims: unfiltered_stats.num_comments,
validation: validation_response.ok(),
},
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let dataset_and_stats = DatasetAndStats {
dataset: dataset.clone(),
stats: DatasetStats {
num_reviewed: reviewed_stats.num_comments,
total_verbatims: unfiltered_stats.num_comments
}
};
dataset_stats.push(dataset_and_stats);
Ok(())
})?;
})
};

let sender = sender.clone();
scope.execute(move || {
sender.send(get_stats()).expect("Could not send error");
});
});
});

drop(sender);
let mut dataset_stats = Vec::new();
let results: Vec<Result<DatasetAndStats>> = receiver.iter().collect();

for result in results {
dataset_stats.push(result?);
}

printer.print_resources(&dataset_stats)
} else {
Expand Down
2 changes: 1 addition & 1 deletion cli/src/commands/get/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn run(args: &GetArgs, client: Client, printer: &Printer, pool: &mut Pool) -
GetArgs::Emails(args) => emails::get_many(&client, args),
GetArgs::Comment(args) => comments::get_single(&client, args),
GetArgs::Comments(args) => comments::get_many(&client, args),
GetArgs::Datasets(args) => datasets::get(&client, args, printer),
GetArgs::Datasets(args) => datasets::get(&client, args, printer, pool),
GetArgs::Projects(args) => projects::get(&client, args, printer),
GetArgs::Sources(args) => sources::get(&client, args, printer),
GetArgs::Streams(args) => streams::get(&client, args, printer),
Expand Down
36 changes: 27 additions & 9 deletions cli/src/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl DisplayTable for Dataset {

impl DisplayTable for DatasetAndStats {
fn to_table_headers() -> Row {
row![bFg => "Name", "ID", "Updated (UTC)", "Title","Total Verbatims", "Num Reviewed"]
row![bFg => "Name", "ID", "Updated (UTC)", "Title","Total Verbatims", "Num Reviewed","Latest Model", "Score", "Quality"]
}

fn to_table_row(&self) -> Row {
Expand All @@ -146,14 +146,32 @@ impl DisplayTable for DatasetAndStats {
"/".dimmed(),
self.dataset.name.0
);
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
self.stats.num_reviewed
]

if let Some(validation_response) = &self.stats.validation {
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
validation_response.validation.reviewed_size,
validation_response.validation.version,
validation_response.validation.model_rating.score,
validation_response.validation.model_rating.quality
]
} else {
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
"N/A".dimmed(),
"N/A".dimmed(),
"N/A".dimmed(),
"N/A".dimmed(),
]
}
}
}

Expand Down

0 comments on commit fa87853

Please sign in to comment.