diff --git a/uplink/src/base/mod.rs b/uplink/src/base/mod.rs index 442a0f81..74887842 100644 --- a/uplink/src/base/mod.rs +++ b/uplink/src/base/mod.rs @@ -6,6 +6,7 @@ use tokio::join; use self::bridge::{ActionsLaneCtrlTx, DataLaneCtrlTx}; use self::mqtt::CtrlTx as MqttCtrlTx; use self::serializer::CtrlTx as SerializerCtrlTx; +use crate::collector::downloader::CtrlTx as DownloaderCtrlTx; pub mod actions; pub mod bridge; @@ -26,6 +27,7 @@ pub struct CtrlTx { pub data_lane: DataLaneCtrlTx, pub mqtt: MqttCtrlTx, pub serializer: SerializerCtrlTx, + pub downloader: DownloaderCtrlTx, } impl CtrlTx { @@ -34,7 +36,8 @@ impl CtrlTx { self.actions_lane.trigger_shutdown(), self.data_lane.trigger_shutdown(), self.mqtt.trigger_shutdown(), - self.serializer.trigger_shutdown() + self.serializer.trigger_shutdown(), + self.downloader.trigger_shutdown() ); } } diff --git a/uplink/src/collector/downloader.rs b/uplink/src/collector/downloader.rs index abe8e53c..b3556fb3 100644 --- a/uplink/src/collector/downloader.rs +++ b/uplink/src/collector/downloader.rs @@ -48,16 +48,17 @@ //! [`action_redirections`]: Config#structfield.action_redirections use bytes::BytesMut; -use flume::Receiver; -use futures_util::{Future, StreamExt}; +use flume::{Receiver, Sender}; +use futures_util::StreamExt; use human_bytes::human_bytes; use log::{debug, error, info, trace, warn}; -use reqwest::{Certificate, Client, ClientBuilder, Error as ReqwestError, Identity, Response}; +use reqwest::{Certificate, Client, ClientBuilder, Error as ReqwestError, Identity}; use rsa::sha2::{Digest, Sha256}; use serde::{Deserialize, Serialize}; +use tokio::select; use tokio::time::{timeout_at, Instant}; -use std::fs::{metadata, remove_dir_all, File}; +use std::fs::{metadata, read, remove_dir_all, remove_file, write, File}; use std::io; use std::sync::Arc; use std::time::Duration; @@ -88,6 +89,10 @@ pub enum Error { BadChecksum, #[error("Disk space is insufficient: {0}")] InsufficientDisk(String), + #[error("Save file is corrupted")] + BadSave, + #[error("Save file doesn't exist")] + NoSave, } /// This struct contains the necessary components to download and store file as notified by a download file @@ -101,6 +106,7 @@ pub struct FileDownloader { bridge_tx: BridgeTx, client: Client, sequence: u32, + shutdown_rx: Receiver, } impl FileDownloader { @@ -109,6 +115,7 @@ impl FileDownloader { config: Arc, actions_rx: Receiver, bridge_tx: BridgeTx, + shutdown_rx: Receiver, ) -> Result { // Authenticate with TLS certs from config let client_builder = ClientBuilder::new(); @@ -132,6 +139,7 @@ impl FileDownloader { bridge_tx, sequence: 0, action_id: String::default(), + shutdown_rx, }) } @@ -139,6 +147,8 @@ impl FileDownloader { /// back to bridge for further processing, e.g. OTA update installation. #[tokio::main(flavor = "current_thread")] pub async fn start(mut self) { + self.reload().await; + info!("Downloader thread is ready to receive download actions"); loop { self.sequence = 0; @@ -150,24 +160,40 @@ impl FileDownloader { } }; self.action_id = action.action_id.clone(); - let deadline = match &action.deadline { - Some(d) => *d, - _ => { - error!("Unconfigured deadline: {}", action.name); + + let download = match DownloadState::prepare(action, &self.config).await { + Ok(d) => d, + Err(e) => { + self.forward_error(e).await; continue; } }; - // NOTE: if download has timedout don't do anything, else ensure errors are forwarded after three retries + // Update action status for process initiated + self.forward_progress(0).await; - match timeout_at(deadline, self.run(action)).await { - Ok(Err(e)) => self.forward_error(e).await, - Err(_) => error!("Last download has timedout"), - _ => {} + if let Err(e) = self.download(download).await { + self.forward_error(e).await; } } } + // reloads a download if it wasn't completed during the previous run of uplink + async fn reload(&mut self) { + let download = match DownloadState::load(&self.config) { + Ok(s) => s, + Err(Error::NoSave) => return, + Err(e) => { + warn!("Couldn't reload current_download: {e}"); + return; + } + }; + + if let Err(e) = self.download(download).await { + self.forward_error(e).await; + } + } + // Forward errors as action response to bridge async fn forward_error(&mut self, err: Error) { let status = @@ -175,177 +201,113 @@ impl FileDownloader { self.bridge_tx.send_action_response(status).await; } - // A download must be retried with Range header when HTTP/reqwest errors are faced - async fn continuous_retry( - &mut self, - url: &str, - mut download: DownloadState, - ) -> Result<(), Error> { - let mut req = self.client.get(url).send(); - loop { - match self.download(req, &mut download).await { - Ok(_) => break, - Err(Error::Reqwest(e)) => error!("Download failed: {e}"), - Err(e) => return Err(e), - } - tokio::time::sleep(Duration::from_secs(1)).await; - - let range = download.retry_range(); - warn!("Retrying download; Continuing to download file from: {range}"); - req = self.client.get(url).header("Range", range).send(); - } - - Ok(()) - } - - // Accepts a download `Action` and performs necessary data extraction to actually download the file - async fn run(&mut self, mut action: Action) -> Result<(), Error> { - // Update action status for process initiated - let status = ActionResponse::progress(&self.action_id, "Downloading", 0); + // Forward progress as action response to bridge + async fn forward_progress(&mut self, progress: u8) { + let status = ActionResponse::progress(&self.action_id, "Downloading", progress); let status = status.set_sequence(self.sequence()); self.bridge_tx.send_action_response(status).await; + } - // Ensure that directory for downloading file into, exists - let mut download_path = self.config.path.clone(); - download_path.push(&action.name); - - #[cfg(unix)] - self.create_dirs_with_perms( - download_path.as_path(), - std::os::unix::fs::PermissionsExt::from_mode(0o777), - )?; + // Accepts `DownloadState`, sets a timeout for the action, saves action for restart + async fn download(&mut self, mut download: DownloadState) -> Result<(), Error> { + let shutdown_rx = self.shutdown_rx.clone(); + let deadline = *download.current.action.deadline.as_ref().unwrap(); - #[cfg(not(unix))] - std::fs::create_dir_all(&download_path)?; + select! { + o = timeout_at(deadline, self.continuous_retry(&mut download)) => { - // Extract url information from action payload - let mut update = match serde_json::from_str::(&action.payload)? { - DownloadFile { file_name, .. } if file_name.is_empty() => { - return Err(Error::EmptyFileName) + // NOTE: if download has timedout don't do anything + match o { + Ok(r) => r?, + Err(_) => error!("Last download has timedout"), + } } - DownloadFile { content_length: 0, .. } => return Err(Error::EmptyFile), - u => u, - }; - self.check_disk_size(&update)?; + _ = shutdown_rx.recv_async() => { + if let Err(e) = download.save(&self.config) { + error!("Error saving current_download: {e}"); + } + } + } - let url = update.url.clone(); + Ok(()) + } - // Create file to actually download into - let (file, file_path) = self.create_file(&download_path, &update.file_name)?; + // A download must be retried with Range header when HTTP/reqwest errors are faced + async fn continuous_retry(&mut self, download: &mut DownloadState) -> Result<(), Error> { + loop { + let mut req = self.client.get(download.url()); + if let Some(range) = download.retry_range() { + req = req.header("Range", &range); + warn!("Retrying download; Continuing to download file from: {range}"); + } - // Retry downloading upto 3 times in case of connectivity issues - // TODO: Error out for 1XX/3XX responses - info!( - "Downloading from {} into {}; size = {}", - url, - file_path.display(), - human_bytes(update.content_length as f64) - ); - let download = DownloadState { - file, - bytes_written: 0, - bytes_downloaded: 0, - percentage_downloaded: 0, - content_length: update.content_length, - start_instant: Instant::now(), - }; - self.continuous_retry(&url, download).await?; + let mut stream = req.send().await?.error_for_status()?.bytes_stream(); + // Download and store to disk by streaming as chunks + while let Some(item) = stream.next().await { + let chunk = match item { + Ok(c) => c, + Err(e) => { + error!("Download failed: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + if let Some(percentage) = download.write_bytes(&chunk)? { + self.forward_progress(percentage).await; + } + } - // Update Action payload with `download_path`, i.e. downloaded file's location in fs - update.insert_path(file_path.clone()); - update.verify_checksum()?; + break; + } - action.payload = serde_json::to_string(&update)?; - let status = ActionResponse::done(&self.action_id, "Downloaded", Some(action)); + download.current.meta.verify_checksum()?; + info!("Firmware downloaded successfully"); - let status = status.set_sequence(self.sequence()); + let mut action = download.current.action.clone(); + action.payload = serde_json::to_string(&download.current.meta)?; + let status = ActionResponse::done(&self.action_id, "Downloaded", Some(action)) + .set_sequence(self.sequence()); self.bridge_tx.send_action_response(status).await; Ok(()) } - fn check_disk_size(&mut self, download: &DownloadFile) -> Result<(), Error> { - let disk_free_space = fs2::free_space(&self.config.path)? as usize; - - let req_size = human_bytes(download.content_length as f64); - let free_size = human_bytes(disk_free_space as f64); - debug!("Download requires {req_size}; Disk free space is {free_size}"); - - if download.content_length > disk_free_space { - return Err(Error::InsufficientDisk(free_size)); - } - - Ok(()) + fn sequence(&mut self) -> u32 { + self.sequence += 1; + self.sequence } +} - #[cfg(unix)] - /// Custom create_dir_all which sets permissions on each created directory, only works on unix - fn create_dirs_with_perms(&self, path: &Path, perms: Permissions) -> std::io::Result<()> { - let mut current_path = PathBuf::new(); - - for component in path.components() { - current_path.push(component); +fn check_disk_size(download: &DownloadFile, config: &DownloaderConfig) -> Result<(), Error> { + let disk_free_space = fs2::free_space(&config.path)? as usize; - if !current_path.exists() { - create_dir(¤t_path)?; - set_permissions(¤t_path, perms.clone())?; - } - } + let req_size = human_bytes(download.content_length as f64); + let free_size = human_bytes(disk_free_space as f64); + debug!("Download requires {req_size}; Disk free space is {free_size}"); - Ok(()) + if download.content_length > disk_free_space { + return Err(Error::InsufficientDisk(free_size)); } - /// Creates file to download into - fn create_file( - &self, - download_path: &PathBuf, - file_name: &str, - ) -> Result<(File, PathBuf), Error> { - let mut file_path = download_path.to_owned(); - file_path.push(file_name); - // NOTE: if file_path is occupied by a directory due to previous working of uplink, remove it - if let Ok(f) = metadata(&file_path) { - if f.is_dir() { - remove_dir_all(&file_path)?; - } - } - let file = File::create(&file_path)?; - #[cfg(unix)] - file.set_permissions(std::os::unix::fs::PermissionsExt::from_mode(0o666))?; + Ok(()) +} - Ok((file, file_path)) - } +#[cfg(unix)] +/// Custom create_dir_all which sets permissions on each created directory, only works on unix +fn create_dirs_with_perms(path: &Path, perms: Permissions) -> std::io::Result<()> { + let mut current_path = PathBuf::new(); - /// Downloads from server and stores into file - async fn download( - &mut self, - req: impl Future>, - download: &mut DownloadState, - ) -> Result<(), Error> { - let mut stream = req.await?.error_for_status()?.bytes_stream(); - - // Download and store to disk by streaming as chunks - while let Some(item) = stream.next().await { - let chunk = item?; - if let Some(percentage) = download.write_bytes(&chunk)? { - //TODO: Simplify progress by reusing action_id and state - //TODO: let response = self.response.progress(percentage);?? - let status = ActionResponse::progress(&self.action_id, "Downloading", percentage); - let status = status.set_sequence(self.sequence()); - self.bridge_tx.send_action_response(status).await; - } - } + for component in path.components() { + current_path.push(component); - info!("Firmware downloaded successfully"); - - Ok(()) + if !current_path.exists() { + create_dir(¤t_path)?; + set_permissions(¤t_path, perms.clone())?; + } } - fn sequence(&mut self) -> u32 { - self.sequence += 1; - self.sequence - } + Ok(()) } /// Expected JSON format of data contained in the [`payload`] of a download file [`Action`] @@ -365,10 +327,6 @@ pub struct DownloadFile { } impl DownloadFile { - fn insert_path(&mut self, download_path: PathBuf) { - self.download_path = Some(download_path); - } - fn verify_checksum(&self) -> Result<(), Error> { let Some(checksum) = &self.checksum else { return Ok(()) }; let path = self.download_path.as_ref().expect("Downloader didn't set \"download_path\""); @@ -385,26 +343,144 @@ impl DownloadFile { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +struct CurrentDownload { + action: Action, + meta: DownloadFile, + time_left: Option, +} + // A temporary structure to help us retry downloads // that failed after partial completion. +#[derive(Debug)] struct DownloadState { + current: CurrentDownload, file: File, bytes_written: usize, - bytes_downloaded: usize, percentage_downloaded: u8, - content_length: usize, - start_instant: Instant, + start: Instant, } impl DownloadState { + fn load(config: &DownloaderConfig) -> Result { + let mut path = config.path.clone(); + path.push("current_download"); + + if !path.exists() { + return Err(Error::NoSave); + } + + let read = read(&path)?; + let mut current: CurrentDownload = serde_json::from_slice(&read)?; + // Calculate deadline based on written time left + current.action.deadline = current.time_left.map(|t| Instant::now() + t); + + let file = File::open(¤t.meta.download_path.as_ref().unwrap())?; + let bytes_written = file.metadata()?.len() as usize; + + remove_file(path)?; + + Ok(DownloadState { + current, + file, + bytes_written, + percentage_downloaded: 0, + start: Instant::now(), + }) + } + + fn save(&self, config: &DownloaderConfig) -> Result<(), Error> { + if self.bytes_written == self.current.meta.content_length { + return Ok(()); + } + + let mut current = self.current.clone(); + // Calculate time left based on deadline + current.time_left = current.action.deadline.map(|t| t.duration_since(Instant::now())); + let json = serde_json::to_vec(¤t)?; + + let mut path = config.path.clone(); + path.push("current_download"); + write(path, json)?; + + Ok(()) + } + + fn retry_range(&self) -> Option { + if self.bytes_written == 0 { + return None; + } + + Some(format!("bytes={}-{}", self.bytes_written, self.current.meta.content_length)) + } + + async fn prepare(action: Action, config: &DownloaderConfig) -> Result { + // Ensure that directory for downloading file into, exists + let mut download_path = config.path.clone(); + download_path.push(&action.name); + + #[cfg(unix)] + create_dirs_with_perms( + &download_path.as_path(), + std::os::unix::fs::PermissionsExt::from_mode(0o777), + )?; + + #[cfg(not(unix))] + std::fs::create_dir_all(&download_path)?; + + // Extract url information from action payload + let mut meta = match serde_json::from_str::(&action.payload)? { + DownloadFile { file_name, .. } if file_name.is_empty() => { + return Err(Error::EmptyFileName) + } + DownloadFile { content_length: 0, .. } => return Err(Error::EmptyFile), + u => u, + }; + + check_disk_size(&meta, &config)?; + + // Create file to actually download into + let mut file_path = download_path.to_owned(); + file_path.push(&meta.file_name); + + // NOTE: if file_path is occupied by a directory due to previous working of uplink, remove it + if let Ok(f) = metadata(&file_path) { + if f.is_dir() { + remove_dir_all(&file_path)?; + } + } + let file = File::create(&file_path)?; + #[cfg(unix)] + file.set_permissions(std::os::unix::fs::PermissionsExt::from_mode(0o666))?; + + // Retry downloading upto 3 times in case of connectivity issues + // TODO: Error out for 1XX/3XX responses + info!( + "Downloading from {} into {}; size = {}", + meta.url, + file_path.display(), + human_bytes(meta.content_length as f64) + ); + meta.download_path = Some(file_path); + let current = CurrentDownload { action, meta, time_left: None }; + + Ok(Self { + current, + file, + bytes_written: 0, + percentage_downloaded: 0, + start: Instant::now(), + }) + } + fn write_bytes(&mut self, buf: &[u8]) -> Result, Error> { - self.bytes_downloaded += buf.len(); + let bytes_downloaded = buf.len(); self.file.write_all(buf)?; - self.bytes_written = self.bytes_downloaded; - let size = human_bytes(self.content_length as f64); + self.bytes_written += bytes_downloaded; + let size = human_bytes(self.current.meta.content_length as f64); // Calculate percentage on the basis of content_length - let factor = self.bytes_downloaded as f32 / self.content_length as f32; + let factor = self.bytes_written as f32 / self.current.meta.content_length as f32; let percentage = (99.99 * factor) as u8; // NOTE: ensure lesser frequency of action responses, once every percentage points @@ -412,7 +488,7 @@ impl DownloadState { self.percentage_downloaded = percentage; debug!( "Downloading: size = {size}, percentage = {percentage}, elapsed = {}s", - self.start_instant.elapsed().as_secs() + self.start.elapsed().as_secs() ); Ok(Some(percentage)) @@ -420,15 +496,31 @@ impl DownloadState { trace!( "Downloading: size = {size}, percentage = {}, elapsed = {}s", self.percentage_downloaded, - self.start_instant.elapsed().as_secs() + self.start.elapsed().as_secs() ); Ok(None) } } - fn retry_range(&self) -> String { - format!("bytes={}-{}", self.bytes_written, self.content_length) + fn url(&self) -> &str { + &self.current.meta.url + } +} + +/// Command to remotely trigger `Downloader` shutdown +pub struct DownloaderShutdown; + +/// Handle to send control messages to `Downloader` +#[derive(Debug, Clone)] +pub struct CtrlTx { + pub(crate) inner: Sender, +} + +impl CtrlTx { + /// Triggers shutdown of `Downloader` + pub async fn trigger_shutdown(&self) { + self.inner.send_async(DownloaderShutdown).await.unwrap() } } @@ -493,7 +585,9 @@ mod test { // Create channels to forward and push actions on let (download_tx, download_rx) = bounded(1); - let downloader = FileDownloader::new(Arc::new(config), download_rx, bridge_tx).unwrap(); + let (_, ctrl_rx) = bounded(1); + let downloader = + FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap(); // Start FileDownloader in separate thread std::thread::spawn(|| downloader.start()); @@ -554,7 +648,9 @@ mod test { // Create channels to forward and push action_status on let (download_tx, download_rx) = bounded(1); - let downloader = FileDownloader::new(Arc::new(config), download_rx, bridge_tx).unwrap(); + let (_, ctrl_rx) = bounded(1); + let downloader = + FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap(); // Start FileDownloader in separate thread std::thread::spawn(|| downloader.start()); diff --git a/uplink/src/lib.rs b/uplink/src/lib.rs index 85959456..c536eb4a 100644 --- a/uplink/src/lib.rs +++ b/uplink/src/lib.rs @@ -61,7 +61,7 @@ use base::mqtt::Mqtt; use base::serializer::{Serializer, SerializerMetrics}; use base::CtrlTx; use collector::device_shadow::DeviceShadow; -use collector::downloader::FileDownloader; +use collector::downloader::{CtrlTx as DownloaderCtrlTx, FileDownloader}; use collector::installer::OTAInstaller; #[cfg(target_os = "linux")] use collector::journalctl::JournalCtl; @@ -130,7 +130,7 @@ impl Uplink { ) } - pub fn spawn(&mut self, bridge: Bridge) -> Result { + pub fn spawn(&mut self, mut bridge: Bridge) -> Result { let (mqtt_metrics_tx, mqtt_metrics_rx) = bounded(10); let (ctrl_actions_lane, ctrl_data_lane) = bridge.ctrl_tx(); @@ -146,6 +146,17 @@ impl Uplink { )?; let ctrl_serializer = serializer.ctrl_tx(); + let (ctrl_tx, ctrl_rx) = bounded(1); + let ctrl_downloader = DownloaderCtrlTx { inner: ctrl_tx }; + + // Downloader thread if configured + if !self.config.downloader.actions.is_empty() { + let actions_rx = bridge.register_action_routes(&self.config.downloader.actions)?; + let file_downloader = + FileDownloader::new(self.config.clone(), actions_rx, bridge.bridge_tx(), ctrl_rx)?; + spawn_named_thread("File Downloader", || file_downloader.start()); + } + // Serializer thread to handle network conditions state machine // and send data to mqtt thread spawn_named_thread("Serializer", || { @@ -219,6 +230,7 @@ impl Uplink { data_lane: ctrl_data_lane, mqtt: ctrl_mqtt, serializer: ctrl_serializer, + downloader: ctrl_downloader, }) } @@ -231,13 +243,6 @@ impl Uplink { let tunshell_client = TunshellClient::new(actions_rx, bridge_tx.clone()); spawn_named_thread("Tunshell Client", move || tunshell_client.start()); - if !self.config.downloader.actions.is_empty() { - let actions_rx = bridge.register_action_routes(&self.config.downloader.actions)?; - let file_downloader = - FileDownloader::new(self.config.clone(), actions_rx, bridge_tx.clone())?; - spawn_named_thread("File Downloader", || file_downloader.start()); - } - let device_shadow = DeviceShadow::new(self.config.device_shadow.clone(), bridge_tx.clone()); spawn_named_thread("Device Shadow Generator", move || device_shadow.start());