Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): add validation to dataset stats #311

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unreleased
- Add validation to dataset `--stats`

# v0.32.0
- Add dataset flags to `create-dataset`
- Add `parse aic-classification-csv`
Expand Down
20 changes: 20 additions & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ impl Client {
Retry::Yes,
)
}
pub fn get_latest_validation(
&self,
dataset_name: &DatasetFullName,
) -> Result<ValidationResponse> {
self.get::<_, ValidationResponse>(self.endpoints.latest_validation(dataset_name)?)
}

pub fn get_validation(
&self,
Expand Down Expand Up @@ -1838,6 +1844,20 @@ impl Endpoints {
],
)
}
fn latest_validation(&self, dataset_name: &DatasetFullName) -> Result<Url> {
construct_endpoint(
&self.base,
&[
"api",
"_private",
"datasets",
&dataset_name.0,
"labellers",
"latest",
"validation",
],
)
}

fn validation(
&self,
Expand Down
14 changes: 11 additions & 3 deletions api/src/resources/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::{
str::FromStr,
};

use super::validation::ValidationResponse;

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum DatasetFlag {
Expand Down Expand Up @@ -49,13 +51,13 @@ pub struct Dataset {
pub dataset_flags: Vec<DatasetFlag>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct DatasetStats {
pub num_reviewed: NotNan<f64>,
pub total_verbatims: NotNan<f64>,
pub validation: Option<ValidationResponse>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct DatasetAndStats {
pub dataset: Dataset,
pub stats: DatasetStats,
Expand Down Expand Up @@ -203,6 +205,12 @@ pub struct ModelFamily(pub String);
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct ModelVersion(pub u32);

impl std::fmt::Display for ModelVersion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

// TODO(mcobzarenco)[3963]: Make `Identifier` into a trait (ensure it still implements
// `FromStr` so we can take T: Identifier as a clap command line argument).
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
Expand Down
53 changes: 51 additions & 2 deletions api/src/resources/validation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{LabelGroup, LabelName};
use crate::{LabelGroup, LabelName, ModelVersion};
use ordered_float::NotNan;
use serde::{Deserialize, Serialize};

Expand All @@ -19,9 +19,58 @@ pub struct LabelValidationResponse {
pub label_validation: LabelValidation,
}

#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelScore(pub f32);

impl std::fmt::Display for ModelScore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

Comment on lines +25 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi this looks like the kind of thing you could use the derive_more crate to do: https://jeltef.github.io/derive_more/derive_more/#example-code

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum DatasetQuality {
None,
Poor,
Average,
Good,
Excellent,
}

impl std::fmt::Display for DatasetQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::None => "None",
Self::Poor => "Poor",
Self::Average => "Average",
Self::Good => "Good",
Self::Excellent => "Excellent",
}
)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelRating {
pub score: ModelScore,
pub quality: DatasetQuality,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ValidationSummary {
pub version: ModelVersion,
pub model_rating: ModelRating,
pub reviewed_size: usize,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ValidationResponse {
pub label_groups: Vec<LabelGroup>,
pub validation: ValidationSummary,
}

impl ValidationResponse {
Expand Down
85 changes: 49 additions & 36 deletions cli/src/commands/get/datasets.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::sync::mpsc::channel;

use anyhow::{Context, Result};
use log::info;
use reinfer_client::{
resources::dataset::{DatasetAndStats, DatasetStats, StatisticsRequestParams},
Client, CommentFilter, DatasetIdentifier, SourceIdentifier,
Client, DatasetIdentifier, SourceIdentifier,
};
use scoped_threadpool::Pool;
use structopt::StructOpt;

use crate::printer::Printer;
Expand All @@ -23,7 +26,12 @@ pub struct GetDatasetsArgs {
source_identifier: Option<SourceIdentifier>,
}

pub fn get(client: &Client, args: &GetDatasetsArgs, printer: &Printer) -> Result<()> {
pub fn get(
client: &Client,
args: &GetDatasetsArgs,
printer: &Printer,
pool: &mut Pool,
) -> Result<()> {
let GetDatasetsArgs {
dataset,
include_stats,
Expand All @@ -49,42 +57,47 @@ pub fn get(client: &Client, args: &GetDatasetsArgs, printer: &Printer) -> Result
datasets.retain(|d| d.source_ids.contains(&source.id));
}

let mut dataset_stats = Vec::new();
let (sender, receiver) = channel();

if *include_stats {
datasets.iter().try_for_each(|dataset| -> Result<()> {
info!("Getting statistics for dataset {}", dataset.full_name().0);
let unfiltered_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let reviewed_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
comment_filter: CommentFilter {
reviewed:Some(reinfer_client::resources::comment::ReviewedFilterEnum::OnlyReviewed),
..Default::default()
pool.scoped(|scope| {
datasets.iter().for_each(|dataset| {
let get_stats = || -> Result<DatasetAndStats> {
info!("Getting statistics for dataset {}", dataset.full_name().0);
let unfiltered_stats = client
.get_dataset_statistics(
&dataset.full_name(),
&StatisticsRequestParams {
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let validation_response = client.get_latest_validation(&dataset.full_name());

Ok(DatasetAndStats {
dataset: dataset.clone(),
stats: DatasetStats {
total_verbatims: unfiltered_stats.num_comments,
validation: validation_response.ok(),
},
..Default::default()
},
)
.context("Could not get statistics for dataset")?;

let dataset_and_stats = DatasetAndStats {
dataset: dataset.clone(),
stats: DatasetStats {
num_reviewed: reviewed_stats.num_comments,
total_verbatims: unfiltered_stats.num_comments
}
};
dataset_stats.push(dataset_and_stats);
Ok(())
})?;
})
};

let sender = sender.clone();
scope.execute(move || {
sender.send(get_stats()).expect("Could not send error");
});
});
});

drop(sender);
let mut dataset_stats = Vec::new();
let results: Vec<Result<DatasetAndStats>> = receiver.iter().collect();

for result in results {
dataset_stats.push(result?);
}

printer.print_resources(&dataset_stats)
} else {
Expand Down
2 changes: 1 addition & 1 deletion cli/src/commands/get/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn run(args: &GetArgs, client: Client, printer: &Printer, pool: &mut Pool) -
GetArgs::Emails(args) => emails::get_many(&client, args),
GetArgs::Comment(args) => comments::get_single(&client, args),
GetArgs::Comments(args) => comments::get_many(&client, args),
GetArgs::Datasets(args) => datasets::get(&client, args, printer),
GetArgs::Datasets(args) => datasets::get(&client, args, printer, pool),
GetArgs::Projects(args) => projects::get(&client, args, printer),
GetArgs::Sources(args) => sources::get(&client, args, printer),
GetArgs::Streams(args) => streams::get(&client, args, printer),
Expand Down
36 changes: 27 additions & 9 deletions cli/src/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl DisplayTable for Dataset {

impl DisplayTable for DatasetAndStats {
fn to_table_headers() -> Row {
row![bFg => "Name", "ID", "Updated (UTC)", "Title","Total Verbatims", "Num Reviewed"]
row![bFg => "Name", "ID", "Updated (UTC)", "Title","Total Verbatims", "Num Reviewed","Latest Model", "Score", "Quality"]
}

fn to_table_row(&self) -> Row {
Expand All @@ -146,14 +146,32 @@ impl DisplayTable for DatasetAndStats {
"/".dimmed(),
self.dataset.name.0
);
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
self.stats.num_reviewed
]

if let Some(validation_response) = &self.stats.validation {
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
validation_response.validation.reviewed_size,
validation_response.validation.version,
validation_response.validation.model_rating.score,
validation_response.validation.model_rating.quality
]
} else {
row![
full_name,
self.dataset.id.0,
self.dataset.updated_at.format("%Y-%m-%d %H:%M:%S"),
self.dataset.title,
self.stats.total_verbatims,
"N/A".dimmed(),
"N/A".dimmed(),
"N/A".dimmed(),
"N/A".dimmed(),
]
}
}
}

Expand Down
Loading