diff --git a/gasket/Cargo.toml b/gasket/Cargo.toml index bcbb643..06a8dfa 100644 --- a/gasket/Cargo.toml +++ b/gasket/Cargo.toml @@ -28,4 +28,6 @@ tokio = { version = "1", features = ["rt", "time", "sync", "macros", "rt-multi-t approx = "0.5.1" [features] +default = [] derive = ["gasket-derive"] +threaded = [] diff --git a/gasket/examples/dumb.rs b/gasket/examples/dumb.rs index da1bb26..5f5e0b6 100644 --- a/gasket/examples/dumb.rs +++ b/gasket/examples/dumb.rs @@ -44,7 +44,7 @@ struct TickerUnit { instant: Instant, } -#[async_trait::async_trait(?Send)] +#[async_trait::async_trait] impl Worker for Ticker { async fn bootstrap(_: &TickerSpec) -> Result { Ok(Self { @@ -117,7 +117,7 @@ struct Terminal { value: Option, } -#[async_trait::async_trait(?Send)] +#[async_trait::async_trait] impl Worker for Terminal { async fn bootstrap(_: &TerminalSpec) -> Result { Ok(Self { @@ -164,7 +164,8 @@ impl Worker for Terminal { } } -fn main() { +#[tokio::main] +async fn main() { tracing::subscriber::set_global_default( tracing_subscriber::FmtSubscriber::builder() .with_max_level(tracing::Level::TRACE) @@ -232,9 +233,8 @@ fn main() { ); let tethers = vec![tether1, tether2, tether3]; - let pipeline = gasket::daemon::Daemon::new(tethers); - pipeline.block(); + pipeline.block().await; // match tether.read_metrics() { // Ok(readings) => { diff --git a/gasket/src/daemon.rs b/gasket/src/daemon.rs index 0f6b840..8d3fdfd 100644 --- a/gasket/src/daemon.rs +++ b/gasket/src/daemon.rs @@ -56,6 +56,7 @@ impl Daemon { false } + #[cfg(feature = "threaded")] pub fn teardown(self) { // first pass is to notify that we should stop for tether in self.0.iter() { @@ -76,6 +77,37 @@ impl Daemon { } } + #[cfg(not(feature = "threaded"))] + pub async fn teardown(self) { + // first pass is to notify that we should stop + for tether in self.0.iter() { + let state = tether.check_state(); + info!(stage = tether.name(), ?state, "dismissing stage"); + + match tether.dismiss_stage() { + Ok(_) => (), + Err(crate::error::Error::TetherDropped) => debug!("stage already dismissed"), + error => warn!(?error, "couldn't dismiss stage"), + } + } + + // second pass is to wait for graceful shutdown + info!("waiting for stages to end"); + for tether in self.0.into_iter() { + tether.join_stage().await; + } + } + + #[cfg(not(feature = "threaded"))] + pub async fn block(self) { + while !self.should_stop() { + tokio::time::sleep(Duration::from_millis(1500)).await; + } + + self.teardown().await; + } + + #[cfg(feature = "threaded")] pub fn block(self) { while !self.should_stop() { std::thread::sleep(Duration::from_millis(1500)); diff --git a/gasket/src/framework.rs b/gasket/src/framework.rs index c1b9fbd..98d2ad4 100644 --- a/gasket/src/framework.rs +++ b/gasket/src/framework.rs @@ -6,8 +6,8 @@ use tracing::{error, warn}; #[cfg(feature = "derive")] pub use gasket_derive::*; -pub trait Stage: Sized + Send { - type Unit; +pub trait Stage: Sized + Send + Sync { + type Unit: Send + Sync; type Worker: Worker; fn name(&self) -> &str; @@ -37,7 +37,7 @@ pub enum WorkerError { type Result = core::result::Result; -pub trait AsWorkError { +pub trait AsWorkError: Send + Sync { fn or_panic(self) -> Result; fn or_retry(self) -> Result; fn or_restart(self) -> Result; @@ -45,7 +45,8 @@ pub trait AsWorkError { impl AsWorkError for core::result::Result where - E: Display, + T: Send + Sync, + E: Send + Sync + Display, { fn or_panic(self) -> Result { match self { @@ -87,8 +88,8 @@ pub enum WorkSchedule { Done, } -#[async_trait::async_trait(?Send)] -pub trait Worker: Sized +#[async_trait::async_trait] +pub trait Worker: Send + Sync + Sized where S: Stage, { diff --git a/gasket/src/runtime.rs b/gasket/src/runtime.rs index c1f2024..3495536 100644 --- a/gasket/src/runtime.rs +++ b/gasket/src/runtime.rs @@ -1,6 +1,5 @@ use std::{ sync::{Arc, Weak}, - thread::JoinHandle, time::{Duration, Instant}, }; @@ -135,7 +134,7 @@ where impl StageMachine where - S: Stage, + S: Stage + Send + Sync, { fn new(anchor: Arc, stage: S, policy: Policy, name: String) -> Self { StageMachine { @@ -335,7 +334,6 @@ impl Anchor { } fn dismiss_stage(&self) -> Result<(), crate::error::Error> { - println!("cancelling stage"); self.dismissed.cancel(); Ok(()) @@ -346,8 +344,13 @@ impl Anchor { pub struct Tether { name: String, anchor_ref: Weak, - thread_handle: JoinHandle<()>, policy: Policy, + + #[cfg(feature = "threaded")] + thread_handle: std::thread::JoinHandle<()>, + + #[cfg(not(feature = "threaded"))] + thread_handle: tokio::task::JoinHandle<()>, } #[derive(Debug, PartialEq)] @@ -362,12 +365,18 @@ impl Tether { &self.name } + #[cfg(feature = "threaded")] pub fn join_stage(self) { self.thread_handle .join() .expect("called from outside thread"); } + #[cfg(not(feature = "threaded"))] + pub async fn join_stage(self) { + self.thread_handle.await; + } + pub fn try_anchor(&self) -> Result, crate::error::Error> { match self.anchor_ref.upgrade() { Some(anchor) => Ok(anchor), @@ -439,18 +448,41 @@ impl Default for Policy { } #[instrument(name="stage", level = Level::INFO, skip_all, fields(stage = machine.name))] -fn fullfil_stage(mut machine: StageMachine) +async fn fullfil_stage(mut machine: StageMachine) where S: Stage, { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + while machine.transition().await != StagePhase::Ended {} +} + +#[cfg(not(feature = "threaded"))] +pub fn spawn_stage(stage: S, policy: Policy) -> Tether +where + S: Stage + 'static, +{ + let name = stage.name().to_owned(); + + let metrics = stage.metrics(); + + let anchor = Arc::new(Anchor::new(metrics)); + let anchor_ref = Arc::downgrade(&anchor); + + let policy2 = policy.clone(); + let name2 = name.clone(); + + let machine = StageMachine::::new(anchor, stage, policy2, name2); - rt.block_on(async { while machine.transition().await != StagePhase::Ended {} }); + let thread_handle = tokio::spawn(fullfil_stage(machine)); + + Tether { + name, + anchor_ref, + thread_handle, + policy, + } } +#[cfg(feature = "threaded")] pub fn spawn_stage(stage: S, policy: Policy) -> Tether where S: Stage + 'static, @@ -466,7 +498,13 @@ where let name2 = name.clone(); let thread_handle = std::thread::spawn(move || { let machine = StageMachine::::new(anchor, stage, policy2, name2); - fullfil_stage(machine); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(fullfil_stage(machine)); }); Tether { @@ -521,7 +559,7 @@ pub mod tests { } } - #[async_trait::async_trait(?Send)] + #[async_trait::async_trait] impl Worker for MockWorker { async fn bootstrap(_: &MockStage) -> Result { Ok(Self { @@ -688,6 +726,7 @@ pub mod tests { // assert!(elapsed.as_millis() <= 2250); // } + #[cfg(feature = "threaded")] #[tokio::test(flavor = "multi_thread")] async fn honors_cancel_in_time() { let expected_shutdown = Duration::from_millis(1_000);