diff --git a/configs/config.toml b/configs/config.toml index d97ff2a4..de88c546 100644 --- a/configs/config.toml +++ b/configs/config.toml @@ -9,7 +9,7 @@ action_redirections = { "firmware_update" = "install_update", "send_file" = "loa # Script runner allows users to trigger an action that will run an already downloaded # file as described by the download_path field of the JSON payload of a download action. -script_runner = [{ name = "run_script" }] +script_runner = [{ name = "run_script", timeout = 10 }] # Location on disk for persisted streams to write backlogs into, also used to write persistence_path = "/tmp/uplink/" @@ -141,7 +141,7 @@ priority = 255 # - actions: List of actions names that can trigger the downloader, with configurable timeouts # - path: Location in fs where the files are downloaded into [downloader] -actions = [{ name = "update_firmware" }, { name = "send_file" }, { name = "send_script" }] +actions = [{ name = "update_firmware" }, { name = "send_file", timeout = 10 }, { name = "send_script" }] path = "/var/tmp/ota-file" # Configurations associated with the system stats module of uplink, if enabled diff --git a/uplink/src/base/actions.rs b/uplink/src/base/actions.rs index 1433a1ea..d9c07e4a 100644 --- a/uplink/src/base/actions.rs +++ b/uplink/src/base/actions.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use tokio::time::Instant; use crate::{Payload, Point}; @@ -17,9 +16,6 @@ pub struct Action { pub name: String, // action payload. json. can be args/payload. depends on the invoked command pub payload: String, - // Instant at which action must be timedout - #[serde(skip)] - pub deadline: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -114,3 +110,10 @@ impl Point for ActionResponse { self.timestamp } } + +#[derive(Debug, Deserialize, Serialize)] +pub struct Cancellation { + pub action_id: String, + #[serde(rename = "name")] + pub action_name: String, +} diff --git a/uplink/src/base/bridge/actions_lane.rs b/uplink/src/base/bridge/actions_lane.rs index ffb9083a..62926e41 100644 --- a/uplink/src/base/bridge/actions_lane.rs +++ b/uplink/src/base/bridge/actions_lane.rs @@ -11,6 +11,7 @@ use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, time::Duration} use super::streams::Streams; use super::{ActionBridgeShutdown, Package, StreamMetrics}; +use crate::base::actions::Cancellation; use crate::config::ActionRoute; use crate::{Action, ActionResponse, Config}; @@ -34,10 +35,16 @@ pub enum Error { Busy, #[error("Action Route clash: \"{0}\"")] ActionRouteClash(String), + #[error("Cancellation request received for action currently not in execution!")] + UnexpectedCancellation, + #[error("Cancellation request for action in execution, but names don't match!")] + CorruptedCancellation, + #[error("Cancellation request failed as action completed execution!")] + FailedCancellation, + #[error("Action cancelled by action_id: {0}")] + Cancelled(String), } -struct RedirectionError(Action); - pub struct ActionsBridge { /// All configuration config: Arc, @@ -103,10 +110,10 @@ impl ActionsBridge { pub fn register_action_route( &mut self, - ActionRoute { name, timeout: duration }: ActionRoute, + ActionRoute { name, timeout: duration, cancellable }: ActionRoute, actions_tx: Sender, ) -> Result<(), Error> { - let action_router = ActionRouter { actions_tx, duration }; + let action_router = ActionRouter { actions_tx, duration, cancellable }; if self.action_routes.insert(name.clone(), action_router).is_some() { return Err(Error::ActionRouteClash(name)); } @@ -149,20 +156,59 @@ impl ActionsBridge { select! { action = self.actions_rx.recv_async() => { let action = action?; + + if action.name == "cancel_action" { + self.handle_cancellation(action).await?; + continue + } + self.handle_action(action).await; + } + response = self.status_rx.recv_async() => { let response = response?; self.forward_action_response(response).await; } + _ = &mut self.current_action.as_mut().map(|a| &mut a.timeout).unwrap_or(&mut end) => { - let action = self.current_action.take().unwrap(); - error!("Timeout waiting for action response. Action ID = {}", action.id); - self.forward_action_error(action.action, Error::ActionTimeout).await; + let curret_action = self.current_action.as_mut().unwrap(); + let action_id = curret_action.action.action_id.clone(); + let action_name = curret_action.action.name.clone(); + + let route = self + .action_routes + .get(&action_name) + .expect("Action shouldn't be in execution if it can't be routed!"); + + if !route.is_cancellable() { + // Directly send timeout failure response if handler doesn't allow action cancellation + error!("Timeout waiting for action response. Action ID = {}", action_id); + self.forward_action_error(&action_id, Error::ActionTimeout).await; + + // Remove action because it timedout + self.clear_current_action(); + continue; + } + + let cancellation = Cancellation { action_id, action_name }; + let payload = serde_json::to_string(&cancellation)?; + let cancel_action = Action { + action_id: "timeout".to_owned(), // Describes cause of action cancellation. NOTE: Action handler shouldn't expect an integer. + name: "cancel_action".to_owned(), + payload, + }; + if route.try_send(cancel_action).is_err() { + error!("Couldn't cancel action ({}) on timeout: {}", cancellation.action_id, Error::UnresponsiveReceiver); + // Remove action anyways + self.clear_current_action(); + continue; + } - // Remove action because it timedout - self.clear_current_action() + // set timeout to end of time in wait of cancellation response + curret_action.timeout = Box::pin(time::sleep(Duration::from_secs(u64::MAX))); } + // Flush streams that timeout Some(timedout_stream) = self.streams.stream_timeouts.next(), if self.streams.stream_timeouts.has_pending() => { debug!("Flushing stream = {}", timedout_stream); @@ -170,12 +216,14 @@ impl ActionsBridge { error!("Failed to flush stream = {}. Error = {}", timedout_stream, e); } } + // Flush all metrics when timed out _ = metrics_timeout.tick() => { if let Err(e) = self.streams.check_and_flush_metrics() { debug!("Failed to flush stream metrics. Error = {}", e); } } + // Handle a shutdown signal _ = self.ctrl_rx.recv_async() => { if let Err(e) = self.save_current_action() { @@ -199,9 +247,9 @@ impl ActionsBridge { if action.name != TUNSHELL_ACTION { warn!( "Another action is currently occupying uplink; action_id = {}", - current_action.id + current_action.action.action_id ); - self.forward_action_error(action, Error::Busy).await; + self.forward_action_error(&action.action_id, Error::Busy).await; return; } } @@ -211,7 +259,7 @@ impl ActionsBridge { let error = match self.try_route_action(action.clone()) { Ok(_) => { let response = ActionResponse::progress(&action_id, "Received", 0); - self.forward_action_response(response).await; + self.streams.forward(response).await; return; } Err(e) => e, @@ -231,7 +279,52 @@ impl ActionsBridge { } error!("Failed to route action to app. Error = {:?}", error); - self.forward_action_error(action, error).await; + self.forward_action_error(&action.action_id, error).await; + } + + /// Forwards cancellation request to the handler if it can handle the same, + /// else marks the current action as cancelled and avoids further redirections + async fn handle_cancellation(&mut self, action: Action) -> Result<(), Error> { + let action_id = action.action_id.clone(); + let Some(current_action) = self.current_action.as_ref() else { + self.forward_action_error(&action_id, Error::UnexpectedCancellation).await; + return Ok(()); + }; + let mut cancellation: Cancellation = serde_json::from_str(&action.payload)?; + if cancellation.action_id != current_action.action.action_id { + warn!("Unexpected cancellation: {cancellation:?}"); + self.forward_action_error(&action_id, Error::UnexpectedCancellation).await; + return Ok(()); + } + + info!("Received action cancellation: {:?}", cancellation); + if cancellation.action_name != current_action.action.name { + debug!( + "Action was redirected: {} ~> {}", + cancellation.action_name, current_action.action.name + ); + current_action.action.name.clone_into(&mut cancellation.action_name); + } + + let route = self + .action_routes + .get(&cancellation.action_name) + .expect("Action shouldn't be in execution if it can't be routed!"); + + // Ensure that action redirections for the action are turned off, + // action will be cancelled on next attempt to redirect + self.current_action.as_mut().unwrap().cancelled_by = Some(action_id.clone()); + + if route.is_cancellable() { + if let Err(e) = route.try_send(action).map_err(|_| Error::UnresponsiveReceiver) { + self.forward_action_error(&action_id, e).await; + return Ok(()); + } + } + let response = ActionResponse::progress(&action_id, "Received", 0); + self.streams.forward(response).await; + + Ok(()) } /// Save current action information in persistence @@ -255,7 +348,10 @@ impl ActionsBridge { if path.is_file() { let current_action = CurrentAction::read_from_disk(path)?; - info!("Loading saved action from persistence; action_id: {}", current_action.id); + info!( + "Loading saved action from persistence; action_id: {}", + current_action.action.action_id + ); self.current_action = Some(current_action) } @@ -282,6 +378,11 @@ impl ActionsBridge { } async fn forward_action_response(&mut self, mut response: ActionResponse) { + // Ignore responses to timeout action + if response.action_id == "timeout" { + return; + } + if self.parallel_actions.contains(&response.action_id) { self.forward_parallel_action_response(response).await; @@ -296,8 +397,13 @@ impl ActionsBridge { } }; - if *inflight_action.id != response.action_id { - error!("response id({}) != active action({})", response.action_id, inflight_action.id); + if !inflight_action.is_executing(&response.action_id) + && !inflight_action.is_cancelled_by(&response.action_id) + { + error!( + "response id({}) != active action({}); response = {:?}", + response.action_id, inflight_action.action.action_id, response + ); return; } @@ -305,35 +411,63 @@ impl ActionsBridge { self.streams.forward(response.clone()).await; if response.is_completed() || response.is_failed() { - self.clear_current_action(); + if let Some(CurrentAction { cancelled_by: Some(cancel_action), .. }) = + self.current_action.take() + { + let response = ActionResponse::success(&cancel_action); + self.streams.forward(response).await; + } return; } // Forward actions included in the config to the appropriate forward route, when // they have reached 100% progress but haven't been marked as "Completed"/"Finished". if response.is_done() { - let mut action = inflight_action.action.clone(); + let mut action = self.current_action.take().unwrap().action; if let Some(a) = response.done_response.take() { action = a; } - if let Err(RedirectionError(action)) = self.redirect_action(action).await { - // NOTE: send success reponse for actions that don't have redirections configured - warn!("Action redirection is not configured for: {:?}", action); - let response = ActionResponse::success(&action.action_id); - self.streams.forward(response).await; + match self.redirect_action(&mut action).await { + Ok(_) => (), + Err(Error::NoRoute(_)) => { + // NOTE: send success reponse for actions that don't have redirections configured + warn!("Action redirection is not configured for: {:?}", action); + let response = ActionResponse::success(&action.action_id); + self.streams.forward(response).await; + + if let Some(CurrentAction { cancelled_by: Some(cancel_action), .. }) = + self.current_action.take() + { + // Marks the cancellation as a failure as action has reached completion without being cancelled + self.forward_action_error(&cancel_action, Error::FailedCancellation).await + } + } + Err(Error::Cancelled(cancel_action)) => { + let response = ActionResponse::success(&cancel_action); + self.streams.forward(response).await; - self.clear_current_action(); + self.forward_action_error(&action.action_id, Error::Cancelled(cancel_action)) + .await; + } + Err(e) => self.forward_action_error(&action.action_id, e).await, } } } - async fn redirect_action(&mut self, mut action: Action) -> Result<(), RedirectionError> { + async fn redirect_action(&mut self, action: &mut Action) -> Result<(), Error> { let fwd_name = self .action_redirections .get(&action.name) - .ok_or_else(|| RedirectionError(action.clone()))?; + .ok_or_else(|| Error::NoRoute(action.name.clone()))?; + + // Cancelled action should not be redirected + if let Some(CurrentAction { cancelled_by: Some(cancel_action), .. }) = + self.current_action.as_ref() + { + return Err(Error::Cancelled(cancel_action.clone())); + } debug!( "Redirecting action: {} ~> {}; action_id = {}", @@ -341,14 +475,7 @@ impl ActionsBridge { ); fwd_name.clone_into(&mut action.name); - - if let Err(e) = self.try_route_action(action.clone()) { - error!("Failed to route action to app. Error = {:?}", e); - self.forward_action_error(action, e).await; - - // Remove action because it couldn't be forwarded - self.clear_current_action() - } + self.try_route_action(action.clone())?; Ok(()) } @@ -362,8 +489,8 @@ impl ActionsBridge { self.streams.forward(response).await; } - async fn forward_action_error(&mut self, action: Action, error: Error) { - let response = ActionResponse::failure(&action.action_id, error.to_string()); + async fn forward_action_error(&mut self, action_id: &str, error: Error) { + let response = ActionResponse::failure(action_id, error.to_string()); self.streams.forward(response).await; } @@ -371,29 +498,25 @@ impl ActionsBridge { #[derive(Debug, Deserialize, Serialize)] struct SaveAction { - pub id: String, pub action: Action, pub timeout: Duration, } struct CurrentAction { - pub id: String, pub action: Action, pub timeout: Pin>, + // cancel_action request + pub cancelled_by: Option, } impl CurrentAction { pub fn new(action: Action, deadline: Instant) -> CurrentAction { - CurrentAction { - id: action.action_id.clone(), - action, - timeout: Box::pin(time::sleep_until(deadline)), - } + CurrentAction { action, timeout: Box::pin(time::sleep_until(deadline)), cancelled_by: None } } pub fn write_to_disk(self, path: PathBuf) -> Result<(), Error> { let timeout = self.timeout.as_ref().deadline() - Instant::now(); - let save_action = SaveAction { id: self.id, action: self.action, timeout }; + let save_action = SaveAction { action: self.action, timeout }; let json = serde_json::to_string(&save_action)?; fs::write(path, json)?; @@ -406,29 +529,41 @@ impl CurrentAction { fs::remove_file(path)?; Ok(CurrentAction { - id: json.id, action: json.action, timeout: Box::pin(time::sleep(json.timeout)), + cancelled_by: None, }) } + + fn is_executing(&self, action_id: &str) -> bool { + self.action.action_id == action_id + } + + fn is_cancelled_by(&self, action_id: &str) -> bool { + self.cancelled_by.as_ref().is_some_and(|id| id == action_id) + } } #[derive(Debug)] pub struct ActionRouter { pub(crate) actions_tx: Sender, duration: Duration, + cancellable: bool, } impl ActionRouter { #[allow(clippy::result_large_err)] /// Forwards action to the appropriate application and returns the instance in time at which it should be timedout if incomplete - pub fn try_send(&self, mut action: Action) -> Result> { + pub fn try_send(&self, action: Action) -> Result> { let deadline = Instant::now() + self.duration; - action.deadline = Some(deadline); self.actions_tx.try_send(action)?; Ok(deadline) } + + pub fn is_cancellable(&self) -> bool { + self.cancellable + } } /// Handle for apps to send action status to bridge @@ -458,7 +593,7 @@ impl CtrlTx { #[cfg(test)] mod tests { - use tokio::runtime::Runtime; + use tokio::{runtime::Runtime, select}; use crate::config::{StreamConfig, StreamMetricsConfig}; @@ -520,13 +655,21 @@ mod tests { std::env::set_current_dir(&tmpdir).unwrap(); let config = Arc::new(default_config()); let (mut bridge, actions_tx, data_rx) = create_bridge(config); - let route_1 = ActionRoute { name: "route_1".to_string(), timeout: Duration::from_secs(10) }; + let route_1 = ActionRoute { + name: "route_1".to_string(), + timeout: Duration::from_secs(10), + cancellable: false, + }; let (route_tx, route_1_rx) = bounded(1); bridge.register_action_route(route_1, route_tx).unwrap(); let (route_tx, route_2_rx) = bounded(1); - let route_2 = ActionRoute { name: "route_2".to_string(), timeout: Duration::from_secs(30) }; + let route_2 = ActionRoute { + name: "route_2".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(route_2, route_tx).unwrap(); spawn_bridge(bridge); @@ -556,7 +699,6 @@ mod tests { action_id: "1".to_string(), name: "route_1".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action_1).unwrap(); @@ -579,7 +721,6 @@ mod tests { action_id: "2".to_string(), name: "route_2".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action_2).unwrap(); @@ -604,7 +745,11 @@ mod tests { let config = Arc::new(default_config()); let (mut bridge, actions_tx, data_rx) = create_bridge(config); - let test_route = ActionRoute { name: "test".to_string(), timeout: Duration::from_secs(30) }; + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; let (route_tx, action_rx) = bounded(1); bridge.register_action_route(test_route, route_tx).unwrap(); @@ -622,7 +767,6 @@ mod tests { action_id: "1".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action_1).unwrap(); @@ -636,7 +780,6 @@ mod tests { action_id: "2".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action_2).unwrap(); @@ -654,7 +797,11 @@ mod tests { let config = Arc::new(default_config()); let (mut bridge, actions_tx, data_rx) = create_bridge(config); - let test_route = ActionRoute { name: "test".to_string(), timeout: Duration::from_secs(30) }; + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; let (route_tx, action_rx) = bounded(1); bridge.register_action_route(test_route, route_tx).unwrap(); @@ -676,7 +823,6 @@ mod tests { action_id: "1".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); @@ -704,12 +850,19 @@ mod tests { let bridge_tx_2 = bridge.status_tx(); let (route_tx, action_rx_1) = bounded(1); - let test_route = ActionRoute { name: "test".to_string(), timeout: Duration::from_secs(30) }; + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(test_route, route_tx).unwrap(); let (route_tx, action_rx_2) = bounded(1); - let redirect_route = - ActionRoute { name: "redirect".to_string(), timeout: Duration::from_secs(30) }; + let redirect_route = ActionRoute { + name: "redirect".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(redirect_route, route_tx).unwrap(); spawn_bridge(bridge); @@ -740,7 +893,6 @@ mod tests { action_id: "1".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); @@ -771,12 +923,19 @@ mod tests { let bridge_tx_2 = bridge.status_tx(); let (route_tx, action_rx_1) = bounded(1); - let tunshell_route = - ActionRoute { name: TUNSHELL_ACTION.to_string(), timeout: Duration::from_secs(30) }; + let tunshell_route = ActionRoute { + name: TUNSHELL_ACTION.to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(tunshell_route, route_tx).unwrap(); let (route_tx, action_rx_2) = bounded(1); - let test_route = ActionRoute { name: "test".to_string(), timeout: Duration::from_secs(30) }; + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(test_route, route_tx).unwrap(); spawn_bridge(bridge); @@ -809,7 +968,6 @@ mod tests { action_id: "1".to_string(), name: "launch_shell".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); @@ -819,7 +977,6 @@ mod tests { action_id: "2".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); @@ -860,12 +1017,19 @@ mod tests { let bridge_tx_2 = bridge.status_tx(); let (route_tx, action_rx_1) = bounded(1); - let test_route = ActionRoute { name: "test".to_string(), timeout: Duration::from_secs(30) }; + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(test_route, route_tx).unwrap(); let (route_tx, action_rx_2) = bounded(1); - let tunshell_route = - ActionRoute { name: TUNSHELL_ACTION.to_string(), timeout: Duration::from_secs(30) }; + let tunshell_route = ActionRoute { + name: TUNSHELL_ACTION.to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; bridge.register_action_route(tunshell_route, route_tx).unwrap(); spawn_bridge(bridge); @@ -898,7 +1062,6 @@ mod tests { action_id: "1".to_string(), name: "test".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); @@ -908,7 +1071,6 @@ mod tests { action_id: "2".to_string(), name: "launch_shell".to_string(), payload: "test".to_string(), - deadline: None, }; actions_tx.send(action).unwrap(); diff --git a/uplink/src/collector/downloader.rs b/uplink/src/collector/downloader.rs index b121a6ca..4257bfb2 100644 --- a/uplink/src/collector/downloader.rs +++ b/uplink/src/collector/downloader.rs @@ -56,7 +56,7 @@ use reqwest::{Certificate, Client, ClientBuilder, Error as ReqwestError, Identit use rsa::sha2::{Digest, Sha256}; use serde::{Deserialize, Serialize}; use tokio::select; -use tokio::time::{sleep, timeout_at, Instant}; +use tokio::time::{sleep, Instant}; use std::fs::{metadata, read, remove_dir_all, remove_file, write, File}; use std::io; @@ -69,6 +69,7 @@ use std::{ }; use std::{io::Write, path::PathBuf}; +use crate::base::actions::Cancellation; use crate::{base::bridge::BridgeTx, config::DownloaderConfig, Action, ActionResponse, Config}; #[derive(thiserror::Error, Debug)] @@ -93,8 +94,8 @@ pub enum Error { BadSave, #[error("Save file doesn't exist")] NoSave, - #[error("Download timedout")] - Timeout, + #[error("Download has been cancelled by '{0}'")] + Cancelled(String), } /// This struct contains the necessary components to download and store file as notified by a download file @@ -207,14 +208,19 @@ impl FileDownloader { // Accepts `DownloadState`, sets a timeout for the action async fn download(&mut self, state: &mut DownloadState) -> Result<(), Error> { let shutdown_rx = self.shutdown_rx.clone(); - let deadline = match &state.current.action.deadline { - Some(d) => *d, - _ => { - error!("Unconfigured deadline: {}", state.current.action.name); - return Ok(()); - } - }; select! { + // Wait till download completes + o = self.continuous_retry(state) => o?, + // Cancel download on receiving cancel action, e.g. on action timeout + Ok(action) = self.actions_rx.recv_async() => { + let cancellation: Cancellation = serde_json::from_str(&action.payload)?; + + trace!("Deleting partially downloaded file: {cancellation:?}"); + state.clean()?; + + return Err(Error::Cancelled(action.action_id)); + }, + Ok(_) = shutdown_rx.recv_async(), if !shutdown_rx.is_disconnected() => { if let Err(e) = state.save(&self.config) { error!("Error saving current_download: {e}"); @@ -222,18 +228,6 @@ impl FileDownloader { return Ok(()); }, - - // NOTE: if download has timedout don't do anything, else ensure errors are forwarded after three retries - o = timeout_at(deadline, self.continuous_retry(state)) => match o { - Ok(r) => r?, - Err(_) => { - // unwrap is safe because download_path is expected to be Some - _ = remove_file(state.current.meta.download_path.as_ref().unwrap()); - error!("Last download has timedout; file deleted"); - - return Err(Error::Timeout); - }, - } } state.current.meta.verify_checksum()?; @@ -244,7 +238,7 @@ impl FileDownloader { } // A download must be retried with Range header when HTTP/reqwest errors are faced - async fn continuous_retry(&mut self, state: &mut DownloadState) -> Result<(), Error> { + async fn continuous_retry(&self, state: &mut DownloadState) -> Result<(), Error> { 'outer: loop { let mut req = self.client.get(&state.current.meta.url); if let Some(range) = state.retry_range() { @@ -256,7 +250,7 @@ impl FileDownloader { Err(e) => { error!("Download failed: {e}"); // Retry after wait - tokio::time::sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; continue 'outer; } }; @@ -280,7 +274,7 @@ impl FileDownloader { self.bridge_tx.send_action_response(status).await; error!("Download failed: {e}"); // Retry after wait - tokio::time::sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; continue 'outer; } Err(e) => return Err(e.into()), @@ -391,7 +385,6 @@ impl DownloadFile { struct CurrentDownload { action: Action, meta: DownloadFile, - time_left: Option, } // A temporary structure to help us retry downloads @@ -444,7 +437,7 @@ impl DownloadState { human_bytes(meta.content_length as f64) ); meta.download_path = Some(file_path); - let current = CurrentDownload { action, meta, time_left: None }; + let current = CurrentDownload { action, meta }; Ok(Self { current, @@ -464,11 +457,11 @@ impl DownloadState { } let read = read(&path)?; - let mut current: CurrentDownload = serde_json::from_slice(&read)?; - // Calculate deadline based on written time left - current.action.deadline = current.time_left.map(|t| Instant::now() + t); + let current: CurrentDownload = serde_json::from_slice(&read)?; - let file = File::options().append(true).open(current.meta.download_path.as_ref().unwrap())?; + // Unwrap is ok here as it is expected to be set for actions once received + let file = + File::options().append(true).open(current.meta.download_path.as_ref().unwrap())?; let bytes_written = file.metadata()?.len() as usize; remove_file(path)?; @@ -487,9 +480,7 @@ impl DownloadState { return Ok(()); } - let mut current = self.current.clone(); - // Calculate time left based on deadline - current.time_left = current.action.deadline.map(|t| t.duration_since(Instant::now())); + let current = self.current.clone(); let json = serde_json::to_vec(¤t)?; let mut path = config.path.clone(); @@ -499,6 +490,14 @@ impl DownloadState { Ok(()) } + /// Deletes contents of file + fn clean(&self) -> Result<(), Error> { + // Unwrap is ok here as it is expected to be set for actions once received + remove_file(self.current.meta.download_path.as_ref().unwrap())?; + + Ok(()) + } + fn retry_range(&self) -> Option { if self.bytes_written == 0 { return None; @@ -598,6 +597,7 @@ mod test { actions: vec![ActionRoute { name: "firmware_update".to_owned(), timeout: Duration::from_secs(10), + cancellable: true, }], path, }; @@ -643,7 +643,6 @@ mod test { action_id: "1".to_string(), name: "firmware_update".to_string(), payload: json!(download_update).to_string(), - deadline: Some(Instant::now() + Duration::from_secs(60)), }; std::thread::sleep(Duration::from_millis(10)); @@ -710,7 +709,6 @@ mod test { action_id: "1".to_string(), name: "firmware_update".to_string(), payload: json!(correct_update).to_string(), - deadline: Some(Instant::now() + Duration::from_secs(100)), }; // Send the correct action to FileDownloader @@ -748,7 +746,6 @@ mod test { action_id: "1".to_string(), name: "firmware_update".to_string(), payload: json!(wrong_update).to_string(), - deadline: Some(Instant::now() + Duration::from_secs(100)), }; // Send the wrong action to FileDownloader diff --git a/uplink/src/collector/process.rs b/uplink/src/collector/process.rs index f1aed754..79eb578e 100644 --- a/uplink/src/collector/process.rs +++ b/uplink/src/collector/process.rs @@ -1,11 +1,11 @@ use flume::{Receiver, RecvError, SendError}; -use log::{debug, error, info}; +use log::{debug, error, info, trace}; use thiserror::Error; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::{Child, Command}; use tokio::select; -use tokio::time::timeout_at; +use crate::base::actions::Cancellation; use crate::base::bridge::BridgeTx; use crate::{Action, ActionResponse, Package}; @@ -26,6 +26,8 @@ pub enum Error { Busy, #[error("No stdout in spawned action")] NoStdout, + #[error("Process has been cancelled by '{0}'")] + Cancelled(String), } /// Process abstracts functions to spawn process and handle their output @@ -55,25 +57,37 @@ impl ProcessHandler { } /// Capture stdout of the running process in a spawned task - pub async fn spawn_and_capture_stdout(&mut self, mut child: Child) -> Result<(), Error> { + pub async fn spawn_and_capture_stdout( + &mut self, + mut child: Child, + action_id: &str, + ) -> Result<(), Error> { let stdout = child.stdout.take().ok_or(Error::NoStdout)?; let mut stdout = BufReader::new(stdout).lines(); loop { select! { - Ok(Some(line)) = stdout.next_line() => { + Ok(Some(line)) = stdout.next_line() => { let status: ActionResponse = match serde_json::from_str(&line) { Ok(status) => status, - Err(e) => ActionResponse::failure("dummy", e.to_string()), + Err(e) => ActionResponse::failure(action_id, e.to_string()), }; debug!("Action status: {:?}", status); self.bridge_tx.send_action_response(status).await; - } - status = child.wait() => { + } + status = child.wait() => { info!("Action done!! Status = {:?}", status); return Ok(()); }, + // Cancel process on receiving cancel action, e.g. on action timeout + Ok(action) = self.actions_rx.recv_async() => { + let cancellation: Cancellation = serde_json::from_str(&action.payload)?; + + trace!("Cancelling process: '{}'", cancellation.action_id); + let status = ActionResponse::failure(action_id, Error::Cancelled(action.action_id).to_string()); + self.bridge_tx.send_action_response(status).await; + }, } } } @@ -83,21 +97,10 @@ impl ProcessHandler { loop { let action = self.actions_rx.recv_async().await?; let command = format!("tools/{}", action.name); - let deadline = match &action.deadline { - Some(d) => *d, - _ => { - error!("Unconfigured deadline: {}", action.name); - continue; - } - }; // Spawn the action and capture its stdout, ignore timeouts let child = self.run(&action.action_id, &command, &action.payload).await?; - if let Ok(o) = timeout_at(deadline, self.spawn_and_capture_stdout(child)).await { - o?; - } else { - error!("Process timedout: {command}; action_id = {}", action.action_id); - } + self.spawn_and_capture_stdout(child, &action.action_id).await?; } } } diff --git a/uplink/src/collector/script_runner.rs b/uplink/src/collector/script_runner.rs index db8f74ac..8e3b06ce 100644 --- a/uplink/src/collector/script_runner.rs +++ b/uplink/src/collector/script_runner.rs @@ -1,12 +1,12 @@ use flume::{Receiver, RecvError, SendError}; -use log::{debug, error, info, warn}; +use log::{debug, error, info, trace, warn}; use thiserror::Error; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::{Child, Command}; use tokio::select; -use tokio::time::timeout_at; use super::downloader::DownloadFile; +use crate::base::actions::Cancellation; use crate::base::bridge::BridgeTx; use crate::{Action, ActionResponse, Package}; @@ -28,6 +28,8 @@ pub enum Error { Busy, #[error("No stdout in spawned action")] NoStdout, + #[error("Script has been cancelled: '{0}'")] + Cancelled(String), } /// Script runner runs a script downloaded with FileDownloader and handles their output over the action_status stream. @@ -86,6 +88,14 @@ impl ScriptRunner { self.forward_status(ActionResponse::success(id)).await; break; }, + // Cancel script run on receiving cancel action, e.g. on action timeout + Ok(action) = self.actions_rx.recv_async() => { + let cancellation: Cancellation = serde_json::from_str(&action.payload)?; + + trace!("Cancelling script: '{}'", cancellation.action_id); + let status = ActionResponse::failure(id, Error::Cancelled(action.action_id).to_string()); + self.bridge_tx.send_action_response(status).await; + }, } } @@ -119,20 +129,9 @@ impl ScriptRunner { continue; } }; - let deadline = match &action.deadline { - Some(d) => *d, - _ => { - error!("Unconfigured deadline: {}", action.name); - continue; - } - }; // Spawn the action and capture its stdout let child = self.run(command).await?; - if let Ok(o) = - timeout_at(deadline, self.spawn_and_capture_stdout(child, &action.action_id)).await - { - o? - } + self.spawn_and_capture_stdout(child, &action.action_id).await? } } @@ -176,7 +175,6 @@ mod tests { action_id: "1".to_string(), name: "test".to_string(), payload: "".to_string(), - deadline: None, }) .unwrap(); @@ -200,7 +198,6 @@ mod tests { name: "test".to_string(), payload: "{\"url\": \"...\", \"content_length\": 0,\"file_name\": \"...\"}" .to_string(), - deadline: None, }) .unwrap(); diff --git a/uplink/src/config.rs b/uplink/src/config.rs index d7718a97..bef2e116 100644 --- a/uplink/src/config.rs +++ b/uplink/src/config.rs @@ -214,6 +214,9 @@ pub struct ActionRoute { #[serde(default = "default_timeout")] #[serde_as(as = "DurationSeconds")] pub timeout: Duration, + // Can the action handler cancel actions mid execution? + #[serde(default)] + pub cancellable: bool, } impl From<&ActionRoute> for ActionRoute { diff --git a/uplink/src/lib.rs b/uplink/src/lib.rs index bace1a15..3763a1b9 100644 --- a/uplink/src/lib.rs +++ b/uplink/src/lib.rs @@ -245,8 +245,11 @@ impl Uplink { pub fn spawn_builtins(&mut self, bridge: &mut Bridge) -> Result<(), Error> { let bridge_tx = bridge.bridge_tx(); - let route = - ActionRoute { name: "launch_shell".to_owned(), timeout: Duration::from_secs(10) }; + let route = ActionRoute { + name: "launch_shell".to_owned(), + timeout: Duration::from_secs(10), + cancellable: false, + }; let actions_rx = bridge.register_action_route(route)?; let tunshell_client = TunshellClient::new(actions_rx, bridge_tx.clone()); spawn_named_thread("Tunshell Client", move || tunshell_client.start()); @@ -266,6 +269,7 @@ impl Uplink { let route = ActionRoute { name: "journalctl_config".to_string(), timeout: Duration::from_secs(10), + cancellable: false, }; let actions_rx = bridge.register_action_route(route)?; let logger = JournalCtl::new(config, actions_rx, bridge_tx.clone()); @@ -281,6 +285,7 @@ impl Uplink { let route = ActionRoute { name: "journalctl_config".to_string(), timeout: Duration::from_secs(10), + cancellable: false, }; let actions_rx = bridge.register_action_route(route)?; let logger = Logcat::new(config, actions_rx, bridge_tx.clone()); diff --git a/uplink/src/main.rs b/uplink/src/main.rs index 1809bb58..6bf775e0 100644 --- a/uplink/src/main.rs +++ b/uplink/src/main.rs @@ -181,6 +181,21 @@ impl CommandLine { config.actions_subscription = format!("/tenants/{tenant_id}/devices/{device_id}/actions"); + // downloader actions are cancellable by default + for route in config.downloader.actions.iter_mut() { + route.cancellable = true; + } + + // process actions are cancellable by default + for route in config.processes.iter_mut() { + route.cancellable = true; + } + + // script runner actions are cancellable by default + for route in config.script_runner.iter_mut() { + route.cancellable = true; + } + Ok(config) }