Skip to content

Commit

Permalink
feat(cli): add stop after on get comments
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser committed Oct 31, 2024
1 parent 820e7c9 commit 08fdf0b
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 152 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Unreleased
- update `reinfer.io` urls to `reinfer.dev`
- fix validation when providing property filter as json
- add stop after on `get comments`

# v0.34.0
- Round trip `field_id`
Expand Down
312 changes: 160 additions & 152 deletions cli/src/commands/get/comments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ pub struct GetManyCommentsArgs {
#[structopt(long = "--shuffle")]
/// Whether to return comments in a random order
shuffle: Option<bool>,

#[structopt(long = "--stop-after")]
/// Stop downloading comments after X comments (stops in following batch)
stop_after: Option<usize>,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -453,6 +457,7 @@ pub fn get_many(client: &Client, args: &GetManyCommentsArgs) -> Result<()> {
include_attachment_content,
only_with_attachments,
shuffle,
stop_after,
} = args;

let by_timerange = from_timestamp.is_some() || to_timestamp.is_some();
Expand Down Expand Up @@ -602,6 +607,7 @@ pub fn get_many(client: &Client, args: &GetManyCommentsArgs) -> Result<()> {
attachments_dir,
only_with_attachments_filter,
shuffle: shuffle.unwrap_or(false),
stop_after: *stop_after,
};

if let Some(file) = jsonl_file {
Expand Down Expand Up @@ -658,6 +664,7 @@ struct CommentDownloadOptions {
attachments_dir: Option<PathBuf>,
only_with_attachments_filter: Option<AttributeFilter>,
shuffle: bool,
stop_after: Option<usize>,
}

impl CommentDownloadOptions {
Expand Down Expand Up @@ -752,46 +759,42 @@ fn download_comments(
dataset_name,
source,
&statistics,
options.include_predictions,
options.attachments_dir,
writer,
options,
)?;
} else {
get_comments_from_uids(
client,
dataset_name,
source,
&statistics,
options.include_predictions,
options.model_version,
writer,
&options,
)?;
get_comments_from_uids(client, dataset_name, source, &statistics, writer, &options)?;
}
} else {
let _progress = if options.show_progress {
Some(make_progress(None)?)
} else {
None
};
client
.get_comments_iter(&source.full_name(), None, options.timerange)
.try_for_each(|page| {
let page = page.context("Operation to get comments has failed.")?;
statistics.add_comments(page.len());

print_resources_as_json(
page.into_iter().map(|comment| AnnotatedComment {
comment,
labelling: None,
entities: None,
thread_properties: None,
moon_forms: None,
label_properties: None,
}),
&mut writer,
)
})?;
for page in client.get_comments_iter(&source.full_name(), None, options.timerange) {
let page = page.context("Operation to get comments has failed.")?;

if options
.stop_after
.is_some_and(|stop_after| statistics.num_downloaded() >= stop_after)
{
break;
}

statistics.add_comments(page.len());

print_resources_as_json(
page.into_iter().map(|comment| AnnotatedComment {
comment,
labelling: None,
entities: None,
thread_properties: None,
moon_forms: None,
label_properties: None,
}),
&mut writer,
)?;
}
}
log::info!(
"Successfully downloaded {} comments [{} annotated].",
Expand All @@ -809,8 +812,6 @@ fn get_comments_from_uids(
dataset_name: DatasetFullName,
source: Source,
statistics: &Arc<Statistics>,
include_predictions: bool,
model_version: Option<u32>,
mut writer: impl Write,
options: &CommentDownloadOptions,
) -> Result<()> {
Expand All @@ -837,104 +838,109 @@ fn get_comments_from_uids(
},
};

client
.get_dataset_query_iter(&dataset_name, &mut params)
.try_for_each(|page| {
let page = page.context("Operation to get comments has failed.")?;
if page.is_empty() {
return Ok(());
}
for page in client.get_dataset_query_iter(&dataset_name, &mut params) {
let page = page.context("Operation to get comments has failed.")?;
if page.is_empty() {
return Ok(());
}

statistics.add_comments(page.len());
if options
.stop_after
.is_some_and(|stop_after| statistics.num_downloaded() >= stop_after)
{
break;
}

if let Some(model_version) = &model_version {
let predictions = client
.get_comment_predictions(
&dataset_name,
&ModelVersion(*model_version),
page.iter().map(|comment| &comment.comment.uid),
Some(CommentPredictionsThreshold::Auto),
None,
)
.context("Operation to get predictions has failed.")?;
// since predict-comments endpoint doesn't return some fields,
// they are set to None or [] here
let comments: Vec<_> = page
.into_iter()
.zip(predictions.into_iter())
.map(|(comment, prediction)| AnnotatedComment {
comment: comment.comment,
labelling: Some(vec![Labelling {
group: DEFAULT_LABEL_GROUP_NAME.clone(),
assigned: Vec::new(),
dismissed: Vec::new(),
predicted: prediction.labels.map(|auto_threshold_labels| {
auto_threshold_labels
.iter()
.map(|auto_threshold_label| PredictedLabel {
name: auto_threshold_label.name.clone(),
sentiment: None,
probability: auto_threshold_label.probability,
auto_thresholds: Some(
auto_threshold_label
.auto_thresholds
.clone()
.expect("Could not get auto thresholds")
.to_vec(),
),
})
.collect()
}),
}]),
entities: Some(Entities {
assigned: Vec::new(),
dismissed: Vec::new(),
predicted: prediction.entities,
statistics.add_comments(page.len());

if let Some(model_version) = &options.model_version {
let predictions = client
.get_comment_predictions(
&dataset_name,
&ModelVersion(*model_version),
page.iter().map(|comment| &comment.comment.uid),
Some(CommentPredictionsThreshold::Auto),
None,
)
.context("Operation to get predictions has failed.")?;
// since predict-comments endpoint doesn't return some fields,
// they are set to None or [] here
let comments: Vec<_> = page
.into_iter()
.zip(predictions.into_iter())
.map(|(comment, prediction)| AnnotatedComment {
comment: comment.comment,
labelling: Some(vec![Labelling {
group: DEFAULT_LABEL_GROUP_NAME.clone(),
assigned: Vec::new(),
dismissed: Vec::new(),
predicted: prediction.labels.map(|auto_threshold_labels| {
auto_threshold_labels
.iter()
.map(|auto_threshold_label| PredictedLabel {
name: auto_threshold_label.name.clone(),
sentiment: None,
probability: auto_threshold_label.probability,
auto_thresholds: Some(
auto_threshold_label
.auto_thresholds
.clone()
.expect("Could not get auto thresholds")
.to_vec(),
),
})
.collect()
}),
thread_properties: None,
moon_forms: None,
label_properties: None,
})
.collect();

if let Some(attachments_dir) = &options.attachments_dir {
comments.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(
client,
attachments_dir,
&comment.comment,
statistics,
)
})?;
}
print_resources_as_json(comments, &mut writer)
} else {
let comments: Vec<_> = page
.into_iter()
.map(|mut annotated_comment| {
if !include_predictions {
annotated_comment = annotated_comment.without_predictions();
}
if annotated_comment.has_annotations() {
statistics.add_annotated(1);
}
annotated_comment
})
.collect();
if let Some(attachments_dir) = &options.attachments_dir {
comments.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(
client,
attachments_dir,
&comment.comment,
statistics,
)
})?;
}
}]),
entities: Some(Entities {
assigned: Vec::new(),
dismissed: Vec::new(),
predicted: prediction.entities,
}),
thread_properties: None,
moon_forms: None,
label_properties: None,
})
.collect();

print_resources_as_json(comments, &mut writer)
if let Some(attachments_dir) = &options.attachments_dir {
comments.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(
client,
attachments_dir,
&comment.comment,
statistics,
)
})?;
}
})?;
print_resources_as_json(comments, &mut writer)?;
} else {
let comments: Vec<_> = page
.into_iter()
.map(|mut annotated_comment| {
if !options.include_predictions {
annotated_comment = annotated_comment.without_predictions();
}
if annotated_comment.has_annotations() {
statistics.add_annotated(1);
}
annotated_comment
})
.collect();
if let Some(attachments_dir) = &options.attachments_dir {
comments.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(
client,
attachments_dir,
&comment.comment,
statistics,
)
})?;
}

