diff --git a/CHANGELOG.md b/CHANGELOG.md index edf812c3..40fbb385 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,10 @@ ## Unreleased + - Fix url used for fetching streams - Return `is_end_sequence` on stream fetch - Make `transform_tag` optional on `create bucket` - Retry `put_emails` requests +- Add `get stream-stats` to expose and compare model validation ## v0.20.0 diff --git a/Cargo.lock b/Cargo.lock index f553a827..858686c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1250,6 +1250,7 @@ dependencies = [ "log", "maplit", "once_cell", + "ordered-float", "pretty_assertions", "prettytable-rs", "regex", diff --git a/api/src/lib.rs b/api/src/lib.rs index 4fb0279e..19679baf 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -22,7 +22,10 @@ use resources::{ project::ForceDeleteProject, quota::{GetQuotasResponse, Quota}, source::StatisticsRequestParams as SourceStatisticsRequestParams, - stream::{NewStream, PutStreamRequest, PutStreamResponse}, + stream::{GetStreamResponse, NewStream, PutStreamRequest, PutStreamResponse}, + validation::{ + LabelValidation, LabelValidationRequest, LabelValidationResponse, ValidationResponse, + }, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -58,7 +61,7 @@ use crate::resources::{ statistics::GetResponse as GetStatisticsResponse, stream::{ AdvanceRequest as StreamAdvanceRequest, FetchRequest as StreamFetchRequest, - GetResponse as GetStreamsResponse, ResetRequest as StreamResetRequest, + GetStreamsResponse, ResetRequest as StreamResetRequest, TagExceptionsRequest as TagStreamExceptionsRequest, }, tenant_id::TenantId, @@ -405,6 +408,32 @@ impl Client { ) } + pub fn get_validation( + &self, + dataset_name: &DatasetFullName, + model_version: &ModelVersion, + ) -> Result { + self.get::<_, ValidationResponse>(self.endpoints.validation(dataset_name, model_version)?) + } + + pub fn get_label_validation( + &self, + label: &LabelName, + dataset_name: &DatasetFullName, + model_version: &ModelVersion, + ) -> Result { + Ok(self + .post::<_, _, LabelValidationResponse>( + self.endpoints + .label_validation(dataset_name, model_version)?, + LabelValidationRequest { + label: label.clone(), + }, + Retry::Yes, + )? + .label_validation) + } + pub fn sync_comments( &self, source_name: &SourceFullName, @@ -799,6 +828,12 @@ impl Client { ) } + pub fn get_stream(&self, stream_name: &StreamFullName) -> Result { + Ok(self + .get::<_, GetStreamResponse>(self.endpoints.stream(stream_name)?)? + .stream) + } + pub fn advance_stream( &self, stream_name: &StreamFullName, @@ -1279,6 +1314,44 @@ impl Endpoints { }) } + fn validation( + &self, + dataset_name: &DatasetFullName, + model_version: &ModelVersion, + ) -> Result { + construct_endpoint( + &self.base, + &[ + "api", + "_private", + "datasets", + &dataset_name.0, + "labellers", + &model_version.0.to_string(), + "validation", + ], + ) + } + + fn label_validation( + &self, + dataset_name: &DatasetFullName, + model_version: &ModelVersion, + ) -> Result { + construct_endpoint( + &self.base, + &[ + "api", + "_private", + "datasets", + &dataset_name.0, + "labellers", + &model_version.0.to_string(), + "label-validation", + ], + ) + } + fn dataset_summary(&self, dataset_name: &DatasetFullName) -> Result { construct_endpoint( &self.base, @@ -1300,6 +1373,20 @@ impl Endpoints { ) } + fn stream(&self, stream_name: &StreamFullName) -> Result { + construct_endpoint( + &self.base, + &[ + "api", + "v1", + "datasets", + &stream_name.dataset.0, + "streams", + &stream_name.stream.0, + ], + ) + } + fn stream_fetch(&self, stream_name: &StreamFullName) -> Result { construct_endpoint( &self.base, diff --git a/api/src/resources/mod.rs b/api/src/resources/mod.rs index 34676fb8..9ea27430 100644 --- a/api/src/resources/mod.rs +++ b/api/src/resources/mod.rs @@ -13,6 +13,7 @@ pub mod statistics; pub mod stream; pub mod tenant_id; pub mod user; +pub mod validation; use crate::error::{Error, Result}; use reqwest::StatusCode; diff --git a/api/src/resources/stream.rs b/api/src/resources/stream.rs index 564c2547..d9fc7cc5 100644 --- a/api/src/resources/stream.rs +++ b/api/src/resources/stream.rs @@ -3,7 +3,10 @@ use ordered_float::NotNan; use serde::{Deserialize, Serialize}; use std::str::FromStr; -use crate::error::{Error, Result}; +use crate::{ + error::{Error, Result}, + ModelVersion, +}; use super::{ comment::{Comment, CommentFilter, Entity, PredictedLabel, Uid as CommentUid}, @@ -67,7 +70,7 @@ pub struct NewStream { } impl NewStream { - pub fn set_model_version(&mut self, model_version: &UserModelVersion) { + pub fn set_model_version(&mut self, model_version: &ModelVersion) { if let Some(model) = &mut self.model { model.version = model_version.clone() } @@ -76,14 +79,14 @@ impl NewStream { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct StreamModel { - pub version: UserModelVersion, + pub version: ModelVersion, pub label_thresholds: Vec, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct StreamLabelThreshold { - name: Vec, - threshold: NotNan, + pub name: Vec, + pub threshold: NotNan, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -109,19 +112,16 @@ pub struct Stream { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct LabelFilter { pub label: LabelName, - pub model_version: UserModelVersion, + pub model_version: ModelVersion, pub threshold: NotNan, } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct UserModelVersion(pub u64); - -impl FromStr for UserModelVersion { +impl FromStr for ModelVersion { type Err = Error; fn from_str(s: &str) -> Result { - match s.parse::() { - Ok(version) => Ok(UserModelVersion(version)), + match s.parse::() { + Ok(version) => Ok(ModelVersion(version)), Err(_) => Err(Error::BadStreamModelVersion { version: s.to_string(), }), @@ -146,10 +146,15 @@ pub struct StreamResult { } #[derive(Debug, Clone, Deserialize)] -pub(crate) struct GetResponse { +pub(crate) struct GetStreamsResponse { pub streams: Vec, } +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct GetStreamResponse { + pub stream: Stream, +} + #[derive(Debug, Clone, Serialize)] pub(crate) struct FetchRequest { pub size: u32, diff --git a/api/src/resources/validation.rs b/api/src/resources/validation.rs new file mode 100644 index 00000000..cc46864b --- /dev/null +++ b/api/src/resources/validation.rs @@ -0,0 +1,33 @@ +use crate::{LabelGroup, LabelName}; +use ordered_float::NotNan; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct LabelValidation { + pub thresholds: Vec>, + pub precisions: Vec>, + pub recalls: Vec>, +} + +#[derive(Serialize)] +pub struct LabelValidationRequest { + pub label: LabelName, +} + +#[derive(Deserialize)] +pub struct LabelValidationResponse { + pub label_validation: LabelValidation, +} + +#[derive(Clone, Deserialize)] +pub struct ValidationResponse { + pub label_groups: Vec, +} + +impl ValidationResponse { + pub fn get_default_label_group(&self) -> Option<&LabelGroup> { + self.label_groups + .iter() + .find(|group| group.name.0 == "default") + } +} diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 7c858dc7..5e5264d6 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -40,6 +40,7 @@ scoped_threadpool = "0.1.9" backoff = "0.4.0" cfb = "0.9.0" encoding_rs = "0.8.33" +ordered-float = { version = "3.9.1", features = ["serde"] } [dev-dependencies] pretty_assertions = "1.3.0" diff --git a/cli/src/commands/create/streams.rs b/cli/src/commands/create/streams.rs index 201e051d..6e551312 100644 --- a/cli/src/commands/create/streams.rs +++ b/cli/src/commands/create/streams.rs @@ -6,10 +6,8 @@ use std::{ use anyhow::{Context, Result}; use log::info; -use reinfer_client::{ - resources::stream::{NewStream, UserModelVersion}, - Client, DatasetIdentifier, -}; +use reinfer_client::ModelVersion; +use reinfer_client::{resources::stream::NewStream, Client, DatasetIdentifier}; use structopt::StructOpt; @@ -25,7 +23,7 @@ pub struct CreateStreamsArgs { #[structopt(short = "v", long = "model-version")] /// The model version for the new streams to use - model_version: UserModelVersion, + model_version: ModelVersion, } pub fn create(client: &Client, args: &CreateStreamsArgs) -> Result<()> { diff --git a/cli/src/commands/get/mod.rs b/cli/src/commands/get/mod.rs index a492053c..6820d797 100644 --- a/cli/src/commands/get/mod.rs +++ b/cli/src/commands/get/mod.rs @@ -9,6 +9,7 @@ mod users; use anyhow::Result; use reinfer_client::Client; +use scoped_threadpool::Pool; use structopt::StructOpt; use self::{ @@ -17,7 +18,7 @@ use self::{ datasets::GetDatasetsArgs, projects::GetProjectsArgs, sources::GetSourcesArgs, - streams::{GetStreamCommentsArgs, GetStreamsArgs}, + streams::{GetStreamCommentsArgs, GetStreamStatsArgs, GetStreamsArgs}, users::GetUsersArgs, }; use crate::printer::Printer; @@ -56,6 +57,10 @@ pub enum GetArgs { /// Fetch comments from a stream StreamComments(GetStreamCommentsArgs), + #[structopt(name = "stream-stats")] + /// Get the validation stats for a given stream + StreamStats(GetStreamStatsArgs), + #[structopt(name = "users")] /// List the available users Users(GetUsersArgs), @@ -69,7 +74,7 @@ pub enum GetArgs { Quotas, } -pub fn run(args: &GetArgs, client: Client, printer: &Printer) -> Result<()> { +pub fn run(args: &GetArgs, client: Client, printer: &Printer, pool: &mut Pool) -> Result<()> { match args { GetArgs::Buckets(args) => buckets::get(&client, args, printer), GetArgs::Comment(args) => comments::get_single(&client, args), @@ -79,6 +84,7 @@ pub fn run(args: &GetArgs, client: Client, printer: &Printer) -> Result<()> { GetArgs::Sources(args) => sources::get(&client, args, printer), GetArgs::Streams(args) => streams::get(&client, args, printer), GetArgs::StreamComments(args) => streams::get_stream_comments(&client, args), + GetArgs::StreamStats(args) => streams::get_stream_stats(&client, args, printer, pool), GetArgs::Users(args) => users::get(&client, args, printer), GetArgs::CurrentUser => users::get_current_user(&client, printer), GetArgs::Quotas => quota::get(&client, printer), diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs index e035e232..7e220b9c 100644 --- a/cli/src/commands/get/streams.rs +++ b/cli/src/commands/get/streams.rs @@ -1,5 +1,17 @@ -use anyhow::{Context, Result}; -use reinfer_client::{Client, DatasetIdentifier, StreamFullName}; +use anyhow::{anyhow, Context, Result}; +use colored::{ColoredString, Colorize}; +use log::info; +use ordered_float::NotNan; +use prettytable::row; +use reinfer_client::resources::stream::{StreamLabelThreshold, StreamModel}; +use reinfer_client::resources::validation::ValidationResponse; +use reinfer_client::{ + resources::validation::LabelValidation, Client, DatasetIdentifier, ModelVersion, StreamFullName, +}; +use reinfer_client::{DatasetFullName, LabelDef, LabelName}; +use scoped_threadpool::Pool; +use serde::Serialize; +use std::sync::mpsc::channel; use std::{ fs::File, io, @@ -8,7 +20,7 @@ use std::{ }; use structopt::StructOpt; -use crate::printer::{print_resources_as_json, Printer}; +use crate::printer::{print_resources_as_json, DisplayTable, Printer}; #[derive(Debug, StructOpt)] pub struct GetStreamsArgs { @@ -40,6 +52,21 @@ pub struct GetStreamCommentsArgs { individual_advance: bool, } +#[derive(Debug, StructOpt)] +pub struct GetStreamStatsArgs { + #[structopt(name = "stream")] + /// The full stream name `//`. + stream_full_name: StreamFullName, + + #[structopt(long = "compare-version", short = "v")] + /// The model version to compare stats with + compare_to_model_version: Option, + + #[structopt(long = "compare-dataset", short = "d")] + /// The dataset to compare stats with + compare_to_dataset: Option, +} + pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result<()> { let GetStreamsArgs { dataset, path } = args; @@ -68,6 +95,259 @@ pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result< } } +#[derive(Serialize)] +pub struct StreamStat { + label_name: LabelName, + threshold: NotNan, + precision: NotNan, + recall: NotNan, + compare_to_precision: Option>, + compare_to_recall: Option>, +} +impl DisplayTable for StreamStat { + fn to_table_headers() -> prettytable::Row { + row![ + "Name", + "Threshold", + "Current Precision", + "Current Recall", + "Compare to Precision", + "Compare to Recall" + ] + } + fn to_table_row(&self) -> prettytable::Row { + row![ + self.label_name.0, + format!("{:.3}", self.threshold), + format!("{:.3}", self.precision), + format!("{:.3}", self.recall), + if let Some(precision) = self.compare_to_precision { + red_if_lower_green_otherwise(precision, self.precision) + } else { + "none".dimmed() + }, + if let Some(recall) = self.compare_to_recall { + red_if_lower_green_otherwise(recall, self.recall) + } else { + "none".dimmed() + }, + ] + } +} + +fn red_if_lower_green_otherwise(test: NotNan, threshold: NotNan) -> ColoredString { + let test_str = format!("{:.3}", test); + + match test { + test if test < threshold => format!("{test_str} (decrease)").red(), + test if test > threshold => format!("{test_str} (increase)").green(), + _ => test_str.green(), + } +} + +fn get_precision_and_recall_for_threshold( + threshold: NotNan, + label_name: LabelName, + label_validation: LabelValidation, +) -> Result<(NotNan, NotNan)> { + let threshold_index = label_validation + .thresholds + .iter() + .position(|&val_threshold| val_threshold < threshold) + .context(format!( + "Could not find threshold for label {}", + label_name.0 + ))?; + + let precision = label_validation + .precisions + .get(threshold_index) + .context(format!( + "Could not get precision for label {}", + label_name.0 + ))?; + let recall = label_validation + .recalls + .get(threshold_index) + .context(format!("Could not get recall for label {}", label_name.0))?; + Ok((*precision, *recall)) +} + +#[derive(Clone)] +struct CompareConfig { + validation: ValidationResponse, + dataset_name: DatasetFullName, + model_version: ModelVersion, +} + +impl CompareConfig { + pub fn get_label_def(&self, label_name: &LabelName) -> Result> { + Ok(self + .validation + .get_default_label_group() + .context("Compare to dataset does not have a default label group")? + .label_defs + .iter() + .find(|label| label.name == *label_name)) + } +} + +fn get_compare_config( + client: &Client, + model_version: &Option, + dataset_name: &Option, + stream_name: &StreamFullName, +) -> Result> { + if model_version.is_none() && dataset_name.is_none() { + return Ok(None); + } + + let dataset_name = if let Some(dataset_name) = dataset_name { + dataset_name + } else { + &stream_name.dataset + }; + + let model_version = model_version + .clone() + .context("No compare to model version provided")?; + + info!("Getting validation for {}", dataset_name.0); + let validation = client.get_validation(dataset_name, &model_version)?; + + Ok(Some(CompareConfig { + validation, + dataset_name: dataset_name.clone(), + model_version, + })) +} + +fn get_stream_stat( + label_threshold: &StreamLabelThreshold, + stream_full_name: &StreamFullName, + model: &StreamModel, + compare_config: &Option, + client: &Client, +) -> Result { + let label_name = reinfer_client::LabelName(label_threshold.name.join(" > ")); + + info!( + "Getting label validation for {} in dataset {}", + label_name.0, stream_full_name.dataset.0 + ); + let label_validation = + client.get_label_validation(&label_name, &stream_full_name.dataset, &model.version)?; + + let (precision, recall) = get_precision_and_recall_for_threshold( + label_threshold.threshold, + label_name.clone(), + label_validation, + )?; + + let mut stream_stat = StreamStat { + label_name: label_name.clone(), + threshold: label_threshold.threshold, + precision, + recall, + compare_to_precision: None, + compare_to_recall: None, + }; + + if let Some(ref compare_config) = compare_config { + if compare_config.get_label_def(&label_name)?.is_some() { + info!( + "Getting label validation for {} in dataset {}", + label_name.0, compare_config.dataset_name.0 + ); + let compare_to_label_validation = client.get_label_validation( + &label_name, + &compare_config.dataset_name, + &compare_config.model_version, + )?; + + let (compare_to_precision, compare_to_recall) = get_precision_and_recall_for_threshold( + label_threshold.threshold, + label_name, + compare_to_label_validation, + )?; + + stream_stat.compare_to_precision = Some(compare_to_precision); + stream_stat.compare_to_recall = Some(compare_to_recall); + } + } + Ok(stream_stat) +} + +pub fn get_stream_stats( + client: &Client, + args: &GetStreamStatsArgs, + printer: &Printer, + pool: &mut Pool, +) -> Result<()> { + let GetStreamStatsArgs { + stream_full_name, + compare_to_model_version, + compare_to_dataset, + } = args; + + if compare_to_dataset.is_some() && compare_to_model_version.is_none() { + return Err(anyhow!( + "You cannot provide `compare_to_dataset` without `compare_to_model_version`" + )); + } + + info!("Getting Stream"); + let stream = client.get_stream(stream_full_name)?; + let model = stream.model.context("No model associated with stream.")?; + + let compare_config = get_compare_config( + client, + compare_to_model_version, + compare_to_dataset, + stream_full_name, + )?; + + let mut stream_stats = Vec::new(); + + let (sender, receiver) = channel(); + + pool.scoped(|scope| { + for label_threshold in &model.label_thresholds { + if label_threshold.threshold >= NotNan::new(1.0).expect("Could not create NotNan") { + // As the precision and recall will always be 0 + continue; + } + let sender = sender.clone(); + let model = model.clone(); + let compare_config = compare_config.clone(); + + scope.execute(move || { + let result = get_stream_stat( + label_threshold, + stream_full_name, + &model, + &compare_config, + client, + ); + sender.send(result).expect("Could not send result"); + }); + } + }); + + drop(sender); + let results: Vec> = receiver.iter().collect(); + + for result in results { + let stream_stat = result?; + stream_stats.push(stream_stat) + } + + stream_stats.sort_by(|a, b| a.label_name.0.cmp(&b.label_name.0)); + + printer.print_resources(&stream_stats)?; + Ok(()) +} + pub fn get_stream_comments(client: &Client, args: &GetStreamCommentsArgs) -> Result<()> { let GetStreamCommentsArgs { stream, diff --git a/cli/src/main.rs b/cli/src/main.rs index 45585bf1..cb3f9cd0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -56,9 +56,12 @@ fn run(args: Args) -> Result<()> { app.gen_completions_to("re", clap_shell, &mut io::stdout()); Ok(()) } - Command::Get { get_args } => { - get::run(get_args, client_from_args(&args, &config)?, &printer) - } + Command::Get { get_args } => get::run( + get_args, + client_from_args(&args, &config)?, + &printer, + &mut pool, + ), Command::Delete { delete_args } => { delete::run(delete_args, client_from_args(&args, &config)?) }