From 26b9e5b236d20cc4a1773d0f8cd25137bf2ea275 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Tue, 1 Aug 2023 15:06:33 -0700 Subject: [PATCH 1/5] =?UTF-8?q?na=C3=AFve=20version=20of=20task=20sync?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/queue/job.rs | 3 ++ src/queue/job/v1.rs | 4 +++ src/queue/job/v1/task_sync.rs | 53 +++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 src/queue/job/v1/task_sync.rs diff --git a/src/queue/job.rs b/src/queue/job.rs index 6c5cfa53..cb51a635 100644 --- a/src/queue/job.rs +++ b/src/queue/job.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use time::{Duration, OffsetDateTime}; use trillium::{Method, Status}; +use trillium_client::Client; use url::Url; mod v1; @@ -77,12 +78,14 @@ impl From for JobError { pub struct SharedJobState { pub auth0_client: Auth0Client, pub postmark_client: PostmarkClient, + pub http_client: Client, } impl From<&Config> for SharedJobState { fn from(config: &Config) -> Self { Self { auth0_client: Auth0Client::new(config), postmark_client: PostmarkClient::new(config), + http_client: config.client.clone(), } } } diff --git a/src/queue/job/v1.rs b/src/queue/job/v1.rs index bca2943a..d6140e27 100644 --- a/src/queue/job/v1.rs +++ b/src/queue/job/v1.rs @@ -3,6 +3,7 @@ mod queue_cleanup; mod reset_password; mod send_invitation_email; mod session_cleanup; +mod task_sync; use crate::queue::EnqueueJob; @@ -15,6 +16,7 @@ pub use queue_cleanup::QueueCleanup; pub use reset_password::ResetPassword; pub use send_invitation_email::SendInvitationEmail; pub use session_cleanup::SessionCleanup; +pub use task_sync::TaskSync; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "type")] @@ -24,6 +26,7 @@ pub enum V1 { ResetPassword(ResetPassword), SessionCleanup(SessionCleanup), QueueCleanup(QueueCleanup), + TaskSync(TaskSync), } impl V1 { @@ -38,6 +41,7 @@ impl V1 { V1::ResetPassword(job) => job.perform(job_state, db).await, V1::SessionCleanup(job) => job.perform(job_state, db).await, V1::QueueCleanup(job) => job.perform(job_state, db).await, + V1::TaskSync(job) => job.perform(job_state, db).await, } } } diff --git a/src/queue/job/v1/task_sync.rs b/src/queue/job/v1/task_sync.rs new file mode 100644 index 00000000..9786351e --- /dev/null +++ b/src/queue/job/v1/task_sync.rs @@ -0,0 +1,53 @@ +use crate::{ + entity::*, + queue::job::{EnqueueJob, Job, JobError, SharedJobState, V1}, +}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, PaginatorTrait, QueryFilter}; +use serde::{Deserialize, Serialize}; +use time::Duration; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy)] +pub struct TaskSync; + +const SYNC_PERIOD: Duration = Duration::weeks(1); + +impl TaskSync { + pub async fn perform( + &mut self, + job_state: &SharedJobState, + db: &impl ConnectionTrait, + ) -> Result, JobError> { + let aggregators = Aggregators::find() + .filter(AggregatorColumn::IsFirstParty.eq(true)) // eventually we may want to check for a capability + .all(db) + .await?; + + for aggregator in aggregators { + let client = aggregator.client(job_state.http_client.clone()); + for task_id in client.get_task_ids().await? { + if 0 == Tasks::find_by_id(&task_id).count(db).await? { + client.delete_task(&task_id).await?; + } + } + } + + Ok(Some(EnqueueJob::from(*self).scheduled_in(SYNC_PERIOD))) + } +} + +impl From for Job { + fn from(value: TaskSync) -> Self { + Self::V1(V1::TaskSync(value)) + } +} + +impl PartialEq for TaskSync { + fn eq(&self, other: &Job) -> bool { + matches!(other, Job::V1(V1::TaskSync(c)) if c == self) + } +} +impl PartialEq for Job { + fn eq(&self, other: &TaskSync) -> bool { + matches!(self, Job::V1(V1::TaskSync(j)) if j == other) + } +} From b4612790874c49979da7c50df9ec579443697700 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Tue, 1 Aug 2023 17:45:23 -0700 Subject: [PATCH 2/5] add to auto-scheduled jobs --- src/queue.rs | 48 ++++++++++++++++++++---------------------------- src/queue/job.rs | 4 +++- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/src/queue.rs b/src/queue.rs index 576d1916..5842f2c1 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -73,35 +73,27 @@ impl Queue { } pub async fn schedule_recurring_tasks_if_needed(&self) -> Result<(), DbErr> { - let tx = self.db.begin().await?; - - let session_cleanup_jobs = Entity::find() - .filter(all![ - Expr::cust_with_expr("job->>'type' = $1", "SessionCleanup"), - Column::ScheduledAt.gt(OffsetDateTime::now_utc()), - ]) - .count(&tx) - .await?; - - if session_cleanup_jobs == 0 { - Job::from(SessionCleanup).insert(&tx).await?; - } - tx.commit().await?; - - let tx = self.db.begin().await?; - let queue_cleanup_jobs = Entity::find() - .filter(all![ - Expr::cust_with_expr("job->>'type' = $1", "QueueCleanup"), - Column::ScheduledAt.gt(OffsetDateTime::now_utc()), - ]) - .count(&tx) - .await?; - - if queue_cleanup_jobs == 0 { - Job::from(QueueCleanup).insert(&tx).await?; + let schedulable_jobs = [ + (Job::from(SessionCleanup), "SessionCleanup"), + (Job::from(QueueCleanup), "QueueCleanup"), + (Job::from(TaskSync), "TaskSync"), + ]; + + for (job, name) in schedulable_jobs { + let tx = self.db.begin().await?; + let existing_jobs = Entity::find() + .filter(all![ + Expr::cust_with_expr("job->>'type' = $1", name), + Column::ScheduledAt.gt(OffsetDateTime::now_utc()), + ]) + .count(&tx) + .await?; + + if existing_jobs == 0 { + job.insert(&tx).await?; + tx.commit().await?; + } } - tx.commit().await?; - Ok(()) } diff --git a/src/queue/job.rs b/src/queue/job.rs index cb51a635..02617c24 100644 --- a/src/queue/job.rs +++ b/src/queue/job.rs @@ -12,7 +12,9 @@ use trillium_client::Client; use url::Url; mod v1; -pub use v1::{CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, V1}; +pub use v1::{ + CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, TaskSync, V1, +}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "version")] From b58f1e7a41b31c1fa52117ac97515f5c3e0ef9de Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 2 Aug 2023 12:54:39 -0700 Subject: [PATCH 3/5] add task id stream and task stream to aggregator client --- src/api_mocks/aggregator_api.rs | 11 +- src/clients/aggregator_client.rs | 177 +++++++++++++++++++-- src/clients/aggregator_client/api_types.rs | 2 +- tests/aggregator_client.rs | 29 +++- 4 files changed, 192 insertions(+), 27 deletions(-) diff --git a/src/api_mocks/aggregator_api.rs b/src/api_mocks/aggregator_api.rs index 01148528..e1b13186 100644 --- a/src/api_mocks/aggregator_api.rs +++ b/src/api_mocks/aggregator_api.rs @@ -13,7 +13,6 @@ use trillium::{Conn, Handler, Status}; use trillium_api::{api, Json}; use trillium_http::KnownHeaderName; use trillium_router::{router, RouterConnExt}; -use uuid::Uuid; pub const BAD_BEARER_TOKEN: &str = "badbearertoken"; @@ -129,21 +128,17 @@ async fn task_ids(conn: &mut Conn, (): ()) -> Result, Status> { let query = QueryStrong::parse(conn.querystring()).map_err(|_| Status::InternalServerError)?; match query.get_str("pagination_token") { None => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()) - .take(10) - .collect(), + task_ids: repeat_with(random).take(10).collect(), pagination_token: Some("second".into()), })), Some("second") => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()) - .take(10) - .collect(), + task_ids: repeat_with(random).take(10).collect(), pagination_token: Some("last".into()), })), _ => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()).take(5).collect(), + task_ids: repeat_with(random).take(5).collect(), pagination_token: None, })), } diff --git a/src/clients/aggregator_client.rs b/src/clients/aggregator_client.rs index 1821bb3f..c7ed9095 100644 --- a/src/clients/aggregator_client.rs +++ b/src/clients/aggregator_client.rs @@ -1,8 +1,16 @@ +use std::{ + collections::VecDeque, + fmt::{self, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + use crate::{ clients::{ClientConnExt, ClientError}, entity::{task::ProvisionableTask, Aggregator}, handler::Error, }; +use futures_lite::{stream::Stream, Future, StreamExt}; use serde::{de::DeserializeOwned, Serialize}; use trillium::{HeaderValue, KnownHeaderName, Method}; use trillium_client::{Client, Conn}; @@ -57,24 +65,17 @@ impl AggregatorClient { .map_err(Into::into) } - pub async fn get_task_ids(&self) -> Result, ClientError> { - let mut ids = vec![]; - let mut path = String::from("task_ids"); - loop { - let TaskIds { - task_ids, - pagination_token, - } = self.get(&path).await?; - - ids.extend(task_ids); + pub async fn get_task_id_page(&self, page: Option<&str>) -> Result { + let path = if let Some(pagination_token) = page { + format!("task_ids?pagination_token={pagination_token}") + } else { + "task_ids".into() + }; + self.get(&path).await + } - match pagination_token { - Some(pagination_token) => { - path = format!("task_ids?pagination_token={pagination_token}"); - } - None => break Ok(ids), - } - } + pub async fn get_task_ids(&self) -> Result, ClientError> { + self.task_id_stream().try_collect().await } pub async fn get_task(&self, task_id: &str) -> Result { @@ -138,4 +139,146 @@ impl AggregatorClient { .await?; Ok(()) } + + pub fn task_id_stream(&self) -> TaskIdStream<'_> { + TaskIdStream::new(self) + } + + pub fn task_stream(&self) -> TaskStream<'_> { + TaskStream::new(self) + } +} + +#[derive(Clone, Debug)] +struct Page { + task_ids: VecDeque, + pagination_token: Option, +} + +impl From for Page { + fn from( + TaskIds { + task_ids, + pagination_token, + }: TaskIds, + ) -> Self { + Page { + task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), + pagination_token, + } + } +} + +pub struct TaskIdStream<'a> { + client: &'a AggregatorClient, + page: Option, + future: Option> + Send + 'a>>>, +} + +impl<'a> TaskIdStream<'a> { + fn new(client: &'a AggregatorClient) -> Self { + Self { + client, + page: None, + future: None, + } + } +} + +impl<'a> fmt::Debug for TaskIdStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskIdStream") + .field("client", &self.client) + .field("current_page", &self.page) + .field("current_future", &"..") + .finish() + } +} + +impl Stream for TaskIdStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut page, + ref mut future, + } = *self; + + loop { + if let Some(page) = page { + if let Some(task_id) = page.task_ids.pop_front() { + return Poll::Ready(Some(Ok(task_id))); + } + + if page.pagination_token.is_none() { + return Poll::Ready(None); + } + } + + if let Some(fut) = future { + *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); + *future = None; + } else { + let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); + + *future = Some(Box::pin(async move { + client.get_task_id_page(pagination_token.as_deref()).await + })); + }; + } + } +} + +pub struct TaskStream<'a> { + client: &'a AggregatorClient, + task_id_stream: TaskIdStream<'a>, + task_future: Option< + Pin>> + Send + 'a>>, + >, +} + +impl<'a> fmt::Debug for TaskStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskStream").field("future", &"..").finish() + } +} + +impl<'a> TaskStream<'a> { + fn new(client: &'a AggregatorClient) -> Self { + Self { + task_id_stream: client.task_id_stream(), + client, + task_future: None, + } + } +} + +impl Stream for TaskStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut task_id_stream, + ref mut task_future, + } = *self; + + loop { + if let Some(future) = task_future { + let res = ready!(Pin::new(&mut *future).poll(cx)); + *task_future = None; + return Poll::Ready(res); + } + + *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { + Some(Ok(task_id)) => Some(Box::pin(async move { + let task_id = task_id; + Some(client.get_task(&task_id).await) + })), + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + }; + } + } } diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 66430250..dc7a4cba 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -220,7 +220,7 @@ impl TaskResponse { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TaskIds { - pub task_ids: Vec, + pub task_ids: Vec, pub pagination_token: Option, } diff --git a/tests/aggregator_client.rs b/tests/aggregator_client.rs index a73b5cab..38dd1e06 100644 --- a/tests/aggregator_client.rs +++ b/tests/aggregator_client.rs @@ -1,10 +1,37 @@ use divviup_api::{ api_mocks::aggregator_api::{self, BAD_BEARER_TOKEN}, - clients::AggregatorClient, + clients::{aggregator_client::TaskResponse, AggregatorClient}, }; use test_support::{assert_eq, test, *}; use trillium::Handler; +#[test(harness = with_client_logs)] +async fn streaming_tasks(app: DivviupApi, client_logs: ClientLogs) -> TestResult { + use futures_lite::stream::StreamExt; + let aggregator = fixtures::aggregator(&app, None).await; + let client = aggregator.client(app.config().client.clone(), app.crypter())?; + let tasks: Vec = client.task_stream().try_collect().await?; + assert_eq!(tasks.len(), 25); // two pages of 10 plus a final page of 5 + + let logs = client_logs.logs(); + assert!(logs.iter().all(|log| { + log.request_headers + .get_str(KnownHeaderName::Accept) + .unwrap() + == "application/vnd.janus.aggregator+json;version=0.1" + })); + + assert!(logs.iter().all(|log| { + log.request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() + == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + })); + + assert_eq!(logs.len(), 28); // one per task plus three pages + Ok(()) +} + #[test(harness = with_client_logs)] async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let aggregator = fixtures::aggregator(&app, None).await; From 250b5b781408006a39a50a67791aafc188b9fa8f Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 2 Aug 2023 14:28:36 -0700 Subject: [PATCH 4/5] use streaming in task sync --- src/queue/job/v1/task_sync.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/queue/job/v1/task_sync.rs b/src/queue/job/v1/task_sync.rs index 9786351e..2449789a 100644 --- a/src/queue/job/v1/task_sync.rs +++ b/src/queue/job/v1/task_sync.rs @@ -2,7 +2,8 @@ use crate::{ entity::*, queue::job::{EnqueueJob, Job, JobError, SharedJobState, V1}, }; -use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, PaginatorTrait, QueryFilter}; +use futures_lite::StreamExt; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, PaginatorTrait, QueryFilter}; use serde::{Deserialize, Serialize}; use time::Duration; @@ -24,8 +25,11 @@ impl TaskSync { for aggregator in aggregators { let client = aggregator.client(job_state.http_client.clone()); - for task_id in client.get_task_ids().await? { - if 0 == Tasks::find_by_id(&task_id).count(db).await? { + while let Some(task_from_aggregator) = client.task_stream().next().await.transpose()? { + let task_id = task_from_aggregator.task_id.to_string(); + if let Some(_task_from_db) = Tasks::find_by_id(&task_id).one(db).await? { + // TODO: confirm that the task matches + } else { client.delete_task(&task_id).await?; } } From c8baafef5727d66895c337fb9034c86d7eb4cc21 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 2 Aug 2023 14:56:45 -0700 Subject: [PATCH 5/5] rebase and clippy --- src/clients/aggregator_client.rs | 147 +----------------- .../aggregator_client/task_id_stream.rs | 93 +++++++++++ src/clients/aggregator_client/task_stream.rs | 62 ++++++++ src/queue/job.rs | 2 + src/queue/job/v1/task_sync.rs | 6 +- tests/aggregator_client.rs | 33 ++-- tests/crypter.rs | 8 +- 7 files changed, 187 insertions(+), 164 deletions(-) create mode 100644 src/clients/aggregator_client/task_id_stream.rs create mode 100644 src/clients/aggregator_client/task_stream.rs diff --git a/src/clients/aggregator_client.rs b/src/clients/aggregator_client.rs index c7ed9095..d4c00982 100644 --- a/src/clients/aggregator_client.rs +++ b/src/clients/aggregator_client.rs @@ -1,16 +1,15 @@ -use std::{ - collections::VecDeque, - fmt::{self, Formatter}, - pin::Pin, - task::{ready, Context, Poll}, -}; +mod task_id_stream; +mod task_stream; + +use task_id_stream::TaskIdStream; +use task_stream::TaskStream; use crate::{ clients::{ClientConnExt, ClientError}, entity::{task::ProvisionableTask, Aggregator}, handler::Error, }; -use futures_lite::{stream::Stream, Future, StreamExt}; +use futures_lite::StreamExt; use serde::{de::DeserializeOwned, Serialize}; use trillium::{HeaderValue, KnownHeaderName, Method}; use trillium_client::{Client, Conn}; @@ -148,137 +147,3 @@ impl AggregatorClient { TaskStream::new(self) } } - -#[derive(Clone, Debug)] -struct Page { - task_ids: VecDeque, - pagination_token: Option, -} - -impl From for Page { - fn from( - TaskIds { - task_ids, - pagination_token, - }: TaskIds, - ) -> Self { - Page { - task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), - pagination_token, - } - } -} - -pub struct TaskIdStream<'a> { - client: &'a AggregatorClient, - page: Option, - future: Option> + Send + 'a>>>, -} - -impl<'a> TaskIdStream<'a> { - fn new(client: &'a AggregatorClient) -> Self { - Self { - client, - page: None, - future: None, - } - } -} - -impl<'a> fmt::Debug for TaskIdStream<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("TaskIdStream") - .field("client", &self.client) - .field("current_page", &self.page) - .field("current_future", &"..") - .finish() - } -} - -impl Stream for TaskIdStream<'_> { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - client, - ref mut page, - ref mut future, - } = *self; - - loop { - if let Some(page) = page { - if let Some(task_id) = page.task_ids.pop_front() { - return Poll::Ready(Some(Ok(task_id))); - } - - if page.pagination_token.is_none() { - return Poll::Ready(None); - } - } - - if let Some(fut) = future { - *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); - *future = None; - } else { - let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); - - *future = Some(Box::pin(async move { - client.get_task_id_page(pagination_token.as_deref()).await - })); - }; - } - } -} - -pub struct TaskStream<'a> { - client: &'a AggregatorClient, - task_id_stream: TaskIdStream<'a>, - task_future: Option< - Pin>> + Send + 'a>>, - >, -} - -impl<'a> fmt::Debug for TaskStream<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("TaskStream").field("future", &"..").finish() - } -} - -impl<'a> TaskStream<'a> { - fn new(client: &'a AggregatorClient) -> Self { - Self { - task_id_stream: client.task_id_stream(), - client, - task_future: None, - } - } -} - -impl Stream for TaskStream<'_> { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - client, - ref mut task_id_stream, - ref mut task_future, - } = *self; - - loop { - if let Some(future) = task_future { - let res = ready!(Pin::new(&mut *future).poll(cx)); - *task_future = None; - return Poll::Ready(res); - } - - *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { - Some(Ok(task_id)) => Some(Box::pin(async move { - let task_id = task_id; - Some(client.get_task(&task_id).await) - })), - None => return Poll::Ready(None), - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - }; - } - } -} diff --git a/src/clients/aggregator_client/task_id_stream.rs b/src/clients/aggregator_client/task_id_stream.rs new file mode 100644 index 00000000..b37e275e --- /dev/null +++ b/src/clients/aggregator_client/task_id_stream.rs @@ -0,0 +1,93 @@ +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{AggregatorClient, TaskIds}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +#[derive(Clone, Debug)] +struct Page { + task_ids: VecDeque, + pagination_token: Option, +} + +impl From for Page { + fn from( + TaskIds { + task_ids, + pagination_token, + }: TaskIds, + ) -> Self { + Page { + task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), + pagination_token, + } + } +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskIdStream<'a> { + client: &'a AggregatorClient, + page: Option, + future: Option>>, +} + +impl<'a> TaskIdStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + client, + page: None, + future: None, + } + } +} + +impl<'a> Debug for TaskIdStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskIdStream") + .field("client", &self.client) + .field("current_page", &self.page) + .field("current_future", &"..") + .finish() + } +} + +impl Stream for TaskIdStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut page, + ref mut future, + } = *self; + + loop { + if let Some(page) = page { + if let Some(task_id) = page.task_ids.pop_front() { + return Poll::Ready(Some(Ok(task_id))); + } + + if page.pagination_token.is_none() { + return Poll::Ready(None); + } + } + + if let Some(fut) = future { + *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); + *future = None; + } else { + let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); + + *future = Some(Box::pin(async move { + client.get_task_id_page(pagination_token.as_deref()).await + })); + }; + } + } +} diff --git a/src/clients/aggregator_client/task_stream.rs b/src/clients/aggregator_client/task_stream.rs new file mode 100644 index 00000000..a23a2972 --- /dev/null +++ b/src/clients/aggregator_client/task_stream.rs @@ -0,0 +1,62 @@ +use std::{ + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{task_id_stream::TaskIdStream, AggregatorClient, TaskResponse}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskStream<'a> { + client: &'a AggregatorClient, + task_id_stream: TaskIdStream<'a>, + task_future: Option>>>, +} + +impl<'a> Debug for TaskStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskStream").field("future", &"..").finish() + } +} + +impl<'a> TaskStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + task_id_stream: client.task_id_stream(), + client, + task_future: None, + } + } +} + +impl Stream for TaskStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut task_id_stream, + ref mut task_future, + } = *self; + + loop { + if let Some(future) = task_future { + let res = ready!(Pin::new(&mut *future).poll(cx)); + *task_future = None; + return Poll::Ready(res); + } + + *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { + Some(Ok(task_id)) => Some(Box::pin(async move { + let task_id = task_id; + Some(client.get_task(&task_id).await) + })), + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + }; + } + } +} diff --git a/src/queue/job.rs b/src/queue/job.rs index 02617c24..66272adf 100644 --- a/src/queue/job.rs +++ b/src/queue/job.rs @@ -81,6 +81,7 @@ pub struct SharedJobState { pub auth0_client: Auth0Client, pub postmark_client: PostmarkClient, pub http_client: Client, + pub crypter: crate::Crypter, } impl From<&Config> for SharedJobState { fn from(config: &Config) -> Self { @@ -88,6 +89,7 @@ impl From<&Config> for SharedJobState { auth0_client: Auth0Client::new(config), postmark_client: PostmarkClient::new(config), http_client: config.client.clone(), + crypter: config.crypter.clone(), } } } diff --git a/src/queue/job/v1/task_sync.rs b/src/queue/job/v1/task_sync.rs index 2449789a..ea02335d 100644 --- a/src/queue/job/v1/task_sync.rs +++ b/src/queue/job/v1/task_sync.rs @@ -3,7 +3,7 @@ use crate::{ queue::job::{EnqueueJob, Job, JobError, SharedJobState, V1}, }; use futures_lite::StreamExt; -use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, PaginatorTrait, QueryFilter}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter}; use serde::{Deserialize, Serialize}; use time::Duration; @@ -24,7 +24,9 @@ impl TaskSync { .await?; for aggregator in aggregators { - let client = aggregator.client(job_state.http_client.clone()); + let client = aggregator + .client(job_state.http_client.clone(), &job_state.crypter) + .map_err(|e| JobError::ClientOther(e.to_string()))?; while let Some(task_from_aggregator) = client.task_stream().next().await.transpose()? { let task_id = task_from_aggregator.task_id.to_string(); if let Some(_task_from_db) = Tasks::find_by_id(&task_id).one(db).await? { diff --git a/tests/aggregator_client.rs b/tests/aggregator_client.rs index 38dd1e06..588d5a8e 100644 --- a/tests/aggregator_client.rs +++ b/tests/aggregator_client.rs @@ -21,11 +21,13 @@ async fn streaming_tasks(app: DivviupApi, client_logs: ClientLogs) -> TestResult == "application/vnd.janus.aggregator+json;version=0.1" })); + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Authorization) - .unwrap() - == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() })); assert_eq!(logs.len(), 28); // one per task plus three pages @@ -41,23 +43,22 @@ async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let logs = client_logs.logs(); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Accept) - .unwrap() - == "application/vnd.janus.aggregator+json;version=0.1" + log.request_headers.eq_ignore_ascii_case( + KnownHeaderName::Accept, + "application/vnd.janus.aggregator+json;version=0.1", + ) })); + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Authorization) - .unwrap() - == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() })); - let queries = logs - .iter() - .map(|log| log.url.query().clone()) - .collect::>(); + let queries = logs.iter().map(|log| log.url.query()).collect::>(); assert_eq!( &queries, &[ diff --git a/tests/crypter.rs b/tests/crypter.rs index 674f5bdb..85ffc171 100644 --- a/tests/crypter.rs +++ b/tests/crypter.rs @@ -15,7 +15,7 @@ fn round_trip_with_current_key() { #[test] fn round_trip_with_old_key() { let old_key = Crypter::generate_key(); - let crypter = Crypter::from(old_key.clone()); + let crypter = Crypter::from(old_key); let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); let crypter = Crypter::new(Crypter::generate_key(), [old_key]); @@ -49,12 +49,10 @@ fn parsing() { let keys = std::iter::repeat_with(Crypter::generate_key) .take(5) .collect::>(); - let encrypted = Crypter::from(keys[0].clone()) - .encrypt(AAD, PLAINTEXT) - .unwrap(); + let encrypted = Crypter::from(keys[0]).encrypt(AAD, PLAINTEXT).unwrap(); let crypter = keys .iter() - .map(|k| URL_SAFE_NO_PAD.encode(&k)) + .map(|k| URL_SAFE_NO_PAD.encode(k)) .collect::>() .join(",") .parse::()