From 00cf48859437a5a2891f26b9f0457864a7ce7e2c Mon Sep 17 00:00:00 2001 From: lif <> Date: Wed, 15 May 2024 14:03:14 -0700 Subject: [PATCH] lazily initialize raw-mode guard, only when we have something to output to the terminal. (if an error occurs, we don't have anything to reset) --- src/bin/test_raw.rs | 4 +++ src/lib.rs | 59 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/src/bin/test_raw.rs b/src/bin/test_raw.rs index 6cbf2cd..f0b63e2 100644 --- a/src/bin/test_raw.rs +++ b/src/bin/test_raw.rs @@ -5,6 +5,10 @@ use thouart::Console; #[tokio::main] async fn main() -> Result<(), Box> { let mut cons = Console::new_stdio(None).await?; + + // force raw-mode initialization + cons.write_stdout(&[]).await?; + let mut buffer = vec![]; loop { let (read_fut, write_fut) = if buffer.is_empty() { diff --git a/src/lib.rs b/src/lib.rs index 6a4e31c..a56fbf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,43 @@ pub enum Error { Signal(&'static str), } +enum RawModeGuardState { + None, + #[cfg(target_family = "windows")] + PreInit { + stdin: std::os::windows::raw::HANDLE, + stdout: std::os::windows::raw::HANDLE, + }, + #[cfg(target_family = "unix")] + PreInit(std::os::fd::RawFd), + Some(RawModeGuard), +} + +// HANDLE (here used as a file descriptor) is a void pointer. +#[cfg(target_family = "windows")] +unsafe impl Send for RawModeGuardState {} + +impl RawModeGuardState { + fn ensure_enabled(&mut self) -> Result<(), Error> { + match self { + Self::None | Self::Some(_) => Ok(()), + #[cfg(target_family = "windows")] + Self::PreInit { stdin, stdout } => { + *self = Self::Some(RawModeGuard::new(*stdin, *stdout)?); + Ok(()) + } + #[cfg(target_family = "unix")] + Self::PreInit(fd) => { + *self = Self::Some(RawModeGuard::new(*fd)?); + Ok(()) + } + } + } + fn take(&mut self) { + *self = Self::None; + } +} + /// A simple abstraction over a TTY's async I/O streams. /// /// It provides: @@ -72,7 +109,7 @@ pub struct Console { relay_rx: mpsc::Receiver>, read_handle: task::JoinHandle<()>, relay_handle: task::JoinHandle<()>, - raw_guard: Option, + raw_guard: RawModeGuardState, } impl Console { @@ -93,12 +130,12 @@ impl Console { escape: Option, ) -> Result { #[cfg(target_family = "unix")] - let raw_guard = Some(RawModeGuard::new(stdout.as_raw_fd())?); + let raw_guard = RawModeGuardState::PreInit(stdout.as_raw_fd()); #[cfg(target_family = "windows")] - let raw_guard = Some(RawModeGuard::new( - stdin.as_raw_handle(), - stdout.as_raw_handle(), - )?); + let raw_guard = RawModeGuardState::PreInit { + stdin: stdin.as_raw_handle(), + stdout: stdout.as_raw_handle(), + }; Ok(Self::new_inner(stdin, stdout, escape, raw_guard)) } } @@ -117,7 +154,7 @@ impl Console { stdin: I, stdout: O, escape: Option, - raw_guard: Option, + raw_guard: RawModeGuardState, ) -> Self { let (read_tx, read_rx) = mpsc::channel(16); let (relay_tx, relay_rx) = mpsc::channel(16); @@ -143,6 +180,7 @@ impl Console { /// Write the given bytes to stdout. pub async fn write_stdout(&mut self, bytes: &[u8]) -> Result<(), Error> { + self.raw_guard.ensure_enabled()?; // windows io in rust fails if any byte sequences aren't valid utf8 #[cfg(all(not(test), target_family = "windows"))] { @@ -293,6 +331,7 @@ impl Drop for Console { #[cfg(test)] mod tests { + use super::RawModeGuardState; use crate::{Console, EscapeSequence}; use futures::{SinkExt, StreamExt}; use std::time::Duration; @@ -311,7 +350,8 @@ mod tests { let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await; let escape = Some(EscapeSequence::new(vec![1, 2, 3], 1).unwrap()); - let mut console = Console::new_inner(in_console, out_console, escape, None); + let mut console = + Console::new_inner(in_console, out_console, escape, RawModeGuardState::None); let join_handle = tokio::spawn(async move { console.attach_to_websocket(ws_console).await.unwrap(); @@ -360,7 +400,8 @@ mod tests { let (ws_testdrv, ws_console) = tokio::io::duplex(16); let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await; - let mut console = Console::new_inner(in_console, out_console, None, None); + let mut console = + Console::new_inner(in_console, out_console, None, RawModeGuardState::None); let join_handle = tokio::spawn(async move { console.attach_to_websocket(ws_console).await });