print_resources_as_json(comments, &mut writer)?;
}
}
Ok(())
}

Expand Down Expand Up @@ -974,38 +980,40 @@ fn get_reviewed_comments_in_bulk(
dataset_name: DatasetFullName,
source: Source,
statistics: &Arc<Statistics>,
include_predictions: bool,
attachments_dir: Option<PathBuf>,
mut writer: impl Write,
options: CommentDownloadOptions,
) -> Result<()> {
client
.get_labellings_iter(&dataset_name, &source.id, include_predictions, None)
.try_for_each(|page| {
let page = page.context("Operation to get labellings has failed.")?;
statistics.add_comments(page.len());
statistics.add_annotated(page.len());
for page in
client.get_labellings_iter(&dataset_name, &source.id, options.include_predictions, None)
{
let page = page.context("Operation to get labellings has failed.")?;

if let Some(attachments_dir) = &attachments_dir {
page.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(
client,
attachments_dir,
&comment.comment,
statistics,
)
})?;
}
if options
.stop_after
.is_some_and(|stop_after| statistics.num_downloaded() >= stop_after)
{
break;
}

let comments = page.into_iter().map(|comment| {
if !include_predictions {
comment.without_predictions()
} else {
comment
}
});
statistics.add_comments(page.len());
statistics.add_annotated(page.len());

print_resources_as_json(comments, &mut writer)
})?;
if let Some(attachments_dir) = &options.attachments_dir {
page.iter().try_for_each(|comment| -> Result<()> {
download_comment_attachments(client, attachments_dir, &comment.comment, statistics)
})?;
}

let comments = page.into_iter().map(|comment| {
if !options.include_predictions {
comment.without_predictions()
} else {
comment
}
});

print_resources_as_json(comments, &mut writer)?;
}
Ok(())
}

Expand Down
Loading

0 comments on commit 08fdf0b

Please sign in to comment.