diff --git a/uplink/src/collector/downloader.rs b/uplink/src/collector/downloader.rs index 2c8b08d74..b121a6ca3 100644 --- a/uplink/src/collector/downloader.rs +++ b/uplink/src/collector/downloader.rs @@ -56,11 +56,11 @@ use reqwest::{Certificate, Client, ClientBuilder, Error as ReqwestError, Identit use rsa::sha2::{Digest, Sha256}; use serde::{Deserialize, Serialize}; use tokio::select; -use tokio::time::{timeout_at, Instant}; +use tokio::time::{sleep, timeout_at, Instant}; use std::fs::{metadata, read, remove_dir_all, remove_file, write, File}; use std::io; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; #[cfg(unix)] use std::{ @@ -108,6 +108,7 @@ pub struct FileDownloader { bridge_tx: BridgeTx, client: Client, shutdown_rx: Receiver, + disabled: Arc>, } impl FileDownloader { @@ -117,6 +118,7 @@ impl FileDownloader { actions_rx: Receiver, bridge_tx: BridgeTx, shutdown_rx: Receiver, + disabled: Arc>, ) -> Result { // Authenticate with TLS certs from config let client_builder = ClientBuilder::new(); @@ -140,6 +142,7 @@ impl FileDownloader { bridge_tx, action_id: String::default(), shutdown_rx, + disabled, }) } @@ -259,7 +262,14 @@ impl FileDownloader { }; // Download and store to disk by streaming as chunks - while let Some(item) = stream.next().await { + loop { + // Checks if downloader is disabled by user or not + if *self.disabled.lock().unwrap() { + // async to ensure download can be cancelled during sleep + sleep(Duration::from_secs(1)).await; + continue; + } + let Some(item) = stream.next().await else { break }; let chunk = match item { Ok(c) => c, // Retry non-status errors @@ -605,8 +615,14 @@ mod test { // Create channels to forward and push actions on let (download_tx, download_rx) = bounded(1); let (_, ctrl_rx) = bounded(1); - let downloader = - FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap(); + let downloader = FileDownloader::new( + Arc::new(config), + download_rx, + bridge_tx, + ctrl_rx, + Arc::new(Mutex::new(false)), + ) + .unwrap(); // Start FileDownloader in separate thread std::thread::spawn(|| downloader.start()); @@ -666,8 +682,14 @@ mod test { // Create channels to forward and push action_status on let (download_tx, download_rx) = bounded(1); let (_, ctrl_rx) = bounded(1); - let downloader = - FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap(); + let downloader = FileDownloader::new( + Arc::new(config), + download_rx, + bridge_tx, + ctrl_rx, + Arc::new(Mutex::new(false)), + ) + .unwrap(); // Start FileDownloader in separate thread std::thread::spawn(|| downloader.start()); diff --git a/uplink/src/console.rs b/uplink/src/console.rs index 187bbda34..400658bfd 100644 --- a/uplink/src/console.rs +++ b/uplink/src/console.rs @@ -1,4 +1,12 @@ -use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::post, Router}; +use std::sync::{Arc, Mutex}; + +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{post, put}, + Router, +}; use log::info; use uplink::base::CtrlTx; @@ -8,16 +16,24 @@ use crate::ReloadHandle; struct StateHandle { reload_handle: ReloadHandle, ctrl_tx: CtrlTx, + downloader_disable: Arc>, } #[tokio::main] -pub async fn start(port: u16, reload_handle: ReloadHandle, ctrl_tx: CtrlTx) { +pub async fn start( + port: u16, + reload_handle: ReloadHandle, + ctrl_tx: CtrlTx, + downloader_disable: Arc>, +) { let address = format!("0.0.0.0:{port}"); info!("Starting uplink console server: {address}"); - let state = StateHandle { reload_handle, ctrl_tx }; + let state = StateHandle { reload_handle, ctrl_tx, downloader_disable }; let app = Router::new() .route("/logs", post(reload_loglevel)) .route("/shutdown", post(shutdown)) + .route("/disable_downloader", put(disable_downloader)) + .route("/enable_downloader", put(enable_downloader)) .with_state(state); axum::Server::bind(&address.parse().unwrap()).serve(app.into_make_service()).await.unwrap(); @@ -38,3 +54,27 @@ async fn shutdown(State(state): State) -> impl IntoResponse { StatusCode::OK } + +// Stops downloader from downloading even if it was already stopped +async fn disable_downloader(State(state): State) -> impl IntoResponse { + info!("Downloader stopped"); + let mut is_disabled = state.downloader_disable.lock().unwrap(); + if *is_disabled { + StatusCode::ACCEPTED + } else { + *is_disabled = true; + StatusCode::OK + } +} + +// Start downloader back up even if it was already not stopped +async fn enable_downloader(State(state): State) -> impl IntoResponse { + info!("Downloader started"); + let mut is_disabled = state.downloader_disable.lock().unwrap(); + if *state.downloader_disable.lock().unwrap() { + *is_disabled = false; + StatusCode::OK + } else { + StatusCode::ACCEPTED + } +} diff --git a/uplink/src/lib.rs b/uplink/src/lib.rs index ff5ba87c7..bace1a150 100644 --- a/uplink/src/lib.rs +++ b/uplink/src/lib.rs @@ -40,7 +40,7 @@ //!``` //! [`port`]: base::AppConfig#structfield.port //! [`name`]: Action#structfield.name -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; @@ -129,7 +129,11 @@ impl Uplink { ) } - pub fn spawn(&mut self, mut bridge: Bridge) -> Result { + pub fn spawn( + &mut self, + mut bridge: Bridge, + downloader_disable: Arc>, + ) -> Result { let (mqtt_metrics_tx, mqtt_metrics_rx) = bounded(10); let (ctrl_actions_lane, ctrl_data_lane) = bridge.ctrl_tx(); @@ -151,8 +155,13 @@ impl Uplink { // Downloader thread if configured if !self.config.downloader.actions.is_empty() { let actions_rx = bridge.register_action_routes(&self.config.downloader.actions)?; - let file_downloader = - FileDownloader::new(self.config.clone(), actions_rx, bridge.bridge_tx(), ctrl_rx)?; + let file_downloader = FileDownloader::new( + self.config.clone(), + actions_rx, + bridge.bridge_tx(), + ctrl_rx, + downloader_disable, + )?; spawn_named_thread("File Downloader", || file_downloader.start()); } diff --git a/uplink/src/main.rs b/uplink/src/main.rs index 572fea502..1809bb58c 100644 --- a/uplink/src/main.rs +++ b/uplink/src/main.rs @@ -1,7 +1,7 @@ mod console; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use anyhow::Error; @@ -307,7 +307,8 @@ fn main() -> Result<(), Error> { _ => None, }; - let ctrl_tx = uplink.spawn(bridge)?; + let downloader_disable = Arc::new(Mutex::new(false)); + let ctrl_tx = uplink.spawn(bridge, downloader_disable.clone())?; if let Some(config) = config.simulator.clone() { spawn_named_thread("Simulator", || { @@ -318,7 +319,9 @@ fn main() -> Result<(), Error> { if config.console.enabled { let port = config.console.port; let ctrl_tx = ctrl_tx.clone(); - spawn_named_thread("Uplink Console", move || console::start(port, reload_handle, ctrl_tx)); + spawn_named_thread("Uplink Console", move || { + console::start(port, reload_handle, ctrl_tx, downloader_disable) + }); } let rt = tokio::runtime::Builder::new_current_thread()