Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): add parse-aic-classification-csv #307

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Unreleased
- Add `parse aic-classification-csv`

# v0.31.0
- Add `get keyed sync states`
- Add `delete keyed sync states`
Expand Down
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions api/src/resources/comment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
str::FromStr,
};

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Id(pub String);

impl FromStr for Id {
Expand Down Expand Up @@ -259,7 +259,7 @@ pub struct Comment {
pub has_annotations: bool,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct NewComment {
pub id: Id,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -272,7 +272,7 @@ pub struct NewComment {
pub attachments: Vec<AttachmentMetadata>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct Message {
pub body: MessageBody,

Expand Down Expand Up @@ -301,7 +301,7 @@ pub struct Message {
pub sent_at: Option<DateTime<Utc>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct MessageBody {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down
1 change: 1 addition & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ ordered-float = { version = "3.9.1", features = ["serde"] }
mailparse = "0.14.0"
diff = "0.1.13"
rand = "0.8.5"
csv = "1.3.0"

[dev-dependencies]
pretty_assertions = "1.3.0"
Expand Down
24 changes: 12 additions & 12 deletions cli/src/commands/create/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
mod annotations;
mod bucket;
mod comments;
mod dataset;
mod emails;
mod integrations;
mod project;
mod quota;
mod source;
mod stream_exception;
mod streams;
mod user;
pub mod annotations;
pub mod bucket;
pub mod comments;
pub mod dataset;
pub mod emails;
pub mod integrations;
pub mod project;
pub mod quota;
pub mod source;
pub mod stream_exception;
pub mod streams;
pub mod user;

use self::{
annotations::CreateAnnotationsArgs, bucket::CreateBucketArgs, comments::CreateCommentsArgs,
Expand Down
193 changes: 193 additions & 0 deletions cli/src/commands/parse/aic_classification_csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use crate::{
commands::{
create::annotations::{upload_batch_of_annotations, CommentIdComment, NewAnnotation},
parse::{get_progress_bar, upload_batch_of_comments},
},
parse::Statistics,
};
use anyhow::Result;
use log::{error, info};
use scoped_threadpool::Pool;
use serde::Deserialize;
use std::sync::{mpsc::channel, Arc};

use reinfer_client::{
Client, CommentId, DatasetFullName, DatasetIdentifier, EitherLabelling, Label, Message,
MessageBody, NewComment, NewLabelling, Source, SourceIdentifier, DEFAULT_LABEL_GROUP_NAME,
};
use std::path::PathBuf;
use structopt::StructOpt;

const UPLOAD_BATCH_SIZE: usize = 4;

#[derive(Debug, StructOpt)]
pub struct ParseAicClassificationCsvArgs {
#[structopt(short = "f", long = "file", parse(from_os_str))]
/// Path to the csv to parse
file_path: PathBuf,

#[structopt(short = "s", long = "source")]
/// The source to upload the data to
source: SourceIdentifier,

#[structopt(short = "d", long = "dataset")]
/// The dataset to upload annotations to
dataset: DatasetIdentifier,

#[structopt(short = "n", long = "no-charge")]
/// Whether to attempt to bypass billing (internal only)
no_charge: bool,
}

#[derive(Deserialize)]
pub struct AicClassificationRecord {
input: String,
target: String,
}

#[allow(clippy::too_many_arguments)]
fn send_comments_if_needed(
comments: &mut Vec<NewComment>,
annotations: &mut Vec<NewAnnotation>,
force_send: bool,
pool: &mut Pool,
client: &Client,
source: &Source,
statistics: &Statistics,
dataset: &DatasetFullName,
no_charge: bool,
) -> Result<()> {
let thread_count = pool.thread_count();
let should_upload = comments.len() > (thread_count as usize * UPLOAD_BATCH_SIZE);

if !force_send && !should_upload {
return Ok(());
}

let chunks: Vec<_> = comments.chunks(UPLOAD_BATCH_SIZE).collect();

let (error_sender, error_receiver) = channel();
pool.scoped(|scope| {
for chunk in chunks {
scope.execute(|| {
let result = upload_batch_of_comments(client, source, chunk, no_charge, statistics);

if let Err(error) = result {
error_sender.send(error).expect("Could not send error");
}
});
}
});

upload_batch_of_annotations(
annotations,
client,
source,
statistics,
dataset,
pool,
false,
)?;

if let Ok(error) = error_receiver.try_recv() {
Err(error)
} else {
comments.clear();
annotations.clear();
Ok(())
}
}

pub fn parse(client: &Client, args: &ParseAicClassificationCsvArgs, pool: &mut Pool) -> Result<()> {
let ParseAicClassificationCsvArgs {
file_path,
source,
dataset,
no_charge,
} = args;

let source = client.get_source(source.clone())?;
let dataset = client.get_dataset(dataset.clone())?;
let record_count = csv::Reader::from_path(file_path)?.records().count();

let statistics = Arc::new(Statistics::new());
let _progress = get_progress_bar(record_count as u64, &statistics);

let mut reader = csv::Reader::from_path(file_path)?;

let headers = reader.headers()?.clone();

let mut comments: Vec<NewComment> = Vec::new();
let mut annotations: Vec<NewAnnotation> = Vec::new();
for (idx, row) in reader.records().enumerate() {
match row {
Ok(row) => {
let record: AicClassificationRecord = row.deserialize(Some(&headers))?;
let comment_id = CommentId(idx.to_string());

comments.push(NewComment {
id: comment_id.clone(),
timestamp: chrono::Utc::now(),
messages: vec![Message {
body: MessageBody {
text: record.input,
..Default::default()
},
..Default::default()
}],
..Default::default()
});
annotations.push(NewAnnotation {
comment: CommentIdComment { id: comment_id },
labelling: Some(EitherLabelling::Labelling(vec![NewLabelling {
group: DEFAULT_LABEL_GROUP_NAME.clone(),
assigned: Some(vec![Label {
name: reinfer_client::LabelName(record.target),
sentiment: reinfer_client::Sentiment::Positive,
metadata: None,
}]),
dismissed: None,
}])),
entities: None,
moon_forms: None,
});

send_comments_if_needed(
&mut comments,
&mut annotations,
false,
pool,
client,
&source,
&statistics,
&dataset.full_name(),
*no_charge,
)?;
statistics.increment_processed()
}
Err(_) => {
error!("Failed to process row {}", idx);
statistics.increment_failed();
statistics.increment_processed()
}
}
}
send_comments_if_needed(
&mut comments,
&mut annotations,
true,
pool,
client,
&source,
&statistics,
&dataset.full_name(),
*no_charge,
)?;

info!(
"Uploaded {}. {} Failed",
statistics.num_uploaded(),
statistics.num_failed()
);
Ok(())
}
Loading
Loading