From afeb980ab6967d238c8f3a15a3fb45271c18dadd Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 2 Aug 2023 14:56:45 -0700 Subject: [PATCH] rebase and clippy --- src/clients/aggregator_client.rs | 147 +----------------- .../aggregator_client/task_id_stream.rs | 93 +++++++++++ src/clients/aggregator_client/task_stream.rs | 62 ++++++++ src/queue/job.rs | 2 + src/queue/job/v1/task_sync.rs | 6 +- tests/aggregator_client.rs | 33 ++-- tests/crypter.rs | 6 +- 7 files changed, 187 insertions(+), 162 deletions(-) create mode 100644 src/clients/aggregator_client/task_id_stream.rs create mode 100644 src/clients/aggregator_client/task_stream.rs diff --git a/src/clients/aggregator_client.rs b/src/clients/aggregator_client.rs index c7ed9095..d4c00982 100644 --- a/src/clients/aggregator_client.rs +++ b/src/clients/aggregator_client.rs @@ -1,16 +1,15 @@ -use std::{ - collections::VecDeque, - fmt::{self, Formatter}, - pin::Pin, - task::{ready, Context, Poll}, -}; +mod task_id_stream; +mod task_stream; + +use task_id_stream::TaskIdStream; +use task_stream::TaskStream; use crate::{ clients::{ClientConnExt, ClientError}, entity::{task::ProvisionableTask, Aggregator}, handler::Error, }; -use futures_lite::{stream::Stream, Future, StreamExt}; +use futures_lite::StreamExt; use serde::{de::DeserializeOwned, Serialize}; use trillium::{HeaderValue, KnownHeaderName, Method}; use trillium_client::{Client, Conn}; @@ -148,137 +147,3 @@ impl AggregatorClient { TaskStream::new(self) } } - -#[derive(Clone, Debug)] -struct Page { - task_ids: VecDeque, - pagination_token: Option, -} - -impl From for Page { - fn from( - TaskIds { - task_ids, - pagination_token, - }: TaskIds, - ) -> Self { - Page { - task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), - pagination_token, - } - } -} - -pub struct TaskIdStream<'a> { - client: &'a AggregatorClient, - page: Option, - future: Option> + Send + 'a>>>, -} - -impl<'a> TaskIdStream<'a> { - fn new(client: &'a AggregatorClient) -> Self { - Self { - client, - page: None, - future: None, - } - } -} - -impl<'a> fmt::Debug for TaskIdStream<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("TaskIdStream") - .field("client", &self.client) - .field("current_page", &self.page) - .field("current_future", &"..") - .finish() - } -} - -impl Stream for TaskIdStream<'_> { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - client, - ref mut page, - ref mut future, - } = *self; - - loop { - if let Some(page) = page { - if let Some(task_id) = page.task_ids.pop_front() { - return Poll::Ready(Some(Ok(task_id))); - } - - if page.pagination_token.is_none() { - return Poll::Ready(None); - } - } - - if let Some(fut) = future { - *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); - *future = None; - } else { - let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); - - *future = Some(Box::pin(async move { - client.get_task_id_page(pagination_token.as_deref()).await - })); - }; - } - } -} - -pub struct TaskStream<'a> { - client: &'a AggregatorClient, - task_id_stream: TaskIdStream<'a>, - task_future: Option< - Pin>> + Send + 'a>>, - >, -} - -impl<'a> fmt::Debug for TaskStream<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("TaskStream").field("future", &"..").finish() - } -} - -impl<'a> TaskStream<'a> { - fn new(client: &'a AggregatorClient) -> Self { - Self { - task_id_stream: client.task_id_stream(), - client, - task_future: None, - } - } -} - -impl Stream for TaskStream<'_> { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - client, - ref mut task_id_stream, - ref mut task_future, - } = *self; - - loop { - if let Some(future) = task_future { - let res = ready!(Pin::new(&mut *future).poll(cx)); - *task_future = None; - return Poll::Ready(res); - } - - *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { - Some(Ok(task_id)) => Some(Box::pin(async move { - let task_id = task_id; - Some(client.get_task(&task_id).await) - })), - None => return Poll::Ready(None), - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - }; - } - } -} diff --git a/src/clients/aggregator_client/task_id_stream.rs b/src/clients/aggregator_client/task_id_stream.rs new file mode 100644 index 00000000..b37e275e --- /dev/null +++ b/src/clients/aggregator_client/task_id_stream.rs @@ -0,0 +1,93 @@ +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{AggregatorClient, TaskIds}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +#[derive(Clone, Debug)] +struct Page { + task_ids: VecDeque, + pagination_token: Option, +} + +impl From for Page { + fn from( + TaskIds { + task_ids, + pagination_token, + }: TaskIds, + ) -> Self { + Page { + task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), + pagination_token, + } + } +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskIdStream<'a> { + client: &'a AggregatorClient, + page: Option, + future: Option>>, +} + +impl<'a> TaskIdStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + client, + page: None, + future: None, + } + } +} + +impl<'a> Debug for TaskIdStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskIdStream") + .field("client", &self.client) + .field("current_page", &self.page) + .field("current_future", &"..") + .finish() + } +} + +impl Stream for TaskIdStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut page, + ref mut future, + } = *self; + + loop { + if let Some(page) = page { + if let Some(task_id) = page.task_ids.pop_front() { + return Poll::Ready(Some(Ok(task_id))); + } + + if page.pagination_token.is_none() { + return Poll::Ready(None); + } + } + + if let Some(fut) = future { + *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); + *future = None; + } else { + let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); + + *future = Some(Box::pin(async move { + client.get_task_id_page(pagination_token.as_deref()).await + })); + }; + } + } +} diff --git a/src/clients/aggregator_client/task_stream.rs b/src/clients/aggregator_client/task_stream.rs new file mode 100644 index 00000000..a23a2972 --- /dev/null +++ b/src/clients/aggregator_client/task_stream.rs @@ -0,0 +1,62 @@ +use std::{ + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{task_id_stream::TaskIdStream, AggregatorClient, TaskResponse}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskStream<'a> { + client: &'a AggregatorClient, + task_id_stream: TaskIdStream<'a>, + task_future: Option>>>, +} + +impl<'a> Debug for TaskStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskStream").field("future", &"..").finish() + } +} + +impl<'a> TaskStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + task_id_stream: client.task_id_stream(), + client, + task_future: None, + } + } +} + +impl Stream for TaskStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut task_id_stream, + ref mut task_future, + } = *self; + + loop { + if let Some(future) = task_future { + let res = ready!(Pin::new(&mut *future).poll(cx)); + *task_future = None; + return Poll::Ready(res); + } + + *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { + Some(Ok(task_id)) => Some(Box::pin(async move { + let task_id = task_id; + Some(client.get_task(&task_id).await) + })), + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + }; + } + } +} diff --git a/src/queue/job.rs b/src/queue/job.rs index 02617c24..66272adf 100644 --- a/src/queue/job.rs +++ b/src/queue/job.rs @@ -81,6 +81,7 @@ pub struct SharedJobState { pub auth0_client: Auth0Client, pub postmark_client: PostmarkClient, pub http_client: Client, + pub crypter: crate::Crypter, } impl From<&Config> for SharedJobState { fn from(config: &Config) -> Self { @@ -88,6 +89,7 @@ impl From<&Config> for SharedJobState { auth0_client: Auth0Client::new(config), postmark_client: PostmarkClient::new(config), http_client: config.client.clone(), + crypter: config.crypter.clone(), } } } diff --git a/src/queue/job/v1/task_sync.rs b/src/queue/job/v1/task_sync.rs index 2449789a..ea02335d 100644 --- a/src/queue/job/v1/task_sync.rs +++ b/src/queue/job/v1/task_sync.rs @@ -3,7 +3,7 @@ use crate::{ queue::job::{EnqueueJob, Job, JobError, SharedJobState, V1}, }; use futures_lite::StreamExt; -use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, PaginatorTrait, QueryFilter}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter}; use serde::{Deserialize, Serialize}; use time::Duration; @@ -24,7 +24,9 @@ impl TaskSync { .await?; for aggregator in aggregators { - let client = aggregator.client(job_state.http_client.clone()); + let client = aggregator + .client(job_state.http_client.clone(), &job_state.crypter) + .map_err(|e| JobError::ClientOther(e.to_string()))?; while let Some(task_from_aggregator) = client.task_stream().next().await.transpose()? { let task_id = task_from_aggregator.task_id.to_string(); if let Some(_task_from_db) = Tasks::find_by_id(&task_id).one(db).await? { diff --git a/tests/aggregator_client.rs b/tests/aggregator_client.rs index 38dd1e06..588d5a8e 100644 --- a/tests/aggregator_client.rs +++ b/tests/aggregator_client.rs @@ -21,11 +21,13 @@ async fn streaming_tasks(app: DivviupApi, client_logs: ClientLogs) -> TestResult == "application/vnd.janus.aggregator+json;version=0.1" })); + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Authorization) - .unwrap() - == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() })); assert_eq!(logs.len(), 28); // one per task plus three pages @@ -41,23 +43,22 @@ async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let logs = client_logs.logs(); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Accept) - .unwrap() - == "application/vnd.janus.aggregator+json;version=0.1" + log.request_headers.eq_ignore_ascii_case( + KnownHeaderName::Accept, + "application/vnd.janus.aggregator+json;version=0.1", + ) })); + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Authorization) - .unwrap() - == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() })); - let queries = logs - .iter() - .map(|log| log.url.query().clone()) - .collect::>(); + let queries = logs.iter().map(|log| log.url.query()).collect::>(); assert_eq!( &queries, &[ diff --git a/tests/crypter.rs b/tests/crypter.rs index 674f5bdb..ec54d6fd 100644 --- a/tests/crypter.rs +++ b/tests/crypter.rs @@ -15,7 +15,7 @@ fn round_trip_with_current_key() { #[test] fn round_trip_with_old_key() { let old_key = Crypter::generate_key(); - let crypter = Crypter::from(old_key.clone()); + let crypter = Crypter::from(old_key); let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); let crypter = Crypter::new(Crypter::generate_key(), [old_key]); @@ -49,12 +49,12 @@ fn parsing() { let keys = std::iter::repeat_with(Crypter::generate_key) .take(5) .collect::>(); - let encrypted = Crypter::from(keys[0].clone()) + let encrypted = Crypter::from(keys[0]) .encrypt(AAD, PLAINTEXT) .unwrap(); let crypter = keys .iter() - .map(|k| URL_SAFE_NO_PAD.encode(&k)) + .map(|k| URL_SAFE_NO_PAD.encode(k)) .collect::>() .join(",") .parse::()