From 9962076c4ee7b7eece5f7aa4ae080b19aaecc6e6 Mon Sep 17 00:00:00 2001 From: Joe Prosser Date: Thu, 14 Sep 2023 14:57:50 +0100 Subject: [PATCH] streams: add ability to create streams --- CHANGELOG.md | 1 + api/src/error.rs | 3 ++ api/src/lib.rs | 12 +++++ api/src/resources/comment.rs | 1 + api/src/resources/stream.rs | 59 ++++++++++++++++++++++ cli/src/commands/create/mod.rs | 8 ++- cli/src/commands/create/streams.rs | 80 ++++++++++++++++++++++++++++++ cli/src/commands/get/streams.rs | 30 +++++++++-- 8 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 cli/src/commands/create/streams.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index ca439cb1..45b301ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ ## Unreleased +- Add create streams command - Show source statistics in table when getting sources ## v0.18.2 diff --git a/api/src/error.rs b/api/src/error.rs index 34fde247..535d47e0 100644 --- a/api/src/error.rs +++ b/api/src/error.rs @@ -25,6 +25,9 @@ pub enum Error { #[error("Expected //: {}", identifier)] BadStreamName { identifier: String }, + #[error("Expected u64: {}", version)] + BadStreamModelVersion { version: String }, + #[error( "Expected a user id (usernames and emails are not supported), got: {}", identifier diff --git a/api/src/lib.rs b/api/src/lib.rs index 4b9d177c..f402485b 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -22,6 +22,7 @@ use resources::{ project::ForceDeleteProject, quota::{GetQuotasResponse, Quota}, source::StatisticsRequestParams as SourceStatisticsRequestParams, + stream::{NewStream, PutStreamRequest, PutStreamResponse}, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -393,6 +394,17 @@ impl Client { ) } + pub fn put_stream( + &self, + dataset_name: &DatasetFullName, + stream: &NewStream, + ) -> Result { + self.put( + self.endpoints.streams(dataset_name)?, + Some(PutStreamRequest { stream }), + ) + } + pub fn sync_comments( &self, source_name: &SourceFullName, diff --git a/api/src/resources/comment.rs b/api/src/resources/comment.rs index 3b0e7350..d932362f 100644 --- a/api/src/resources/comment.rs +++ b/api/src/resources/comment.rs @@ -87,6 +87,7 @@ pub struct CommentFilter { pub user_properties: Option, #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default)] pub sources: Vec, } diff --git a/api/src/resources/stream.rs b/api/src/resources/stream.rs index 1d05607e..c7c93bc4 100644 --- a/api/src/resources/stream.rs +++ b/api/src/resources/stream.rs @@ -43,6 +43,49 @@ impl FromStr for FullName { } } +#[derive(Debug, Clone, Serialize)] +pub(crate) struct PutStreamRequest<'request> { + pub stream: &'request NewStream, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PutStreamResponse { + pub stream: Stream, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct NewStream { + pub name: Name, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub comment_filter: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +impl NewStream { + pub fn set_model_version(&mut self, model_version: &UserModelVersion) { + if let Some(model) = &mut self.model { + model.version = model_version.clone() + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamModel { + pub version: UserModelVersion, + pub label_thresholds: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamLabelThreshold { + name: Vec, + threshold: NotNan, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Stream { pub id: Id, @@ -58,6 +101,9 @@ pub struct Stream { #[serde(rename = "label_threshold_filter")] pub label_filter: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -70,6 +116,19 @@ pub struct LabelFilter { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct UserModelVersion(pub u64); +impl FromStr for UserModelVersion { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s.parse::() { + Ok(version) => Ok(UserModelVersion(version)), + Err(_) => Err(Error::BadStreamModelVersion { + version: s.to_string(), + }), + } + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Batch { pub results: Vec, diff --git a/cli/src/commands/create/mod.rs b/cli/src/commands/create/mod.rs index 47ceb3b1..2c7a9755 100644 --- a/cli/src/commands/create/mod.rs +++ b/cli/src/commands/create/mod.rs @@ -7,13 +7,14 @@ mod project; mod quota; mod source; mod stream_exception; +mod streams; mod user; use self::{ annotations::CreateAnnotationsArgs, bucket::CreateBucketArgs, comments::CreateCommentsArgs, dataset::CreateDatasetArgs, emails::CreateEmailsArgs, project::CreateProjectArgs, quota::CreateQuotaArgs, source::CreateSourceArgs, stream_exception::CreateStreamExceptionArgs, - user::CreateUserArgs, + streams::CreateStreamsArgs, user::CreateUserArgs, }; use crate::printer::Printer; use anyhow::Result; @@ -62,6 +63,10 @@ pub enum CreateArgs { #[structopt(name = "quota")] /// Set a new value for a quota Quota(CreateQuotaArgs), + + #[structopt(name = "stream")] + /// Create a stream + Stream(CreateStreamsArgs), } pub fn run( @@ -85,5 +90,6 @@ pub fn run( stream_exception::create(&client, stream_exception_args, printer) } CreateArgs::Quota(quota_args) => quota::create(&client, quota_args), + CreateArgs::Stream(stream_args) => streams::create(&client, stream_args), } } diff --git a/cli/src/commands/create/streams.rs b/cli/src/commands/create/streams.rs new file mode 100644 index 00000000..201e051d --- /dev/null +++ b/cli/src/commands/create/streams.rs @@ -0,0 +1,80 @@ +use std::{ + fs::File, + io::{BufRead, BufReader}, + path::PathBuf, +}; + +use anyhow::{Context, Result}; +use log::info; +use reinfer_client::{ + resources::stream::{NewStream, UserModelVersion}, + Client, DatasetIdentifier, +}; + +use structopt::StructOpt; + +#[derive(Debug, StructOpt)] +pub struct CreateStreamsArgs { + #[structopt(short = "d", long = "dataset")] + /// Dataset where the streams should be created + dataset_id: DatasetIdentifier, + + #[structopt(short = "f", long = "file", parse(from_os_str))] + /// Path to JSON file with streams + path: PathBuf, + + #[structopt(short = "v", long = "model-version")] + /// The model version for the new streams to use + model_version: UserModelVersion, +} + +pub fn create(client: &Client, args: &CreateStreamsArgs) -> Result<()> { + let CreateStreamsArgs { + path, + dataset_id, + model_version, + } = args; + + let file = BufReader::new( + File::open(path).with_context(|| format!("Could not open file `{}`", path.display()))?, + ); + + let dataset = client.get_dataset(dataset_id.clone())?; + + for read_stream_result in read_streams_iter(file) { + let mut new_stream = read_stream_result?; + + new_stream.set_model_version(model_version); + + client.put_stream(&dataset.full_name(), &new_stream)?; + info!("Created stream {}", new_stream.name.0) + } + Ok(()) +} + +fn read_streams_iter<'a>( + mut streams: impl BufRead + 'a, +) -> impl Iterator> + 'a { + let mut line = String::new(); + let mut line_number: u32 = 0; + std::iter::from_fn(move || { + line_number += 1; + line.clear(); + + let read_result = streams + .read_line(&mut line) + .with_context(|| format!("Could not read line {line_number} from input stream")); + + match read_result { + Ok(0) => return None, + Err(e) => return Some(Err(e)), + _ => {} + } + + Some( + serde_json::from_str::(line.trim_end()).with_context(|| { + format!("Could not parse stream at line {line_number} from input stream") + }), + ) + }) +} diff --git a/cli/src/commands/get/streams.rs b/cli/src/commands/get/streams.rs index 742c477d..e035e232 100644 --- a/cli/src/commands/get/streams.rs +++ b/cli/src/commands/get/streams.rs @@ -1,6 +1,11 @@ use anyhow::{Context, Result}; use reinfer_client::{Client, DatasetIdentifier, StreamFullName}; -use std::io; +use std::{ + fs::File, + io, + io::{BufWriter, Write}, + path::PathBuf, +}; use structopt::StructOpt; use crate::printer::{print_resources_as_json, Printer}; @@ -10,6 +15,10 @@ pub struct GetStreamsArgs { #[structopt(short = "d", long = "dataset")] /// The dataset name or id dataset: DatasetIdentifier, + + #[structopt(short = "f", long = "file", parse(from_os_str))] + /// Path where to write streams as JSON. + path: Option, } #[derive(Debug, StructOpt)] @@ -32,7 +41,17 @@ pub struct GetStreamCommentsArgs { } pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result<()> { - let GetStreamsArgs { dataset } = args; + let GetStreamsArgs { dataset, path } = args; + + let file: Option> = match path { + Some(path) => Some(Box::new( + File::create(path) + .with_context(|| format!("Could not open file for writing `{}`", path.display())) + .map(BufWriter::new)?, + )), + None => None, + }; + let dataset_name = client .get_dataset(dataset.clone()) .context("Operation to get dataset has failed.")? @@ -41,7 +60,12 @@ pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result< .get_streams(&dataset_name) .context("Operation to list streams has failed.")?; streams.sort_unstable_by(|lhs, rhs| lhs.name.0.cmp(&rhs.name.0)); - printer.print_resources(&streams) + + if let Some(file) = file { + print_resources_as_json(streams, file) + } else { + printer.print_resources(&streams) + } } pub fn get_stream_comments(client: &Client, args: &GetStreamCommentsArgs) -> Result<()> {