Skip to content

Commit

Permalink
feat(stream-stats): add command for displaying stream validation stats (
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser authored Nov 8, 2023
1 parent a63e314 commit 47759bc
Show file tree
Hide file tree
Showing 11 changed files with 445 additions and 28 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
## Unreleased

- Fix url used for fetching streams
- Return `is_end_sequence` on stream fetch
- Make `transform_tag` optional on `create bucket`
- Retry `put_emails` requests
- Add `get stream-stats` to expose and compare model validation

## v0.20.0

Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 89 additions & 2 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ use resources::{
project::ForceDeleteProject,
quota::{GetQuotasResponse, Quota},
source::StatisticsRequestParams as SourceStatisticsRequestParams,
stream::{NewStream, PutStreamRequest, PutStreamResponse},
stream::{GetStreamResponse, NewStream, PutStreamRequest, PutStreamResponse},
validation::{
LabelValidation, LabelValidationRequest, LabelValidationResponse, ValidationResponse,
},
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -58,7 +61,7 @@ use crate::resources::{
statistics::GetResponse as GetStatisticsResponse,
stream::{
AdvanceRequest as StreamAdvanceRequest, FetchRequest as StreamFetchRequest,
GetResponse as GetStreamsResponse, ResetRequest as StreamResetRequest,
GetStreamsResponse, ResetRequest as StreamResetRequest,
TagExceptionsRequest as TagStreamExceptionsRequest,
},
tenant_id::TenantId,
Expand Down Expand Up @@ -405,6 +408,32 @@ impl Client {
)
}

pub fn get_validation(
&self,
dataset_name: &DatasetFullName,
model_version: &ModelVersion,
) -> Result<ValidationResponse> {
self.get::<_, ValidationResponse>(self.endpoints.validation(dataset_name, model_version)?)
}

pub fn get_label_validation(
&self,
label: &LabelName,
dataset_name: &DatasetFullName,
model_version: &ModelVersion,
) -> Result<LabelValidation> {
Ok(self
.post::<_, _, LabelValidationResponse>(
self.endpoints
.label_validation(dataset_name, model_version)?,
LabelValidationRequest {
label: label.clone(),
},
Retry::Yes,
)?
.label_validation)
}

pub fn sync_comments(
&self,
source_name: &SourceFullName,
Expand Down Expand Up @@ -799,6 +828,12 @@ impl Client {
)
}

pub fn get_stream(&self, stream_name: &StreamFullName) -> Result<Stream> {
Ok(self
.get::<_, GetStreamResponse>(self.endpoints.stream(stream_name)?)?
.stream)
}

pub fn advance_stream(
&self,
stream_name: &StreamFullName,
Expand Down Expand Up @@ -1279,6 +1314,44 @@ impl Endpoints {
})
}

fn validation(
&self,
dataset_name: &DatasetFullName,
model_version: &ModelVersion,
) -> Result<Url> {
construct_endpoint(
&self.base,
&[
"api",
"_private",
"datasets",
&dataset_name.0,
"labellers",
&model_version.0.to_string(),
"validation",
],
)
}

fn label_validation(
&self,
dataset_name: &DatasetFullName,
model_version: &ModelVersion,
) -> Result<Url> {
construct_endpoint(
&self.base,
&[
"api",
"_private",
"datasets",
&dataset_name.0,
"labellers",
&model_version.0.to_string(),
"label-validation",
],
)
}

fn dataset_summary(&self, dataset_name: &DatasetFullName) -> Result<Url> {
construct_endpoint(
&self.base,
Expand All @@ -1300,6 +1373,20 @@ impl Endpoints {
)
}

fn stream(&self, stream_name: &StreamFullName) -> Result<Url> {
construct_endpoint(
&self.base,
&[
"api",
"v1",
"datasets",
&stream_name.dataset.0,
"streams",
&stream_name.stream.0,
],
)
}

fn stream_fetch(&self, stream_name: &StreamFullName) -> Result<Url> {
construct_endpoint(
&self.base,
Expand Down
1 change: 1 addition & 0 deletions api/src/resources/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod statistics;
pub mod stream;
pub mod tenant_id;
pub mod user;
pub mod validation;

use crate::error::{Error, Result};
use reqwest::StatusCode;
Expand Down
31 changes: 18 additions & 13 deletions api/src/resources/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use ordered_float::NotNan;
use serde::{Deserialize, Serialize};
use std::str::FromStr;

use crate::error::{Error, Result};
use crate::{
error::{Error, Result},
ModelVersion,
};

use super::{
comment::{Comment, CommentFilter, Entity, PredictedLabel, Uid as CommentUid},
Expand Down Expand Up @@ -67,7 +70,7 @@ pub struct NewStream {
}

impl NewStream {
pub fn set_model_version(&mut self, model_version: &UserModelVersion) {
pub fn set_model_version(&mut self, model_version: &ModelVersion) {
if let Some(model) = &mut self.model {
model.version = model_version.clone()
}
Expand All @@ -76,14 +79,14 @@ impl NewStream {

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamModel {
pub version: UserModelVersion,
pub version: ModelVersion,
pub label_thresholds: Vec<StreamLabelThreshold>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamLabelThreshold {
name: Vec<String>,
threshold: NotNan<f64>,
pub name: Vec<String>,
pub threshold: NotNan<f64>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand All @@ -109,19 +112,16 @@ pub struct Stream {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LabelFilter {
pub label: LabelName,
pub model_version: UserModelVersion,
pub model_version: ModelVersion,
pub threshold: NotNan<f64>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserModelVersion(pub u64);

impl FromStr for UserModelVersion {
impl FromStr for ModelVersion {
type Err = Error;

fn from_str(s: &str) -> Result<Self> {
match s.parse::<u64>() {
Ok(version) => Ok(UserModelVersion(version)),
match s.parse::<u32>() {
Ok(version) => Ok(ModelVersion(version)),
Err(_) => Err(Error::BadStreamModelVersion {
version: s.to_string(),
}),
Expand All @@ -146,10 +146,15 @@ pub struct StreamResult {
}

#[derive(Debug, Clone, Deserialize)]
pub(crate) struct GetResponse {
pub(crate) struct GetStreamsResponse {
pub streams: Vec<Stream>,
}

#[derive(Debug, Clone, Deserialize)]
pub(crate) struct GetStreamResponse {
pub stream: Stream,
}

#[derive(Debug, Clone, Serialize)]
pub(crate) struct FetchRequest {
pub size: u32,
Expand Down
33 changes: 33 additions & 0 deletions api/src/resources/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::{LabelGroup, LabelName};
use ordered_float::NotNan;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
pub struct LabelValidation {
pub thresholds: Vec<NotNan<f64>>,
pub precisions: Vec<NotNan<f64>>,
pub recalls: Vec<NotNan<f64>>,
}

#[derive(Serialize)]
pub struct LabelValidationRequest {
pub label: LabelName,
}

#[derive(Deserialize)]
pub struct LabelValidationResponse {
pub label_validation: LabelValidation,
}

#[derive(Clone, Deserialize)]
pub struct ValidationResponse {
pub label_groups: Vec<LabelGroup>,
}

impl ValidationResponse {
pub fn get_default_label_group(&self) -> Option<&LabelGroup> {
self.label_groups
.iter()
.find(|group| group.name.0 == "default")
}
}
1 change: 1 addition & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ scoped_threadpool = "0.1.9"
backoff = "0.4.0"
cfb = "0.9.0"
encoding_rs = "0.8.33"
ordered-float = { version = "3.9.1", features = ["serde"] }

[dev-dependencies]
pretty_assertions = "1.3.0"
Expand Down
8 changes: 3 additions & 5 deletions cli/src/commands/create/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ use std::{

use anyhow::{Context, Result};
use log::info;
use reinfer_client::{
resources::stream::{NewStream, UserModelVersion},
Client, DatasetIdentifier,
};
use reinfer_client::ModelVersion;
use reinfer_client::{resources::stream::NewStream, Client, DatasetIdentifier};

use structopt::StructOpt;

Expand All @@ -25,7 +23,7 @@ pub struct CreateStreamsArgs {

#[structopt(short = "v", long = "model-version")]
/// The model version for the new streams to use
model_version: UserModelVersion,
model_version: ModelVersion,
}

pub fn create(client: &Client, args: &CreateStreamsArgs) -> Result<()> {
Expand Down
10 changes: 8 additions & 2 deletions cli/src/commands/get/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod users;

use anyhow::Result;
use reinfer_client::Client;
use scoped_threadpool::Pool;
use structopt::StructOpt;

use self::{
Expand All @@ -17,7 +18,7 @@ use self::{
datasets::GetDatasetsArgs,
projects::GetProjectsArgs,
sources::GetSourcesArgs,
streams::{GetStreamCommentsArgs, GetStreamsArgs},
streams::{GetStreamCommentsArgs, GetStreamStatsArgs, GetStreamsArgs},
users::GetUsersArgs,
};
use crate::printer::Printer;
Expand Down Expand Up @@ -56,6 +57,10 @@ pub enum GetArgs {
/// Fetch comments from a stream
StreamComments(GetStreamCommentsArgs),

#[structopt(name = "stream-stats")]
/// Get the validation stats for a given stream
StreamStats(GetStreamStatsArgs),

#[structopt(name = "users")]
/// List the available users
Users(GetUsersArgs),
Expand All @@ -69,7 +74,7 @@ pub enum GetArgs {
Quotas,
}

pub fn run(args: &GetArgs, client: Client, printer: &Printer) -> Result<()> {
pub fn run(args: &GetArgs, client: Client, printer: &Printer, pool: &mut Pool) -> Result<()> {
match args {
GetArgs::Buckets(args) => buckets::get(&client, args, printer),
GetArgs::Comment(args) => comments::get_single(&client, args),
Expand All @@ -79,6 +84,7 @@ pub fn run(args: &GetArgs, client: Client, printer: &Printer) -> Result<()> {
GetArgs::Sources(args) => sources::get(&client, args, printer),
GetArgs::Streams(args) => streams::get(&client, args, printer),
GetArgs::StreamComments(args) => streams::get_stream_comments(&client, args),
GetArgs::StreamStats(args) => streams::get_stream_stats(&client, args, printer, pool),
GetArgs::Users(args) => users::get(&client, args, printer),
GetArgs::CurrentUser => users::get_current_user(&client, printer),
GetArgs::Quotas => quota::get(&client, printer),
Expand Down
Loading

0 comments on commit 47759bc

Please sign in to comment.