Skip to content

Commit

Permalink
feat(commands): add custom label trend report (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser authored Sep 9, 2024
1 parent 60e4e44 commit 06468c9
Show file tree
Hide file tree
Showing 9 changed files with 918 additions and 97 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Unreleased
- add custom label trend report
- Add validation to dataset `--stats`
- fix issue when adding configs from url

Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 36 additions & 10 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use resources::{
bucket_statistics::GetBucketStatisticsResponse,
comment::{AttachmentReference, CommentTimestampFilter},
dataset::{
QueryRequestParams, QueryResponse,
StatisticsRequestParams as DatasetStatisticsRequestParams, SummaryRequestParams,
SummaryResponse,
GetAllModelsInDatasetRequest, GetAllModelsInDatasetRespone, QueryRequestParams,
QueryResponse, StatisticsRequestParams as DatasetStatisticsRequestParams,
SummaryRequestParams, SummaryResponse, UserModelMetadata,
},
documents::{Document, SyncRawEmailsRequest, SyncRawEmailsResponse},
email::{Email, GetEmailResponse},
Expand Down Expand Up @@ -106,11 +106,13 @@ pub use crate::{
Identifier as BucketIdentifier, Name as BucketName, NewBucket,
},
comment::{
AnnotatedComment, Comment, CommentFilter, CommentsIterPage, Continuation,
EitherLabelling, Entities, Entity, HasAnnotations, Id as CommentId, Label, Labelling,
AnnotatedComment, Comment, CommentFilter, CommentPredictionsThreshold,
CommentsIterPage, Continuation, EitherLabelling, Entities, Entity,
GetCommentPredictionsRequest, HasAnnotations, Id as CommentId, Label, Labelling,
Message, MessageBody, MessageSignature, MessageSubject, NewAnnotatedComment,
NewComment, NewEntities, NewLabelling, NewMoonForm, PredictedLabel, Prediction,
PropertyMap, PropertyValue, Sentiment, SyncCommentsResponse, Uid as CommentUid,
PropertyMap, PropertyValue, Sentiment, SyncCommentsResponse, TriggerLabelThreshold,
Uid as CommentUid,
},
dataset::{
Dataset, FullName as DatasetFullName, Id as DatasetId, Identifier as DatasetIdentifier,
Expand Down Expand Up @@ -629,6 +631,16 @@ impl Client {
self.get::<_, ValidationResponse>(self.endpoints.validation(dataset_name, model_version)?)
}

pub fn get_labellers(&self, dataset_name: &DatasetFullName) -> Result<Vec<UserModelMetadata>> {
Ok(self
.post::<_, _, GetAllModelsInDatasetRespone>(
self.endpoints.labellers(dataset_name)?,
GetAllModelsInDatasetRequest {},
Retry::Yes,
)?
.labellers)
}

pub fn get_label_validation(
&self,
label: &LabelName,
Expand Down Expand Up @@ -973,15 +985,22 @@ impl Client {
dataset_name: &DatasetFullName,
model_version: &ModelVersion,
comment_uids: impl Iterator<Item = &'a CommentUid>,
threshold: Option<CommentPredictionsThreshold>,
labels: Option<Vec<TriggerLabelThreshold>>,
) -> Result<Vec<Prediction>> {
Ok(self
.post::<_, _, GetPredictionsResponse>(
self.endpoints
.get_comment_predictions(dataset_name, model_version)?,
json!({
"threshold": "auto",
"uids": comment_uids.into_iter().map(|id| id.0.as_str()).collect::<Vec<_>>(),
}),
GetCommentPredictionsRequest {
uids: comment_uids
.into_iter()
.map(|id| id.0.clone())
.collect::<Vec<_>>(),

threshold,
labels,
},
Retry::Yes,
)?
.predictions)
Expand Down Expand Up @@ -2151,6 +2170,13 @@ impl Endpoints {
)
}

fn labellers(&self, dataset_name: &DatasetFullName) -> Result<Url> {
construct_endpoint(
&self.base,
&["api", "_private", "datasets", &dataset_name.0, "labellers"],
)
}

fn get_comment_predictions(
&self,
dataset_name: &DatasetFullName,
Expand Down
54 changes: 41 additions & 13 deletions api/src/resources/comment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,18 @@ type UserPropertyName = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPropertiesFilter(pub HashMap<UserPropertyName, PropertyFilter>);

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct PropertyFilter {
#[serde(skip_serializing_if = "<[_]>::is_empty", default)]
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub one_of: Vec<PropertyValue>,
#[serde(skip_serializing_if = "<[_]>::is_empty", default)]
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub not_one_of: Vec<PropertyValue>,
#[serde(skip_serializing_if = "<[_]>::is_empty", default)]
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub domain_not_one_of: Vec<PropertyValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub minimum: Option<NotNan<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum: Option<NotNan<f64>>,
}

impl PropertyFilter {
Expand All @@ -90,6 +94,7 @@ impl PropertyFilter {
one_of,
not_one_of,
domain_not_one_of,
..Default::default()
}
}
}
Expand Down Expand Up @@ -122,6 +127,27 @@ pub struct MessagesFilter {
pub to: Option<PropertyFilter>,
}

#[derive(Debug, Clone, Serialize)]
#[serde(rename_all(serialize = "lowercase"))]
pub enum CommentPredictionsThreshold {
Auto,
}

#[derive(Debug, Clone, Serialize)]
pub struct TriggerLabelThreshold {
pub name: Vec<String>,
pub threshold: NotNan<f64>,
}

#[derive(Debug, Clone, Serialize)]
pub struct GetCommentPredictionsRequest {
pub uids: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub threshold: Option<CommentPredictionsThreshold>,
#[serde(skip_serializing_if = "Option::is_none")]
pub labels: Option<Vec<TriggerLabelThreshold>>,
}

#[derive(Debug, Clone, Serialize)]
pub(crate) struct GetRecentRequest<'a> {
pub limit: usize,
Expand Down Expand Up @@ -488,7 +514,7 @@ pub struct AnnotatedComment {
pub struct Prediction {
pub uid: Uid,
#[serde(skip_serializing_if = "should_skip_serializing_optional_vec")]
pub labels: Option<Vec<AutoThresholdLabel>>,
pub labels: Option<Vec<PredictedLabel>>,
#[serde(skip_serializing_if = "should_skip_serializing_optional_vec")]
pub entities: Option<Vec<Entity>>,
}
Expand Down Expand Up @@ -683,13 +709,22 @@ pub struct Label {
pub metadata: Option<HashMap<String, JsonValue>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Hash, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum PredictedLabelName {
Parts(Vec<String>),
String(LabelName),
}

impl PredictedLabelName {
pub fn to_label_name(&self) -> LabelName {
match self {
Self::Parts(parts) => LabelName(parts.join(" > ")),
Self::String(string) => string.clone(),
}
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct PredictedLabel {
pub name: PredictedLabelName,
Expand All @@ -700,13 +735,6 @@ pub struct PredictedLabel {
pub auto_thresholds: Option<Vec<String>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct AutoThresholdLabel {
pub name: Vec<String>,
pub probability: NotNan<f64>,
pub auto_thresholds: Vec<String>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct LabelProperty {
pub id: String,
Expand Down
40 changes: 35 additions & 5 deletions api/src/resources/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ pub struct AttributeFilter {
pub filter: AttributeFilterEnum,
}

#[derive(Debug, Clone, Serialize, Default)]
pub struct GetAllModelsInDatasetRequest {}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserModelMetadata {
pub version: ModelVersion,
}

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GetAllModelsInDatasetRespone {
pub labellers: Vec<UserModelMetadata>,
}

#[derive(Debug, Clone, Serialize, Default)]
pub struct StatisticsRequestParams {
#[serde(skip_serializing_if = "<[_]>::is_empty")]
Expand All @@ -149,15 +162,21 @@ pub enum OrderEnum {
Sample { seed: usize },
}

#[derive(Debug, Clone, Serialize)]
impl Default for OrderEnum {
fn default() -> Self {
Self::Recent
}
}

#[derive(Debug, Clone, Serialize, Default)]
pub struct SummaryRequestParams {
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub attribute_filters: Vec<AttributeFilter>,

pub filter: CommentFilter,
}

#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Default, Clone, Serialize)]
pub struct QueryRequestParams {
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub attribute_filters: Vec<AttributeFilter>,
Expand All @@ -173,14 +192,25 @@ pub struct QueryRequestParams {
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserPropertySummary {
pub struct UserPropertySummaryValue {
pub value: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserPropertySummaryString {
pub full_name: String,
pub values: Vec<UserPropertySummaryValue>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserPropertySummaryNumber {
pub full_name: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserPropertySummaryList {
pub string: Vec<UserPropertySummary>,
pub number: Vec<UserPropertySummary>,
pub string: Vec<UserPropertySummaryString>,
pub number: Vec<UserPropertySummaryNumber>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down
2 changes: 1 addition & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ structopt = { version = "0.3.26", default-features = false }
url = { version = "2.3.1", features = ["serde"] }

reinfer-client = { version = "0.32.0", path = "../api" }
dialoguer = "0.10.4"
dialoguer = { version="0.11.0", features = ["fuzzy-select"] }
scoped_threadpool = "0.1.9"
backoff = "0.4.0"
cfb = "0.9.0"
Expand Down
Loading

0 comments on commit 06468c9

Please sign in to comment.