Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

naïve version of task sync #364

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/api_mocks/aggregator_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use trillium::{Conn, Handler, Status};
use trillium_api::{api, Json};
use trillium_http::KnownHeaderName;
use trillium_router::{router, RouterConnExt};
use uuid::Uuid;

pub const BAD_BEARER_TOKEN: &str = "badbearertoken";

Expand Down Expand Up @@ -129,21 +128,17 @@ async fn task_ids(conn: &mut Conn, (): ()) -> Result<Json<TaskIds>, Status> {
let query = QueryStrong::parse(conn.querystring()).map_err(|_| Status::InternalServerError)?;
match query.get_str("pagination_token") {
None => Ok(Json(TaskIds {
task_ids: repeat_with(|| Uuid::new_v4().to_string())
.take(10)
.collect(),
task_ids: repeat_with(random).take(10).collect(),
pagination_token: Some("second".into()),
})),

Some("second") => Ok(Json(TaskIds {
task_ids: repeat_with(|| Uuid::new_v4().to_string())
.take(10)
.collect(),
task_ids: repeat_with(random).take(10).collect(),
pagination_token: Some("last".into()),
})),

_ => Ok(Json(TaskIds {
task_ids: repeat_with(|| Uuid::new_v4().to_string()).take(5).collect(),
task_ids: repeat_with(random).take(5).collect(),
pagination_token: None,
})),
}
Expand Down
42 changes: 25 additions & 17 deletions src/clients/aggregator_client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
mod task_id_stream;
mod task_stream;

use task_id_stream::TaskIdStream;
use task_stream::TaskStream;

use crate::{
clients::{ClientConnExt, ClientError},
entity::{task::ProvisionableTask, Aggregator},
handler::Error,
};
use futures_lite::StreamExt;
use serde::{de::DeserializeOwned, Serialize};
use trillium::{HeaderValue, KnownHeaderName, Method};
use trillium_client::{Client, Conn};
Expand Down Expand Up @@ -57,24 +64,17 @@ impl AggregatorClient {
.map_err(Into::into)
}

pub async fn get_task_id_page(&self, page: Option<&str>) -> Result<TaskIds, ClientError> {
let path = if let Some(pagination_token) = page {
format!("task_ids?pagination_token={pagination_token}")
} else {
"task_ids".into()
};
self.get(&path).await
}

pub async fn get_task_ids(&self) -> Result<Vec<String>, ClientError> {
let mut ids = vec![];
let mut path = String::from("task_ids");
loop {
let TaskIds {
task_ids,
pagination_token,
} = self.get(&path).await?;

ids.extend(task_ids);

match pagination_token {
Some(pagination_token) => {
path = format!("task_ids?pagination_token={pagination_token}");
}
None => break Ok(ids),
}
}
self.task_id_stream().try_collect().await
}

pub async fn get_task(&self, task_id: &str) -> Result<TaskResponse, ClientError> {
Expand Down Expand Up @@ -138,4 +138,12 @@ impl AggregatorClient {
.await?;
Ok(())
}

pub fn task_id_stream(&self) -> TaskIdStream<'_> {
TaskIdStream::new(self)
}

pub fn task_stream(&self) -> TaskStream<'_> {
TaskStream::new(self)
}
}
2 changes: 1 addition & 1 deletion src/clients/aggregator_client/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl TaskResponse {

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TaskIds {
pub task_ids: Vec<String>,
pub task_ids: Vec<TaskId>,
pub pagination_token: Option<String>,
}

Expand Down
93 changes: 93 additions & 0 deletions src/clients/aggregator_client/task_id_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::{
collections::VecDeque,
fmt::{self, Debug, Formatter},
pin::Pin,
task::{ready, Context, Poll},
};

use super::{AggregatorClient, TaskIds};
use crate::clients::ClientError;
use futures_lite::{stream::Stream, Future};

#[derive(Clone, Debug)]
struct Page {
task_ids: VecDeque<String>,
pagination_token: Option<String>,
}

impl From<TaskIds> for Page {
fn from(
TaskIds {
task_ids,
pagination_token,
}: TaskIds,
) -> Self {
Page {
task_ids: task_ids.into_iter().map(|t| t.to_string()).collect(),
pagination_token,
}
}
}

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;

pub struct TaskIdStream<'a> {
client: &'a AggregatorClient,
page: Option<Page>,
future: Option<BoxFuture<'a, Result<TaskIds, ClientError>>>,
}

impl<'a> TaskIdStream<'a> {
pub(super) fn new(client: &'a AggregatorClient) -> Self {
Self {
client,
page: None,
future: None,
}
}
}

impl<'a> Debug for TaskIdStream<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TaskIdStream")
.field("client", &self.client)
.field("current_page", &self.page)
.field("current_future", &"..")
.finish()
}
}

impl Stream for TaskIdStream<'_> {
type Item = Result<String, ClientError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
client,
ref mut page,
ref mut future,
} = *self;

loop {
if let Some(page) = page {
if let Some(task_id) = page.task_ids.pop_front() {
return Poll::Ready(Some(Ok(task_id)));
}

if page.pagination_token.is_none() {
return Poll::Ready(None);
}
}

if let Some(fut) = future {
*page = Some(ready!(Pin::new(&mut *fut).poll(cx))?.into());
*future = None;
} else {
let pagination_token = page.as_ref().and_then(|page| page.pagination_token.clone());

*future = Some(Box::pin(async move {
client.get_task_id_page(pagination_token.as_deref()).await
}));
};
}
}
}
62 changes: 62 additions & 0 deletions src/clients/aggregator_client/task_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::{
fmt::{self, Debug, Formatter},
pin::Pin,
task::{ready, Context, Poll},
};

