Skip to content

Commit

Permalink
streams: add ability to create streams (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-prosser authored Sep 14, 2023
1 parent e8af1bd commit bc0e7c8
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Unreleased
- Add create streams command
- Show source statistics in table when getting sources

## v0.18.2
Expand Down
3 changes: 3 additions & 0 deletions api/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub enum Error {
#[error("Expected <owner>/<dataset>/<stream>: {}", identifier)]
BadStreamName { identifier: String },

#[error("Expected u64: {}", version)]
BadStreamModelVersion { version: String },

#[error(
"Expected a user id (usernames and emails are not supported), got: {}",
identifier
Expand Down
12 changes: 12 additions & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use resources::{
project::ForceDeleteProject,
quota::{GetQuotasResponse, Quota},
source::StatisticsRequestParams as SourceStatisticsRequestParams,
stream::{NewStream, PutStreamRequest, PutStreamResponse},
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -393,6 +394,17 @@ impl Client {
)
}

pub fn put_stream(
&self,
dataset_name: &DatasetFullName,
stream: &NewStream,
) -> Result<PutStreamResponse> {
self.put(
self.endpoints.streams(dataset_name)?,
Some(PutStreamRequest { stream }),
)
}

pub fn sync_comments(
&self,
source_name: &SourceFullName,
Expand Down
1 change: 1 addition & 0 deletions api/src/resources/comment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub struct CommentFilter {
pub user_properties: Option<UserPropertiesFilter>,

#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default)]
pub sources: Vec<SourceId>,
}

Expand Down
59 changes: 59 additions & 0 deletions api/src/resources/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,49 @@ impl FromStr for FullName {
}
}

#[derive(Debug, Clone, Serialize)]
pub(crate) struct PutStreamRequest<'request> {
pub stream: &'request NewStream,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PutStreamResponse {
pub stream: Stream,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct NewStream {
pub name: Name,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub comment_filter: Option<CommentFilter>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<StreamModel>,
}

impl NewStream {
pub fn set_model_version(&mut self, model_version: &UserModelVersion) {
if let Some(model) = &mut self.model {
model.version = model_version.clone()
}
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamModel {
pub version: UserModelVersion,
pub label_thresholds: Vec<StreamLabelThreshold>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamLabelThreshold {
name: Vec<String>,
threshold: NotNan<f64>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Stream {
pub id: Id,
Expand All @@ -58,6 +101,9 @@ pub struct Stream {

#[serde(rename = "label_threshold_filter")]
pub label_filter: Option<LabelFilter>,

#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<StreamModel>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand All @@ -70,6 +116,19 @@ pub struct LabelFilter {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UserModelVersion(pub u64);

impl FromStr for UserModelVersion {
type Err = Error;

fn from_str(s: &str) -> Result<Self> {
match s.parse::<u64>() {
Ok(version) => Ok(UserModelVersion(version)),
Err(_) => Err(Error::BadStreamModelVersion {
version: s.to_string(),
}),
}
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Batch {
pub results: Vec<StreamResult>,
Expand Down
8 changes: 7 additions & 1 deletion cli/src/commands/create/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ mod project;
mod quota;
mod source;
mod stream_exception;
mod streams;
mod user;

use self::{
annotations::CreateAnnotationsArgs, bucket::CreateBucketArgs, comments::CreateCommentsArgs,
dataset::CreateDatasetArgs, emails::CreateEmailsArgs, project::CreateProjectArgs,
quota::CreateQuotaArgs, source::CreateSourceArgs, stream_exception::CreateStreamExceptionArgs,
user::CreateUserArgs,
streams::CreateStreamsArgs, user::CreateUserArgs,
};
use crate::printer::Printer;
use anyhow::Result;
Expand Down Expand Up @@ -62,6 +63,10 @@ pub enum CreateArgs {
#[structopt(name = "quota")]
/// Set a new value for a quota
Quota(CreateQuotaArgs),

#[structopt(name = "stream")]
/// Create a stream
Stream(CreateStreamsArgs),
}

pub fn run(
Expand All @@ -85,5 +90,6 @@ pub fn run(
stream_exception::create(&client, stream_exception_args, printer)
}
CreateArgs::Quota(quota_args) => quota::create(&client, quota_args),
CreateArgs::Stream(stream_args) => streams::create(&client, stream_args),
}
}
80 changes: 80 additions & 0 deletions cli/src/commands/create/streams.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::{
fs::File,
io::{BufRead, BufReader},
path::PathBuf,
};

use anyhow::{Context, Result};
use log::info;
use reinfer_client::{
resources::stream::{NewStream, UserModelVersion},
Client, DatasetIdentifier,
};

use structopt::StructOpt;

#[derive(Debug, StructOpt)]
pub struct CreateStreamsArgs {
#[structopt(short = "d", long = "dataset")]
/// Dataset where the streams should be created
dataset_id: DatasetIdentifier,

#[structopt(short = "f", long = "file", parse(from_os_str))]
/// Path to JSON file with streams
path: PathBuf,

#[structopt(short = "v", long = "model-version")]
/// The model version for the new streams to use
model_version: UserModelVersion,
}

pub fn create(client: &Client, args: &CreateStreamsArgs) -> Result<()> {
let CreateStreamsArgs {
path,
dataset_id,
model_version,
} = args;

let file = BufReader::new(
File::open(path).with_context(|| format!("Could not open file `{}`", path.display()))?,
);

let dataset = client.get_dataset(dataset_id.clone())?;

for read_stream_result in read_streams_iter(file) {
let mut new_stream = read_stream_result?;

new_stream.set_model_version(model_version);

client.put_stream(&dataset.full_name(), &new_stream)?;
info!("Created stream {}", new_stream.name.0)
}
Ok(())
}

fn read_streams_iter<'a>(
mut streams: impl BufRead + 'a,
) -> impl Iterator<Item = Result<NewStream>> + 'a {
let mut line = String::new();
let mut line_number: u32 = 0;
std::iter::from_fn(move || {
line_number += 1;
line.clear();

let read_result = streams
.read_line(&mut line)
.with_context(|| format!("Could not read line {line_number} from input stream"));

match read_result {
Ok(0) => return None,
Err(e) => return Some(Err(e)),
_ => {}
}

Some(
serde_json::from_str::<NewStream>(line.trim_end()).with_context(|| {
format!("Could not parse stream at line {line_number} from input stream")
}),
)
})
}
30 changes: 27 additions & 3 deletions cli/src/commands/get/streams.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use anyhow::{Context, Result};
use reinfer_client::{Client, DatasetIdentifier, StreamFullName};
use std::io;
use std::{
fs::File,
io,
io::{BufWriter, Write},
path::PathBuf,
};
use structopt::StructOpt;

use crate::printer::{print_resources_as_json, Printer};
Expand All @@ -10,6 +15,10 @@ pub struct GetStreamsArgs {
#[structopt(short = "d", long = "dataset")]
/// The dataset name or id
dataset: DatasetIdentifier,

#[structopt(short = "f", long = "file", parse(from_os_str))]
/// Path where to write streams as JSON.
path: Option<PathBuf>,
}

#[derive(Debug, StructOpt)]
Expand All @@ -32,7 +41,17 @@ pub struct GetStreamCommentsArgs {
}

pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result<()> {
let GetStreamsArgs { dataset } = args;
let GetStreamsArgs { dataset, path } = args;

let file: Option<Box<dyn Write>> = match path {
Some(path) => Some(Box::new(
File::create(path)
.with_context(|| format!("Could not open file for writing `{}`", path.display()))
.map(BufWriter::new)?,
)),
None => None,
};

let dataset_name = client
.get_dataset(dataset.clone())
.context("Operation to get dataset has failed.")?
Expand All @@ -41,7 +60,12 @@ pub fn get(client: &Client, args: &GetStreamsArgs, printer: &Printer) -> Result<
.get_streams(&dataset_name)
.context("Operation to list streams has failed.")?;
streams.sort_unstable_by(|lhs, rhs| lhs.name.0.cmp(&rhs.name.0));
printer.print_resources(&streams)

if let Some(file) = file {
print_resources_as_json(streams, file)
} else {
printer.print_resources(&streams)
}
}

pub fn get_stream_comments(client: &Client, args: &GetStreamCommentsArgs) -> Result<()> {
Expand Down

0 comments on commit bc0e7c8

Please sign in to comment.