Skip to content

Commit

Permalink
feat(commands): add dataset flags (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser authored Aug 29, 2024
1 parent 1440f93 commit 4776788
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Unreleased
- Add dataset flags to `create-dataset`
- Add `parse aic-classification-csv`

# v0.31.0
Expand Down
15 changes: 15 additions & 0 deletions api/src/resources/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ use std::{
str::FromStr,
};

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum DatasetFlag {
Gpt4,
ExternalMoonLlm,
Qos,
ZeroShotLabels,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct Dataset {
pub id: Id,
Expand All @@ -36,6 +45,8 @@ pub struct Dataset {
pub general_fields: Vec<GeneralFieldDef>,
pub label_defs: Vec<LabelDef>,
pub label_groups: Vec<LabelGroup>,
#[serde(rename = "_dataset_flags")]
pub dataset_flags: Vec<DatasetFlag>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
Expand Down Expand Up @@ -267,6 +278,10 @@ pub struct NewDataset<'request> {

#[serde(skip_serializing_if = "Option::is_none")]
pub copy_annotations_from: Option<&'request str>,

#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(rename = "_dataset_flags")]
pub dataset_flags: Vec<DatasetFlag>,
}

#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
Expand Down
57 changes: 54 additions & 3 deletions cli/src/commands/create/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::printer::Printer;
use anyhow::{anyhow, Context, Error, Result};
use anyhow::{anyhow, bail, Context, Error, Result};
use log::info;
use reinfer_client::{
resources::entity_def::NewGeneralFieldDef, Client, DatasetFullName, NewDataset, NewEntityDef,
NewLabelDef, NewLabelGroup, SourceIdentifier,
resources::{dataset::DatasetFlag, entity_def::NewGeneralFieldDef},
Client, DatasetFullName, NewDataset, NewEntityDef, NewLabelDef, NewLabelGroup,
SourceIdentifier,
};
use serde::Deserialize;
use std::str::FromStr;
Expand Down Expand Up @@ -58,6 +59,22 @@ pub struct CreateDatasetArgs {
/// Dataset ID of the dataset to copy annotations from
#[structopt(long = "copy-annotations-from")]
copy_annotations_from: Option<String>,

/// Whether the dataset should have QoS enabled
#[structopt(long = "qos")]
qos: Option<bool>,

/// Whether to use the external llm
#[structopt(long = "external-llm")]
external_llm: Option<bool>,

/// Whether to use generative ai features
#[structopt(long = "gen-ai")]
gen_ai: Option<bool>,

/// Whether to use zero shot ai features
#[structopt(long = "zero-shot")]
zero_shot: Option<bool>,
}

pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> Result<()> {
Expand All @@ -73,6 +90,10 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R
label_groups,
model_family,
copy_annotations_from,
qos,
external_llm,
gen_ai,
zero_shot,
} = args;

let source_ids = {
Expand All @@ -88,6 +109,35 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R
source_ids
};

let get_dataset_flags = || -> Result<Vec<DatasetFlag>> {
if external_llm.unwrap_or_default() && !gen_ai.unwrap_or_default() {
bail!("External Llm can only be used if gen ai features are enabled. Please add `--gen-ai true`")
}

if zero_shot.unwrap_or_default() && !gen_ai.unwrap_or_default() {
bail!("Zero shot can only be used if gen ai features are enabled. Please add `--gen-ai true`")
}

let mut dataset_flags = Vec::new();

if gen_ai.unwrap_or_default() {
dataset_flags.push(DatasetFlag::Gpt4)
}

if external_llm.unwrap_or_default() {
dataset_flags.push(DatasetFlag::ExternalMoonLlm)
}

if zero_shot.unwrap_or_default() {
dataset_flags.push(DatasetFlag::ZeroShotLabels)
}

if qos.unwrap_or_default() {
dataset_flags.push(DatasetFlag::Qos)
}
Ok(dataset_flags)
};

// Unwrap the inner values, we only need the outer for argument parsing
let entity_defs = &entity_defs.0;
let general_fields = &general_fields.0;
Expand Down Expand Up @@ -124,6 +174,7 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R
},
model_family: model_family.as_deref(),
copy_annotations_from: copy_annotations_from.as_deref(),
dataset_flags: get_dataset_flags()?,
},
)
.context("Operation to create a dataset has failed.")?;
Expand Down
174 changes: 172 additions & 2 deletions cli/tests/test_datasets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use backoff::{retry, ExponentialBackoff};
use pretty_assertions::assert_eq;
use reinfer_client::{
Dataset, EntityDef, EntityName, LabelDef, LabelDefPretrained, LabelDefPretrainedId, LabelGroup,
LabelGroupName, LabelName, MoonFormFieldDef, Source,
resources::dataset::DatasetFlag, Dataset, EntityDef, EntityName, LabelDef, LabelDefPretrained,
LabelDefPretrainedId, LabelGroup, LabelGroupName, LabelName, MoonFormFieldDef, Source,
};
use serde_json::json;
use uuid::Uuid;
Expand Down Expand Up @@ -334,6 +334,176 @@ fn test_create_dataset_with_source() {
assert_eq!(&source_info.owner.0, source.owner());
assert_eq!(&source_info.name.0, source.name());
}
#[test]
fn test_create_dataset_with_gen_ai() {
let cli = TestCli::get();

// Run with false ellm Flag
let dataset_gen_ai_false = TestDataset::new_args(&[&format!("--gen-ai={}", false)]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_gen_ai_false.identifier(),
]);
let dataset_gen_ai_false_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(
&dataset_gen_ai_false_info.owner.0,
dataset_gen_ai_false.owner()
);
assert_eq!(
&dataset_gen_ai_false_info.name.0,
dataset_gen_ai_false.name()
);
assert!(!dataset_gen_ai_false_info
.dataset_flags
.contains(&DatasetFlag::Gpt4));