use super::{task_id_stream::TaskIdStream, AggregatorClient, TaskResponse};
use crate::clients::ClientError;
use futures_lite::{stream::Stream, Future};

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;

pub struct TaskStream<'a> {
client: &'a AggregatorClient,
task_id_stream: TaskIdStream<'a>,
task_future: Option<BoxFuture<'a, Option<Result<TaskResponse, ClientError>>>>,
}

impl<'a> Debug for TaskStream<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TaskStream").field("future", &"..").finish()
}
}

impl<'a> TaskStream<'a> {
pub(super) fn new(client: &'a AggregatorClient) -> Self {
Self {
task_id_stream: client.task_id_stream(),
client,
task_future: None,
}
}
}

impl Stream for TaskStream<'_> {
type Item = Result<TaskResponse, ClientError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
client,
ref mut task_id_stream,
ref mut task_future,
} = *self;

loop {
if let Some(future) = task_future {
let res = ready!(Pin::new(&mut *future).poll(cx));
*task_future = None;
return Poll::Ready(res);
}

*task_future = match ready!(Pin::new(&mut *task_id_stream).poll_next(cx)) {
Some(Ok(task_id)) => Some(Box::pin(async move {
let task_id = task_id;
Some(client.get_task(&task_id).await)
})),
None => return Poll::Ready(None),
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
};
}
}
}
48 changes: 20 additions & 28 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,27 @@ impl Queue {
}

pub async fn schedule_recurring_tasks_if_needed(&self) -> Result<(), DbErr> {
let tx = self.db.begin().await?;

let session_cleanup_jobs = Entity::find()
.filter(all![
Expr::cust_with_expr("job->>'type' = $1", "SessionCleanup"),
Column::ScheduledAt.gt(OffsetDateTime::now_utc()),
])
.count(&tx)
.await?;

if session_cleanup_jobs == 0 {
Job::from(SessionCleanup).insert(&tx).await?;
}
tx.commit().await?;

let tx = self.db.begin().await?;
let queue_cleanup_jobs = Entity::find()
.filter(all![
Expr::cust_with_expr("job->>'type' = $1", "QueueCleanup"),
Column::ScheduledAt.gt(OffsetDateTime::now_utc()),
])
.count(&tx)
.await?;

if queue_cleanup_jobs == 0 {
Job::from(QueueCleanup).insert(&tx).await?;
let schedulable_jobs = [
(Job::from(SessionCleanup), "SessionCleanup"),
(Job::from(QueueCleanup), "QueueCleanup"),
(Job::from(TaskSync), "TaskSync"),
];

for (job, name) in schedulable_jobs {
let tx = self.db.begin().await?;
let existing_jobs = Entity::find()
.filter(all![
Expr::cust_with_expr("job->>'type' = $1", name),
Column::ScheduledAt.gt(OffsetDateTime::now_utc()),
])
.count(&tx)
.await?;

if existing_jobs == 0 {
job.insert(&tx).await?;
tx.commit().await?;
}
}
tx.commit().await?;

Ok(())
}

Expand Down
9 changes: 8 additions & 1 deletion src/queue/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use time::{Duration, OffsetDateTime};
use trillium::{Method, Status};
use trillium_client::Client;
use url::Url;

mod v1;
pub use v1::{CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, V1};
pub use v1::{
CreateUser, QueueCleanup, ResetPassword, SendInvitationEmail, SessionCleanup, TaskSync, V1,
};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "version")]
Expand Down Expand Up @@ -77,12 +80,16 @@ impl From<ClientError> for JobError {
pub struct SharedJobState {
pub auth0_client: Auth0Client,
pub postmark_client: PostmarkClient,
pub http_client: Client,
pub crypter: crate::Crypter,
}
impl From<&Config> for SharedJobState {
fn from(config: &Config) -> Self {
Self {
auth0_client: Auth0Client::new(config),
postmark_client: PostmarkClient::new(config),
http_client: config.client.clone(),
crypter: config.crypter.clone(),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/queue/job/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod queue_cleanup;
mod reset_password;
mod send_invitation_email;
mod session_cleanup;
mod task_sync;

use crate::queue::EnqueueJob;

Expand All @@ -15,6 +16,7 @@ pub use queue_cleanup::QueueCleanup;
pub use reset_password::ResetPassword;
pub use send_invitation_email::SendInvitationEmail;
pub use session_cleanup::SessionCleanup;
pub use task_sync::TaskSync;

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "type")]
Expand All @@ -24,6 +26,7 @@ pub enum V1 {
ResetPassword(ResetPassword),
SessionCleanup(SessionCleanup),
QueueCleanup(QueueCleanup),
TaskSync(TaskSync),
}

impl V1 {
Expand All @@ -38,6 +41,7 @@ impl V1 {
V1::ResetPassword(job) => job.perform(job_state, db).await,
V1::SessionCleanup(job) => job.perform(job_state, db).await,
V1::QueueCleanup(job) => job.perform(job_state, db).await,
V1::TaskSync(job) => job.perform(job_state, db).await,
}
}
}
Loading