Skip to content

Commit

Permalink
feat(cli): add parse-aic-classification-csv (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser authored Aug 16, 2024
1 parent 1db3a13 commit 1440f93
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unreleased
- Add `parse aic-classification-csv`

# v0.31.0
- Add `get keyed sync states`
- Add `delete keyed sync states`
Expand Down
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions api/src/resources/comment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
str::FromStr,
};

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Id(pub String);

impl FromStr for Id {
Expand Down Expand Up @@ -259,7 +259,7 @@ pub struct Comment {
pub has_annotations: bool,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct NewComment {
pub id: Id,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -272,7 +272,7 @@ pub struct NewComment {
pub attachments: Vec<AttachmentMetadata>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct Message {
pub body: MessageBody,

Expand Down Expand Up @@ -301,7 +301,7 @@ pub struct Message {
pub sent_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct MessageBody {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down
1 change: 1 addition & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ ordered-float = { version = "3.9.1", features = ["serde"] }
mailparse = "0.14.0"
diff = "0.1.13"
rand = "0.8.5"
csv = "1.3.0"

[dev-dependencies]
pretty_assertions = "1.3.0"
Expand Down
24 changes: 12 additions & 12 deletions cli/src/commands/create/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
mod annotations;
mod bucket;
mod comments;
mod dataset;
mod emails;
mod integrations;
mod project;
mod quota;
mod source;
mod stream_exception;
mod streams;
mod user;
pub mod annotations;
pub mod bucket;
pub mod comments;
pub mod dataset;
pub mod emails;
pub mod integrations;
pub mod project;
pub mod quota;
pub mod source;
pub mod stream_exception;
pub mod streams;
pub mod user;

use self::{
annotations::CreateAnnotationsArgs, bucket::CreateBucketArgs, comments::CreateCommentsArgs,
Expand Down
193 changes: 193 additions & 0 deletions cli/src/commands/parse/aic_classification_csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use crate::{
commands::{
create::annotations::{upload_batch_of_annotations, CommentIdComment, NewAnnotation},
parse::{get_progress_bar, upload_batch_of_comments},
},
parse::Statistics,
};
use anyhow::Result;
use log::{error, info};
use scoped_threadpool::Pool;
use serde::Deserialize;
use std::sync::{mpsc::channel, Arc};

use reinfer_client::{
Client, CommentId, DatasetFullName, DatasetIdentifier, EitherLabelling, Label, Message,
MessageBody, NewComment, NewLabelling, Source, SourceIdentifier, DEFAULT_LABEL_GROUP_NAME,
};
use std::path::PathBuf;
use structopt::StructOpt;

const UPLOAD_BATCH_SIZE: usize = 4;

#[derive(Debug, StructOpt)]
pub struct ParseAicClassificationCsvArgs {
#[structopt(short = "f", long = "file", parse(from_os_str))]
/// Path to the csv to parse
file_path: PathBuf,

#[structopt(short = "s", long = "source")]
/// The source to upload the data to
source: SourceIdentifier,

#[structopt(short = "d", long = "dataset")]
/// The dataset to upload annotations to
dataset: DatasetIdentifier,

#[structopt(short = "n", long = "no-charge")]
/// Whether to attempt to bypass billing (internal only)
no_charge: bool,
}

#[derive(Deserialize)]
pub struct AicClassificationRecord {
input: String,
target: String,
}

#[allow(clippy::too_many_arguments)]
fn send_comments_if_needed(
comments: &mut Vec<NewComment>,
annotations: &mut Vec<NewAnnotation>,
force_send: bool,
pool: &mut Pool,
client: &Client,
source: &Source,
statistics: &Statistics,
dataset: &DatasetFullName,
no_charge: bool,
) -> Result<()> {
let thread_count = pool.thread_count();
let should_upload = comments.len() > (thread_count as usize * UPLOAD_BATCH_SIZE);

if !force_send && !should_upload {
return Ok(());
}

let chunks: Vec<_> = comments.chunks(UPLOAD_BATCH_SIZE).collect();

let (error_sender, error_receiver) = channel();
pool.scoped(|scope| {
for chunk in chunks {
scope.execute(|| {
let result = upload_batch_of_comments(client, source, chunk, no_charge, statistics);

if let Err(error) = result {
error_sender.send(error).expect("Could not send error");
}
});
}
});

upload_batch_of_annotations(
annotations,
client,
source,
statistics,
dataset,
pool,
false,
)?;

if let Ok(error) = error_receiver.try_recv() {
Err(error)
} else {
comments.clear();
annotations.clear();
Ok(())
}
}

pub fn parse(client: &Client, args: &ParseAicClassificationCsvArgs, pool: &mut Pool) -> Result<()> {
let ParseAicClassificationCsvArgs {
file_path,
source,
dataset,
no_charge,
} = args;

let source = client.get_source(source.clone())?;
let dataset = client.get_dataset(dataset.clone())?;
let record_count = csv::Reader::from_path(file_path)?.records().count();

let statistics = Arc::new(Statistics::new());
let _progress = get_progress_bar(record_count as u64, &statistics);

let mut reader = csv::Reader::from_path(file_path)?;

let headers = reader.headers()?.clone();

let mut comments: Vec<NewComment> = Vec::new();
let mut annotations: Vec<NewAnnotation> = Vec::new();
for (idx, row) in reader.records().enumerate() {
match row {
Ok(row) => {
let record: AicClassificationRecord = row.deserialize(Some(&headers))?;
let comment_id = CommentId(idx.to_string());

comments.push(NewComment {
id: comment_id.clone(),
timestamp: chrono::Utc::now(),
messages: vec![Message {
body: MessageBody {
text: record.input,
..Default::default()
},
..Default::default()
}],
..Default::default()
});
annotations.push(NewAnnotation {
comment: CommentIdComment { id: comment_id },
labelling: Some(EitherLabelling::Labelling(vec![NewLabelling {
group: DEFAULT_LABEL_GROUP_NAME.clone(),
assigned: Some(vec![Label {
name: reinfer_client::LabelName(record.target),
sentiment: reinfer_client::Sentiment::Positive,
metadata: None,
}]),
dismissed: None,
}])),
entities: None,
moon_forms: None,
});

send_comments_if_needed(
&mut comments,
&mut annotations,
false,
pool,
client,
&source,
&statistics,
&dataset.full_name(),
*no_charge,
)?;
statistics.increment_processed()
}
Err(_) => {
error!("Failed to process row {}", idx);
statistics.increment_failed();
statistics.increment_processed()
}
}
}
send_comments_if_needed(
&mut comments,
&mut annotations,
true,
pool,
client,
&source,
&statistics,
&dataset.full_name(),
*no_charge,
)?;

info!(
"Uploaded {}. {} Failed",
statistics.num_uploaded(),
statistics.num_failed()
);
Ok(())
}
Loading

0 comments on commit 1440f93

Please sign in to comment.