// Run with true gen_ai Flag
let dataset_gen_ai = TestDataset::new_args(&[&format!("--gen-ai={}", true)]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_gen_ai.identifier(),
]);
let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_info.owner.0, dataset_gen_ai.owner());
assert_eq!(&dataset_info.name.0, dataset_gen_ai.name());
assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Gpt4));
}
#[test]
fn test_create_dataset_with_zero_shot() {
let cli = TestCli::get();

// Run with false ellm Flag
let dataset_zero_shot_false = TestDataset::new_args(&[&format!("--zero-shot={}", false)]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_zero_shot_false.identifier(),
]);
let dataset_zero_shot_false_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(
&dataset_zero_shot_false_info.owner.0,
dataset_zero_shot_false.owner()
);
assert_eq!(
&dataset_zero_shot_false_info.name.0,
dataset_zero_shot_false.name()
);
assert!(!dataset_zero_shot_false_info
.dataset_flags
.contains(&DatasetFlag::ZeroShotLabels));

// Run with true zero_shot Flag
let dataset_zero_shot =
TestDataset::new_args(&[&format!("--zero-shot={}", true), "--gen-ai=true"]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_zero_shot.identifier(),
]);
let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_info.owner.0, dataset_zero_shot.owner());
assert_eq!(&dataset_info.name.0, dataset_zero_shot.name());
assert!(dataset_info
.dataset_flags
.contains(&DatasetFlag::ZeroShotLabels));
}

#[test]
fn test_create_dataset_with_external_llm() {
let cli = TestCli::get();

// Run with false ellm Flag
let dataset_ellm_false = TestDataset::new_args(&[&format!("--external-llm={}", false)]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_ellm_false.identifier(),
]);
let dataset_ellm_false_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_ellm_false_info.owner.0, dataset_ellm_false.owner());
assert_eq!(&dataset_ellm_false_info.name.0, dataset_ellm_false.name());
assert!(!dataset_ellm_false_info
.dataset_flags
.contains(&DatasetFlag::ExternalMoonLlm));

// Run with true ellm Flag
let dataset_ellm =
TestDataset::new_args(&[&format!("--external-llm={}", true), "--gen-ai=true"]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_ellm.identifier(),
]);
let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_info.owner.0, dataset_ellm.owner());
assert_eq!(&dataset_info.name.0, dataset_ellm.name());
assert!(dataset_info
.dataset_flags
.contains(&DatasetFlag::ExternalMoonLlm));
}

#[test]
fn test_create_dataset_with_no_flags() {
let cli = TestCli::get();
// Run with no QoS Flag
let dataset_qos_none = TestDataset::new_args(&[]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_qos_none.identifier(),
]);
let dataset_qos_none_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_qos_none_info.owner.0, dataset_qos_none.owner());
assert_eq!(&dataset_qos_none_info.name.0, dataset_qos_none.name());
assert!(!dataset_qos_none_info
.dataset_flags
.contains(&DatasetFlag::Qos));
assert!(!dataset_qos_none_info
.dataset_flags
.contains(&DatasetFlag::ExternalMoonLlm));
assert!(!dataset_qos_none_info
.dataset_flags
.contains(&DatasetFlag::Gpt4));
assert!(!dataset_qos_none_info
.dataset_flags
.contains(&DatasetFlag::ZeroShotLabels));
}
#[test]
fn test_create_dataset_with_qos() {
let cli = TestCli::get();

// Run with false QoS Flag
let dataset_qos_false = TestDataset::new_args(&[&format!("--qos={}", false)]);
let output = cli.run([
"--output=json",
"get",
"datasets",
dataset_qos_false.identifier(),
]);
let dataset_qos_false_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_qos_false_info.owner.0, dataset_qos_false.owner());
assert_eq!(&dataset_qos_false_info.name.0, dataset_qos_false.name());
assert!(!dataset_qos_false_info
.dataset_flags
.contains(&DatasetFlag::Qos));

// Run with true QoS Flag
let dataset_qos = TestDataset::new_args(&[&format!("--qos={}", true)]);
let output = cli.run(["--output=json", "get", "datasets", dataset_qos.identifier()]);
let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap();
assert_eq!(&dataset_info.owner.0, dataset_qos.owner());
assert_eq!(&dataset_info.name.0, dataset_qos.name());
assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Qos));
}

#[test]
fn test_create_dataset_requires_owner() {
Expand Down

0 comments on commit 4776788

Please sign in to comment.