Skip to content

Commit

Permalink
feat: restart download post reboot (#335)
Browse files Browse the repository at this point in the history
* feat: restart download post reboot

* refactor: call it state

* refactor: remove need for clone

* fix: set action_id on reload
  • Loading branch information
Devdutt Shenoi authored Apr 2, 2024
1 parent 500ca0d commit 6c723c9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 44 deletions.
5 changes: 4 additions & 1 deletion uplink/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tokio::join;
use self::bridge::{ActionsLaneCtrlTx, DataLaneCtrlTx};
use self::mqtt::CtrlTx as MqttCtrlTx;
use self::serializer::CtrlTx as SerializerCtrlTx;
use crate::collector::downloader::CtrlTx as DownloaderCtrlTx;

pub mod actions;
pub mod bridge;
Expand All @@ -26,6 +27,7 @@ pub struct CtrlTx {
pub data_lane: DataLaneCtrlTx,
pub mqtt: MqttCtrlTx,
pub serializer: SerializerCtrlTx,
pub downloader: DownloaderCtrlTx,
}

impl CtrlTx {
Expand All @@ -34,7 +36,8 @@ impl CtrlTx {
self.actions_lane.trigger_shutdown(),
self.data_lane.trigger_shutdown(),
self.mqtt.trigger_shutdown(),
self.serializer.trigger_shutdown()
self.serializer.trigger_shutdown(),
self.downloader.trigger_shutdown()
);
}
}
179 changes: 145 additions & 34 deletions uplink/src/collector/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@
//! [`action_redirections`]: Config#structfield.action_redirections
use bytes::BytesMut;
use flume::Receiver;
use flume::{Receiver, Sender};
use futures_util::StreamExt;
use human_bytes::human_bytes;
use log::{debug, error, info, trace, warn};
use reqwest::{Certificate, Client, ClientBuilder, Error as ReqwestError, Identity};
use rsa::sha2::{Digest, Sha256};
use serde::{Deserialize, Serialize};
use tokio::select;
use tokio::time::{timeout_at, Instant};

use std::fs::{metadata, remove_dir_all, File};
use std::fs::{metadata, read, remove_dir_all, remove_file, write, File};
use std::io;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -88,6 +89,10 @@ pub enum Error {
BadChecksum,
#[error("Disk space is insufficient: {0}")]
InsufficientDisk(String),
#[error("Save file is corrupted")]
BadSave,
#[error("Save file doesn't exist")]
NoSave,
}

/// This struct contains the necessary components to download and store file as notified by a download file
Expand All @@ -101,6 +106,7 @@ pub struct FileDownloader {
bridge_tx: BridgeTx,
client: Client,
sequence: u32,
shutdown_rx: Receiver<DownloaderShutdown>,
}

impl FileDownloader {
Expand All @@ -109,6 +115,7 @@ impl FileDownloader {
config: Arc<Config>,
actions_rx: Receiver<Action>,
bridge_tx: BridgeTx,
shutdown_rx: Receiver<DownloaderShutdown>,
) -> Result<Self, Error> {
// Authenticate with TLS certs from config
let client_builder = ClientBuilder::new();
Expand All @@ -132,13 +139,16 @@ impl FileDownloader {
bridge_tx,
sequence: 0,
action_id: String::default(),
shutdown_rx,
})
}

/// Spawn a thread to handle downloading files as notified by download actions and for forwarding the updated actions
/// back to bridge for further processing, e.g. OTA update installation.
#[tokio::main(flavor = "current_thread")]
pub async fn start(mut self) {
self.reload().await;

info!("Downloader thread is ready to receive download actions");
loop {
self.sequence = 0;
Expand Down Expand Up @@ -168,39 +178,71 @@ impl FileDownloader {
}

// Forward updated action as part of response
let DownloadState { action, .. } = state;
let DownloadState { current: CurrentDownload { action, .. }, .. } = state;
let status = ActionResponse::done(&self.action_id, "Downloaded", Some(action))
.set_sequence(self.sequence());
self.bridge_tx.send_action_response(status).await;
}
}

