diff --git a/src/api_mocks/aggregator_api.rs b/src/api_mocks/aggregator_api.rs index 01148528..e1b13186 100644 --- a/src/api_mocks/aggregator_api.rs +++ b/src/api_mocks/aggregator_api.rs @@ -13,7 +13,6 @@ use trillium::{Conn, Handler, Status}; use trillium_api::{api, Json}; use trillium_http::KnownHeaderName; use trillium_router::{router, RouterConnExt}; -use uuid::Uuid; pub const BAD_BEARER_TOKEN: &str = "badbearertoken"; @@ -129,21 +128,17 @@ async fn task_ids(conn: &mut Conn, (): ()) -> Result, Status> { let query = QueryStrong::parse(conn.querystring()).map_err(|_| Status::InternalServerError)?; match query.get_str("pagination_token") { None => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()) - .take(10) - .collect(), + task_ids: repeat_with(random).take(10).collect(), pagination_token: Some("second".into()), })), Some("second") => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()) - .take(10) - .collect(), + task_ids: repeat_with(random).take(10).collect(), pagination_token: Some("last".into()), })), _ => Ok(Json(TaskIds { - task_ids: repeat_with(|| Uuid::new_v4().to_string()).take(5).collect(), + task_ids: repeat_with(random).take(5).collect(), pagination_token: None, })), } diff --git a/src/clients/aggregator_client.rs b/src/clients/aggregator_client.rs index 1821bb3f..d4c00982 100644 --- a/src/clients/aggregator_client.rs +++ b/src/clients/aggregator_client.rs @@ -1,8 +1,15 @@ +mod task_id_stream; +mod task_stream; + +use task_id_stream::TaskIdStream; +use task_stream::TaskStream; + use crate::{ clients::{ClientConnExt, ClientError}, entity::{task::ProvisionableTask, Aggregator}, handler::Error, }; +use futures_lite::StreamExt; use serde::{de::DeserializeOwned, Serialize}; use trillium::{HeaderValue, KnownHeaderName, Method}; use trillium_client::{Client, Conn}; @@ -57,24 +64,17 @@ impl AggregatorClient { .map_err(Into::into) } + pub async fn get_task_id_page(&self, page: Option<&str>) -> Result { + let path = if let Some(pagination_token) = page { + format!("task_ids?pagination_token={pagination_token}") + } else { + "task_ids".into() + }; + self.get(&path).await + } + pub async fn get_task_ids(&self) -> Result, ClientError> { - let mut ids = vec![]; - let mut path = String::from("task_ids"); - loop { - let TaskIds { - task_ids, - pagination_token, - } = self.get(&path).await?; - - ids.extend(task_ids); - - match pagination_token { - Some(pagination_token) => { - path = format!("task_ids?pagination_token={pagination_token}"); - } - None => break Ok(ids), - } - } + self.task_id_stream().try_collect().await } pub async fn get_task(&self, task_id: &str) -> Result { @@ -138,4 +138,12 @@ impl AggregatorClient { .await?; Ok(()) } + + pub fn task_id_stream(&self) -> TaskIdStream<'_> { + TaskIdStream::new(self) + } + + pub fn task_stream(&self) -> TaskStream<'_> { + TaskStream::new(self) + } } diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 66430250..dc7a4cba 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -220,7 +220,7 @@ impl TaskResponse { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TaskIds { - pub task_ids: Vec, + pub task_ids: Vec, pub pagination_token: Option, } diff --git a/src/clients/aggregator_client/task_id_stream.rs b/src/clients/aggregator_client/task_id_stream.rs new file mode 100644 index 00000000..b37e275e --- /dev/null +++ b/src/clients/aggregator_client/task_id_stream.rs @@ -0,0 +1,93 @@ +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{AggregatorClient, TaskIds}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +#[derive(Clone, Debug)] +struct Page { + task_ids: VecDeque, + pagination_token: Option, +} + +impl From for Page { + fn from( + TaskIds { + task_ids, + pagination_token, + }: TaskIds, + ) -> Self { + Page { + task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(), + pagination_token, + } + } +} + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskIdStream<'a> { + client: &'a AggregatorClient, + page: Option, + future: Option>>, +} + +impl<'a> TaskIdStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + client, + page: None, + future: None, + } + } +} + +impl<'a> Debug for TaskIdStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskIdStream") + .field("client", &self.client) + .field("current_page", &self.page) + .field("current_future", &"..") + .finish() + } +} + +impl Stream for TaskIdStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut page, + ref mut future, + } = *self; + + loop { + if let Some(page) = page { + if let Some(task_id) = page.task_ids.pop_front() { + return Poll::Ready(Some(Ok(task_id))); + } + + if page.pagination_token.is_none() { + return Poll::Ready(None); + } + } + + if let Some(fut) = future { + *page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into()); + *future = None; + } else { + let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone()); + + *future = Some(Box::pin(async move { + client.get_task_id_page(pagination_token.as_deref()).await + })); + }; + } + } +} diff --git a/src/clients/aggregator_client/task_stream.rs b/src/clients/aggregator_client/task_stream.rs new file mode 100644 index 00000000..a23a2972 --- /dev/null +++ b/src/clients/aggregator_client/task_stream.rs @@ -0,0 +1,62 @@ +use std::{ + fmt::{self, Debug, Formatter}, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use super::{task_id_stream::TaskIdStream, AggregatorClient, TaskResponse}; +use crate::clients::ClientError; +use futures_lite::{stream::Stream, Future}; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; + +pub struct TaskStream<'a> { + client: &'a AggregatorClient, + task_id_stream: TaskIdStream<'a>, + task_future: Option>>>, +} + +impl<'a> Debug for TaskStream<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("TaskStream").field("future", &"..").finish() + } +} + +impl<'a> TaskStream<'a> { + pub(super) fn new(client: &'a AggregatorClient) -> Self { + Self { + task_id_stream: client.task_id_stream(), + client, + task_future: None, + } + } +} + +impl Stream for TaskStream<'_> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + client, + ref mut task_id_stream, + ref mut task_future, + } = *self; + + loop { + if let Some(future) = task_future { + let res = ready!(Pin::new(&mut *future).poll(cx)); + *task_future = None; + return Poll::Ready(res); + } + + *task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) { + Some(Ok(task_id)) => Some(Box::pin(async move { + let task_id = task_id; + Some(client.get_task(&task_id).await) + })), + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + }; + } + } +} diff --git a/src/queue.rs b/src/queue.rs index 576d1916..5842f2c1 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -73,35 +73,27 @@ impl Queue { } pub async fn schedule_recurring_tasks_if_needed(&self) -> Result<(), DbErr> { - let tx = self.db.begin().await?; - - let session_cleanup_jobs = Entity::find() - .filter(all![ - Expr::cust_with_expr("job->>'type' = $1", "SessionCleanup"), - Column::ScheduledAt.gt(OffsetDateTime::now_utc()), - ]) - .count(&tx) - .await?; - - if session_cleanup_jobs == 0 { - Job::from(SessionCleanup).insert(&tx).await?; - } - tx.commit().await?; - - let tx = self.db.begin().await?; - let queue_cleanup_jobs = Entity::find() - .filter(all![ - Expr::cust_with_expr("job->>'type' = $1", "QueueCleanup"), - Column::ScheduledAt.gt(OffsetDateTime::now_utc()), - ]) - .count(&tx) - .await?; - - if queue_cleanup_jobs == 0 { - Job::from(QueueCleanup).insert(&tx).await?; + let schedulable_jobs = [ + (Job::from(SessionCleanup), "SessionCleanup"), + (Job::from(QueueCleanup), "QueueCleanup"), + (Job::from(TaskSync), "TaskSync"), + ]; + + for (job, name) in schedulable_jobs { + let tx = self.db.begin().await?; + let existing_jobs = Entity::find() + .filter(all![ + Expr::cust_with_expr("job->>'type' = $1", name), + Column::ScheduledAt.gt(OffsetDateTime::now_utc()), + ]) + .count(&tx) + .await?; + + if existing_jobs == 0 { + job.insert(&tx).await?; + tx.commit().await?; + } } - tx.commit().await?; - Ok(()) } diff --git a/src/queue/job.rs b/src/queue/job.rs index 6c5cfa53..66272adf 100644 --- a/src/queue/job.rs +++ b/src/queue/job.rs @@ -8,10 +8,13 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use time::{Duration, OffsetDateTime}; use trillium::{Method, Status}; +use trillium_client::Client; use url::Url; mod v1; -pub use v1::{CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, V1}; +pub use v1::{ + CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, TaskSync, V1, +}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "version")] @@ -77,12 +80,16 @@ impl From for JobError { pub struct SharedJobState { pub auth0_client: Auth0Client, pub postmark_client: PostmarkClient, + pub http_client: Client, + pub crypter: crate::Crypter, } impl From<&Config> for SharedJobState { fn from(config: &Config) -> Self { Self { auth0_client: Auth0Client::new(config), postmark_client: PostmarkClient::new(config), + http_client: config.client.clone(), + crypter: config.crypter.clone(), } } } diff --git a/src/queue/job/v1.rs b/src/queue/job/v1.rs index bca2943a..d6140e27 100644 --- a/src/queue/job/v1.rs +++ b/src/queue/job/v1.rs @@ -3,6 +3,7 @@ mod queue_cleanup; mod reset_password; mod send_invitation_email; mod session_cleanup; +mod task_sync; use crate::queue::EnqueueJob; @@ -15,6 +16,7 @@ pub use queue_cleanup::QueueCleanup; pub use reset_password::ResetPassword; pub use send_invitation_email::SendInvitationEmail; pub use session_cleanup::SessionCleanup; +pub use task_sync::TaskSync; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "type")] @@ -24,6 +26,7 @@ pub enum V1 { ResetPassword(ResetPassword), SessionCleanup(SessionCleanup), QueueCleanup(QueueCleanup), + TaskSync(TaskSync), } impl V1 { @@ -38,6 +41,7 @@ impl V1 { V1::ResetPassword(job) => job.perform(job_state, db).await, V1::SessionCleanup(job) => job.perform(job_state, db).await, V1::QueueCleanup(job) => job.perform(job_state, db).await, + V1::TaskSync(job) => job.perform(job_state, db).await, } } } diff --git a/src/queue/job/v1/task_sync.rs b/src/queue/job/v1/task_sync.rs new file mode 100644 index 00000000..ea02335d --- /dev/null +++ b/src/queue/job/v1/task_sync.rs @@ -0,0 +1,59 @@ +use crate::{ + entity::*, + queue::job::{EnqueueJob, Job, JobError, SharedJobState, V1}, +}; +use futures_lite::StreamExt; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter}; +use serde::{Deserialize, Serialize}; +use time::Duration; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy)] +pub struct TaskSync; + +const SYNC_PERIOD: Duration = Duration::weeks(1); + +impl TaskSync { + pub async fn perform( + &mut self, + job_state: &SharedJobState, + db: &impl ConnectionTrait, + ) -> Result, JobError> { + let aggregators = Aggregators::find() + .filter(AggregatorColumn::IsFirstParty.eq(true)) // eventually we may want to check for a capability + .all(db) + .await?; + + for aggregator in aggregators { + let client = aggregator + .client(job_state.http_client.clone(), &job_state.crypter) + .map_err(|e| JobError::ClientOther(e.to_string()))?; + while let Some(task_from_aggregator) = client.task_stream().next().await.transpose()? { + let task_id = task_from_aggregator.task_id.to_string(); + if let Some(_task_from_db) = Tasks::find_by_id(&task_id).one(db).await? { + // TODO: confirm that the task matches + } else { + client.delete_task(&task_id).await?; + } + } + } + + Ok(Some(EnqueueJob::from(*self).scheduled_in(SYNC_PERIOD))) + } +} + +impl From for Job { + fn from(value: TaskSync) -> Self { + Self::V1(V1::TaskSync(value)) + } +} + +impl PartialEq for TaskSync { + fn eq(&self, other: &Job) -> bool { + matches!(other, Job::V1(V1::TaskSync(c)) if c == self) + } +} +impl PartialEq for Job { + fn eq(&self, other: &TaskSync) -> bool { + matches!(self, Job::V1(V1::TaskSync(j)) if j == other) + } +} diff --git a/tests/aggregator_client.rs b/tests/aggregator_client.rs index a73b5cab..588d5a8e 100644 --- a/tests/aggregator_client.rs +++ b/tests/aggregator_client.rs @@ -1,16 +1,17 @@ use divviup_api::{ api_mocks::aggregator_api::{self, BAD_BEARER_TOKEN}, - clients::AggregatorClient, + clients::{aggregator_client::TaskResponse, AggregatorClient}, }; use test_support::{assert_eq, test, *}; use trillium::Handler; #[test(harness = with_client_logs)] -async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { +async fn streaming_tasks(app: DivviupApi, client_logs: ClientLogs) -> TestResult { + use futures_lite::stream::StreamExt; let aggregator = fixtures::aggregator(&app, None).await; let client = aggregator.client(app.config().client.clone(), app.crypter())?; - let task_ids = client.get_task_ids().await?; - assert_eq!(task_ids.len(), 25); // two pages of 10 plus a final page of 5 + let tasks: Vec = client.task_stream().try_collect().await?; + assert_eq!(tasks.len(), 25); // two pages of 10 plus a final page of 5 let logs = client_logs.logs(); assert!(logs.iter().all(|log| { @@ -20,17 +21,44 @@ async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { == "application/vnd.janus.aggregator+json;version=0.1" })); + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); assert!(logs.iter().all(|log| { - log.request_headers - .get_str(KnownHeaderName::Authorization) - .unwrap() - == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() + })); + + assert_eq!(logs.len(), 28); // one per task plus three pages + Ok(()) +} + +#[test(harness = with_client_logs)] +async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { + let aggregator = fixtures::aggregator(&app, None).await; + let client = aggregator.client(app.config().client.clone(), app.crypter())?; + let task_ids = client.get_task_ids().await?; + assert_eq!(task_ids.len(), 25); // two pages of 10 plus a final page of 5 + + let logs = client_logs.logs(); + assert!(logs.iter().all(|log| { + log.request_headers.eq_ignore_ascii_case( + KnownHeaderName::Accept, + "application/vnd.janus.aggregator+json;version=0.1", + ) + })); + + let expected_header = format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()); + assert!(logs.iter().all(|log| { + expected_header + == log + .request_headers + .get_str(KnownHeaderName::Authorization) + .unwrap() })); - let queries = logs - .iter() - .map(|log| log.url.query().clone()) - .collect::>(); + let queries = logs.iter().map(|log| log.url.query()).collect::>(); assert_eq!( &queries, &[ diff --git a/tests/crypter.rs b/tests/crypter.rs index 674f5bdb..85ffc171 100644 --- a/tests/crypter.rs +++ b/tests/crypter.rs @@ -15,7 +15,7 @@ fn round_trip_with_current_key() { #[test] fn round_trip_with_old_key() { let old_key = Crypter::generate_key(); - let crypter = Crypter::from(old_key.clone()); + let crypter = Crypter::from(old_key); let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); let crypter = Crypter::new(Crypter::generate_key(), [old_key]); @@ -49,12 +49,10 @@ fn parsing() { let keys = std::iter::repeat_with(Crypter::generate_key) .take(5) .collect::>(); - let encrypted = Crypter::from(keys[0].clone()) - .encrypt(AAD, PLAINTEXT) - .unwrap(); + let encrypted = Crypter::from(keys[0]).encrypt(AAD, PLAINTEXT).unwrap(); let crypter = keys .iter() - .map(|k| URL_SAFE_NO_PAD.encode(&k)) + .map(|k| URL_SAFE_NO_PAD.encode(k)) .collect::>() .join(",") .parse::()