From 6c0ff018a4bcce9e43c6c90623a21aa77e842c76 Mon Sep 17 00:00:00 2001 From: Joe Prosser Date: Thu, 29 Aug 2024 14:36:44 +0100 Subject: [PATCH] feat(commands): add dataset flags --- CHANGELOG.md | 1 + api/src/resources/dataset.rs | 15 +++ cli/src/commands/create/dataset.rs | 57 +++++++++- cli/tests/test_datasets.rs | 174 ++++++++++++++++++++++++++++- 4 files changed, 242 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f82b09de..e1229033 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # Unreleased +- Add dataset flags to `create-dataset` - Add `parse aic-classification-csv` # v0.31.0 diff --git a/api/src/resources/dataset.rs b/api/src/resources/dataset.rs index 23942679..90566c6e 100644 --- a/api/src/resources/dataset.rs +++ b/api/src/resources/dataset.rs @@ -18,6 +18,15 @@ use std::{ str::FromStr, }; +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub enum DatasetFlag { + Gpt4, + ExternalMoonLlm, + Qos, + ZeroShotLabels, +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct Dataset { pub id: Id, @@ -36,6 +45,8 @@ pub struct Dataset { pub general_fields: Vec, pub label_defs: Vec, pub label_groups: Vec, + #[serde(rename = "_dataset_flags")] + pub dataset_flags: Vec, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] @@ -267,6 +278,10 @@ pub struct NewDataset<'request> { #[serde(skip_serializing_if = "Option::is_none")] pub copy_annotations_from: Option<&'request str>, + + #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(rename = "_dataset_flags")] + pub dataset_flags: Vec, } #[derive(Debug, Clone, Serialize, PartialEq, Eq)] diff --git a/cli/src/commands/create/dataset.rs b/cli/src/commands/create/dataset.rs index b22bd587..5fb2a703 100644 --- a/cli/src/commands/create/dataset.rs +++ b/cli/src/commands/create/dataset.rs @@ -1,9 +1,10 @@ use crate::printer::Printer; -use anyhow::{anyhow, Context, Error, Result}; +use anyhow::{anyhow, bail, Context, Error, Result}; use log::info; use reinfer_client::{ - resources::entity_def::NewGeneralFieldDef, Client, DatasetFullName, NewDataset, NewEntityDef, - NewLabelDef, NewLabelGroup, SourceIdentifier, + resources::{dataset::DatasetFlag, entity_def::NewGeneralFieldDef}, + Client, DatasetFullName, NewDataset, NewEntityDef, NewLabelDef, NewLabelGroup, + SourceIdentifier, }; use serde::Deserialize; use std::str::FromStr; @@ -58,6 +59,22 @@ pub struct CreateDatasetArgs { /// Dataset ID of the dataset to copy annotations from #[structopt(long = "copy-annotations-from")] copy_annotations_from: Option, + + /// Whether the dataset should have QoS enabled + #[structopt(long = "qos")] + qos: Option, + + /// Whether to use the external llm + #[structopt(long = "external-llm")] + external_llm: Option, + + /// Whether to use generative ai features + #[structopt(long = "gen-ai")] + gen_ai: Option, + + /// Whether to use zero shot ai features + #[structopt(long = "zero-shot")] + zero_shot: Option, } pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> Result<()> { @@ -73,6 +90,10 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R label_groups, model_family, copy_annotations_from, + qos, + external_llm, + gen_ai, + zero_shot, } = args; let source_ids = { @@ -88,6 +109,35 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R source_ids }; + let get_dataset_flags = || -> Result> { + if external_llm.unwrap_or_default() && !gen_ai.unwrap_or_default() { + bail!("External Llm can only be used if gen ai features are enabled. Please add `--gen-ai true`") + } + + if zero_shot.unwrap_or_default() && !gen_ai.unwrap_or_default() { + bail!("Zero shot can only be used if gen ai features are enabled. Please add `--gen-ai true`") + } + + let mut dataset_flags = Vec::new(); + + if gen_ai.unwrap_or_default() { + dataset_flags.push(DatasetFlag::Gpt4) + } + + if external_llm.unwrap_or_default() { + dataset_flags.push(DatasetFlag::ExternalMoonLlm) + } + + if zero_shot.unwrap_or_default() { + dataset_flags.push(DatasetFlag::ZeroShotLabels) + } + + if qos.unwrap_or_default() { + dataset_flags.push(DatasetFlag::Qos) + } + Ok(dataset_flags) + }; + // Unwrap the inner values, we only need the outer for argument parsing let entity_defs = &entity_defs.0; let general_fields = &general_fields.0; @@ -124,6 +174,7 @@ pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> R }, model_family: model_family.as_deref(), copy_annotations_from: copy_annotations_from.as_deref(), + dataset_flags: get_dataset_flags()?, }, ) .context("Operation to create a dataset has failed.")?; diff --git a/cli/tests/test_datasets.rs b/cli/tests/test_datasets.rs index a16e2dc9..50ce664c 100644 --- a/cli/tests/test_datasets.rs +++ b/cli/tests/test_datasets.rs @@ -1,8 +1,8 @@ use backoff::{retry, ExponentialBackoff}; use pretty_assertions::assert_eq; use reinfer_client::{ - Dataset, EntityDef, EntityName, LabelDef, LabelDefPretrained, LabelDefPretrainedId, LabelGroup, - LabelGroupName, LabelName, MoonFormFieldDef, Source, + resources::dataset::DatasetFlag, Dataset, EntityDef, EntityName, LabelDef, LabelDefPretrained, + LabelDefPretrainedId, LabelGroup, LabelGroupName, LabelName, MoonFormFieldDef, Source, }; use serde_json::json; use uuid::Uuid; @@ -334,6 +334,176 @@ fn test_create_dataset_with_source() { assert_eq!(&source_info.owner.0, source.owner()); assert_eq!(&source_info.name.0, source.name()); } +#[test] +fn test_create_dataset_with_gen_ai() { + let cli = TestCli::get(); + + // Run with false ellm Flag + let dataset_gen_ai_false = TestDataset::new_args(&[&format!("--gen-ai={}", false)]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_gen_ai_false.identifier(), + ]); + let dataset_gen_ai_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!( + &dataset_gen_ai_false_info.owner.0, + dataset_gen_ai_false.owner() + ); + assert_eq!( + &dataset_gen_ai_false_info.name.0, + dataset_gen_ai_false.name() + ); + assert!(!dataset_gen_ai_false_info + .dataset_flags + .contains(&DatasetFlag::Gpt4)); + + // Run with true gen_ai Flag + let dataset_gen_ai = TestDataset::new_args(&[&format!("--gen-ai={}", true)]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_gen_ai.identifier(), + ]); + let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_info.owner.0, dataset_gen_ai.owner()); + assert_eq!(&dataset_info.name.0, dataset_gen_ai.name()); + assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Gpt4)); +} +#[test] +fn test_create_dataset_with_zero_shot() { + let cli = TestCli::get(); + + // Run with false ellm Flag + let dataset_zero_shot_false = TestDataset::new_args(&[&format!("--zero-shot={}", false)]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_zero_shot_false.identifier(), + ]); + let dataset_zero_shot_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!( + &dataset_zero_shot_false_info.owner.0, + dataset_zero_shot_false.owner() + ); + assert_eq!( + &dataset_zero_shot_false_info.name.0, + dataset_zero_shot_false.name() + ); + assert!(!dataset_zero_shot_false_info + .dataset_flags + .contains(&DatasetFlag::ZeroShotLabels)); + + // Run with true zero_shot Flag + let dataset_zero_shot = + TestDataset::new_args(&[&format!("--zero-shot={}", true), "--gen-ai=true"]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_zero_shot.identifier(), + ]); + let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_info.owner.0, dataset_zero_shot.owner()); + assert_eq!(&dataset_info.name.0, dataset_zero_shot.name()); + assert!(dataset_info + .dataset_flags + .contains(&DatasetFlag::ZeroShotLabels)); +} + +#[test] +fn test_create_dataset_with_external_llm() { + let cli = TestCli::get(); + + // Run with false ellm Flag + let dataset_ellm_false = TestDataset::new_args(&[&format!("--external-llm={}", false)]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_ellm_false.identifier(), + ]); + let dataset_ellm_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_ellm_false_info.owner.0, dataset_ellm_false.owner()); + assert_eq!(&dataset_ellm_false_info.name.0, dataset_ellm_false.name()); + assert!(!dataset_ellm_false_info + .dataset_flags + .contains(&DatasetFlag::ExternalMoonLlm)); + + // Run with true ellm Flag + let dataset_ellm = + TestDataset::new_args(&[&format!("--external-llm={}", true), "--gen-ai=true"]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_ellm.identifier(), + ]); + let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_info.owner.0, dataset_ellm.owner()); + assert_eq!(&dataset_info.name.0, dataset_ellm.name()); + assert!(dataset_info + .dataset_flags + .contains(&DatasetFlag::ExternalMoonLlm)); +} + +#[test] +fn test_create_dataset_with_no_flags() { + let cli = TestCli::get(); + // Run with no QoS Flag + let dataset_qos_none = TestDataset::new_args(&[]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_qos_none.identifier(), + ]); + let dataset_qos_none_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_qos_none_info.owner.0, dataset_qos_none.owner()); + assert_eq!(&dataset_qos_none_info.name.0, dataset_qos_none.name()); + assert!(!dataset_qos_none_info + .dataset_flags + .contains(&DatasetFlag::Qos)); + assert!(!dataset_qos_none_info + .dataset_flags + .contains(&DatasetFlag::ExternalMoonLlm)); + assert!(!dataset_qos_none_info + .dataset_flags + .contains(&DatasetFlag::Gpt4)); + assert!(!dataset_qos_none_info + .dataset_flags + .contains(&DatasetFlag::ZeroShotLabels)); +} +#[test] +fn test_create_dataset_with_qos() { + let cli = TestCli::get(); + + // Run with false QoS Flag + let dataset_qos_false = TestDataset::new_args(&[&format!("--qos={}", false)]); + let output = cli.run([ + "--output=json", + "get", + "datasets", + dataset_qos_false.identifier(), + ]); + let dataset_qos_false_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_qos_false_info.owner.0, dataset_qos_false.owner()); + assert_eq!(&dataset_qos_false_info.name.0, dataset_qos_false.name()); + assert!(!dataset_qos_false_info + .dataset_flags + .contains(&DatasetFlag::Qos)); + + // Run with true QoS Flag + let dataset_qos = TestDataset::new_args(&[&format!("--qos={}", true)]); + let output = cli.run(["--output=json", "get", "datasets", dataset_qos.identifier()]); + let dataset_info: Dataset = serde_json::from_str(output.trim()).unwrap(); + assert_eq!(&dataset_info.owner.0, dataset_qos.owner()); + assert_eq!(&dataset_info.name.0, dataset_qos.name()); + assert!(dataset_info.dataset_flags.contains(&DatasetFlag::Qos)); +} #[test] fn test_create_dataset_requires_owner() {