// reloads a download if it wasn't completed during the previous run of uplink
async fn reload(&mut self) {
let mut state = match DownloadState::load(&self.config) {
Ok(s) => s,
Err(Error::NoSave) => return,
Err(e) => {
warn!("Couldn't reload current_download: {e}");
return;
}
};
self.action_id = state.current.action.action_id.clone();

if let Err(e) = self.download(&mut state).await {
self.forward_error(e).await;
}

// Forward updated action as part of response
let DownloadState { current: CurrentDownload { action, .. }, .. } = state;
let status = ActionResponse::done(&self.action_id, "Downloaded", Some(action))
.set_sequence(self.sequence());
self.bridge_tx.send_action_response(status).await;
}
// Accepts `DownloadState`, sets a timeout for the action
async fn download(&mut self, state: &mut DownloadState) -> Result<(), Error> {
let deadline = match &state.action.deadline {
let shutdown_rx = self.shutdown_rx.clone();
let deadline = match &state.current.action.deadline {
Some(d) => *d,
_ => {
error!("Unconfigured deadline: {}", state.action.name);
error!("Unconfigured deadline: {}", state.current.action.name);
return Ok(());
}
};
// NOTE: if download has timedout don't do anything, else ensure errors are forwarded after three retries
match timeout_at(deadline, self.continuous_retry(state)).await {
Ok(r) => r?,
Err(_) => error!("Last download has timedout"),
select! {
// NOTE: if download has timedout don't do anything, else ensure errors are forwarded after three retries
o = timeout_at(deadline, self.continuous_retry(state)) => match o {
Ok(r) => r?,
Err(_) => error!("Last download has timedout"),
},

_ = shutdown_rx.recv_async() => {
if let Err(e) = state.save(&self.config) {
error!("Error saving current_download: {e}");
}
}

}

state.meta.verify_checksum()?;
state.current.meta.verify_checksum()?;
// Update Action payload with `download_path`, i.e. downloaded file's location in fs
state.action.payload = serde_json::to_string(&state.meta)?;
state.current.action.payload = serde_json::to_string(&state.current.meta)?;

Ok(())
}

// A download must be retried with Range header when HTTP/reqwest errors are faced
async fn continuous_retry(&mut self, state: &mut DownloadState) -> Result<(), Error> {
'outer: loop {
let mut req = self.client.get(&state.meta.url);
let mut req = self.client.get(&state.current.meta.url);
if let Some(range) = state.retry_range() {
warn!("Retrying download; Continuing to download file from: {range}");
req = req.header("Range", range);
Expand Down Expand Up @@ -334,19 +376,77 @@ impl DownloadFile {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
struct CurrentDownload {
action: Action,
meta: DownloadFile,
time_left: Option<Duration>,
}

// A temporary structure to help us retry downloads
// that failed after partial completion.
#[derive(Debug)]
struct DownloadState {
action: Action,
meta: DownloadFile,
current: CurrentDownload,
file: File,
bytes_written: usize,
bytes_downloaded: usize,
percentage_downloaded: u8,
start_instant: Instant,
start: Instant,
}

impl DownloadState {
fn load(config: &DownloaderConfig) -> Result<Self, Error> {
let mut path = config.path.clone();
path.push("current_download");

if !path.exists() {
return Err(Error::NoSave);
}

let read = read(&path)?;
let mut current: CurrentDownload = serde_json::from_slice(&read)?;
// Calculate deadline based on written time left
current.action.deadline = current.time_left.map(|t| Instant::now() + t);

let file = File::open(current.meta.download_path.as_ref().unwrap())?;
let bytes_written = file.metadata()?.len() as usize;

remove_file(path)?;

Ok(DownloadState {
current,
file,
bytes_written,
percentage_downloaded: 0,
start: Instant::now(),
})
}

fn save(&self, config: &DownloaderConfig) -> Result<(), Error> {
if self.bytes_written == self.current.meta.content_length {
return Ok(());
}

let mut current = self.current.clone();
// Calculate time left based on deadline
current.time_left = current.action.deadline.map(|t| t.duration_since(Instant::now()));
let json = serde_json::to_vec(&current)?;

let mut path = config.path.clone();
path.push("current_download");
write(path, json)?;

Ok(())
}

fn retry_range(&self) -> Option<String> {
if self.bytes_written == 0 {
return None;
}

Some(format!("bytes={}-{}", self.bytes_written, self.current.meta.content_length))
}

fn new(action: Action, config: &DownloaderConfig) -> Result<Self, Error> {
// Ensure that directory for downloading file into, exists
let mut path = config.path.clone();
Expand Down Expand Up @@ -385,54 +485,61 @@ impl DownloadState {
human_bytes(meta.content_length as f64)
);
meta.download_path = Some(file_path);
let current = CurrentDownload { action, meta, time_left: None };

Ok(DownloadState {
action,
meta,
Ok(Self {
current,
file,
bytes_written: 0,
bytes_downloaded: 0,
percentage_downloaded: 0,
start_instant: Instant::now(),
start: Instant::now(),
})
}

fn write_bytes(&mut self, buf: &[u8]) -> Result<Option<u8>, Error> {
self.bytes_downloaded += buf.len();
let bytes_downloaded = buf.len();
self.file.write_all(buf)?;
self.bytes_written = self.bytes_downloaded;
let size = human_bytes(self.meta.content_length as f64);
self.bytes_written += bytes_downloaded;
let size = human_bytes(self.current.meta.content_length as f64);

// Calculate percentage on the basis of content_length
let factor = self.bytes_downloaded as f32 / self.meta.content_length as f32;
let factor = self.bytes_written as f32 / self.current.meta.content_length as f32;
let percentage = (99.99 * factor) as u8;

// NOTE: ensure lesser frequency of action responses, once every percentage points
if percentage > self.percentage_downloaded {
self.percentage_downloaded = percentage;
debug!(
"Downloading: size = {size}, percentage = {percentage}, elapsed = {}s",
self.start_instant.elapsed().as_secs()
self.start.elapsed().as_secs()
);

Ok(Some(percentage))
} else {
trace!(
"Downloading: size = {size}, percentage = {}, elapsed = {}s",
self.percentage_downloaded,
self.start_instant.elapsed().as_secs()
self.start.elapsed().as_secs()
);

Ok(None)
}
}
}

fn retry_range(&self) -> Option<String> {
if self.bytes_written == self.meta.content_length {
return None;
}
/// Command to remotely trigger `Downloader` shutdown
pub struct DownloaderShutdown;

/// Handle to send control messages to `Downloader`
#[derive(Debug, Clone)]
pub struct CtrlTx {
pub(crate) inner: Sender<DownloaderShutdown>,
}

Some(format!("bytes={}-{}", self.bytes_written, self.meta.content_length))
impl CtrlTx {
/// Triggers shutdown of `Downloader`
pub async fn trigger_shutdown(&self) {
self.inner.send_async(DownloaderShutdown).await.unwrap()
}
}

Expand Down Expand Up @@ -497,7 +604,9 @@ mod test {

// Create channels to forward and push actions on
let (download_tx, download_rx) = bounded(1);
let downloader = FileDownloader::new(Arc::new(config), download_rx, bridge_tx).unwrap();
let (_, ctrl_rx) = bounded(1);
let downloader =
FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap();

// Start FileDownloader in separate thread
std::thread::spawn(|| downloader.start());
Expand Down Expand Up @@ -557,7 +666,9 @@ mod test {

// Create channels to forward and push action_status on
let (download_tx, download_rx) = bounded(1);
let downloader = FileDownloader::new(Arc::new(config), download_rx, bridge_tx).unwrap();
let (_, ctrl_rx) = bounded(1);
let downloader =
FileDownloader::new(Arc::new(config), download_rx, bridge_tx, ctrl_rx).unwrap();

// Start FileDownloader in separate thread
std::thread::spawn(|| downloader.start());
Expand Down
Loading

0 comments on commit 6c723c9

Please sign in to comment.