Skip to content

Commit

Permalink
Add unit test for signal handling, and handle Windows's closest analogue
Browse files Browse the repository at this point in the history
  • Loading branch information
lif committed Feb 8, 2024
1 parent 3a49d0d commit 59333e5
Showing 1 changed file with 74 additions and 12 deletions.
86 changes: 74 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use std::os::windows::io::AsRawHandle as AsRawFdHandle;
use futures::stream::FuturesUnordered;
#[cfg(target_family = "unix")]
use tokio::signal::unix::{signal, SignalKind};
#[cfg(target_family = "windows")]
use tokio::signal::windows::*;

use futures::{FutureExt, SinkExt, StreamExt};
use thiserror::Error;
Expand All @@ -42,7 +44,7 @@ pub enum Error {
StdoutWrite(#[from] std::io::Error),
#[error("Server error: {0}")]
ServerError(String),
#[error("Terminated by SIG{0}")]
#[error("Terminated by signal: {0}")]
Signal(&'static str),
}

Expand Down Expand Up @@ -149,7 +151,6 @@ impl<O: AsyncWriteExt + Unpin + Send> Console<O> {
upgraded: impl AsyncRead + AsyncWrite + Unpin,
) -> Result<(), Error> {
// need Signal structs to live at least as long as their futures
#[cfg(target_family = "unix")]
let mut signal_storage = Vec::new();

let mut signaled = FuturesUnordered::new();
Expand All @@ -161,12 +162,19 @@ impl<O: AsyncWriteExt + Unpin + Send> Console<O> {
signal_storage.push((signal(SignalKind::pipe())?, "PIPE"));
signal_storage.push((signal(SignalKind::quit())?, "QUIT"));
signal_storage.push((signal(SignalKind::terminate())?, "TERM"));
for (s_fut, s_name) in &mut signal_storage {
signaled.push(s_fut.recv().then(|opt| async move { opt.map(|_| s_name) }));
}
}
#[cfg(not(target_family = "unix"))]
signaled.push(std::future::pending());
#[cfg(target_family = "windows")]
{
// no ctrl_c(), we're already in VT100 mode, and raw mode in that
signal_storage.push((WinCtrlSignal::CBreak(ctrl_break()?), "CTRL-BREAK"));
signal_storage.push((WinCtrlSignal::CClose(ctrl_close()?), "CTRL-CLOSE"));
signal_storage.push((WinCtrlSignal::CLogoff(ctrl_logoff()?), "CTRL-LOGOFF"));
signal_storage.push((WinCtrlSignal::CShutdown(ctrl_shutdown()?), "CTRL-SHUTDOWN"));
}

for (s_fut, s_name) in &mut signal_storage {
signaled.push(s_fut.recv().then(|opt| async move { opt.map(|_| s_name) }));
}

let mut ws_stream = WebSocketStream::from_raw_socket(upgraded, Role::Client, None).await;

Expand Down Expand Up @@ -211,11 +219,8 @@ impl<O: AsyncWriteExt + Unpin + Send> Console<O> {
}
}
Some(Some(signal_name)) = signaled.next() => {
#[cfg(target_family = "unix")]
{
eprint!("\r\nExiting on signal.\r\n");
return Err(Error::Signal(signal_name));
}
eprint!("\r\nExiting on signal.\r\n");
return Err(Error::Signal(signal_name));
}
}
}
Expand All @@ -225,6 +230,28 @@ impl<O: AsyncWriteExt + Unpin + Send> Console<O> {
}
}

// unfortunately tokio::signal makes these all separate types...
#[cfg(target_family = "windows")]
enum WinCtrlSignal {
CC(CtrlC),
CBreak(CtrlBreak),
CClose(CtrlClose),
CLogoff(CtrlLogoff),
CShutdown(CtrlShutdown),
}
#[cfg(target_family = "windows")]
impl WinCtrlSignal {
async fn recv(&mut self) -> Option<()> {
match self {
CC(c) => c.recv().await,
CBreak(c) => c.recv().await,
CClose(c) => c.recv().await,
CLogoff(c) => c.recv().await,
CShutdown(c) => c.recv().await,
}
}
}

impl<O: AsyncWriteExt + Unpin + Send> Drop for Console<O> {
fn drop(&mut self) {
self.relay_handle.abort();
Expand Down Expand Up @@ -293,4 +320,39 @@ mod tests {
// ...and end the event loop.
timeout(ONE_SEC, join_handle).await.unwrap().unwrap();
}

#[cfg(target_family = "unix")]
#[tokio::test]
async fn test_cleanup_on_signal() {
let (_in_testdrv, in_console) = tokio::io::duplex(16);
let (mut out_testdrv, out_console) = tokio::io::duplex(16);
let (ws_testdrv, ws_console) = tokio::io::duplex(16);

let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await;
let mut console = Console::new_inner(in_console, out_console, None, None);

let join_handle =
tokio::spawn(async move { console.attach_to_websocket(ws_console).await });

ws.send(Message::Binary(vec![1, 2, 3, 4, 5, 6]))
.await
.unwrap();

let mut read_buf = [0u8; 6];
const ONE_SEC: Duration = Duration::from_secs(1);
timeout(ONE_SEC, out_testdrv.read_exact(&mut read_buf))
.await
.unwrap()
.unwrap();
assert_eq!(read_buf, [1, 2, 3, 4, 5, 6]);

let syscall_return = unsafe { libc::kill(std::process::id() as libc::c_int, libc::SIGINT) };
assert_eq!(syscall_return, 0);

let Err(super::Error::Signal("INT")) =
timeout(ONE_SEC, join_handle).await.unwrap().unwrap()
else {
panic!("Expected SIGINT!")
};
}
}

0 comments on commit 59333e5

Please sign in to comment.