Skip to content

Commit

Permalink
lazily initialize raw-mode guard
Browse files Browse the repository at this point in the history
only construct the RawModeGuard when we have something to output to the terminal.
if an error occurs before we ever receive output -- such as when an instance doesn't exist -- we don't have any terminal state to reset.
  • Loading branch information
lifning authored May 15, 2024
1 parent f960190 commit 1f22ae2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/bin/test_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use thouart::Console;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let mut cons = Console::new_stdio(None).await?;

// force raw-mode initialization
cons.write_stdout(&[]).await?;

let mut buffer = vec![];
loop {
let (read_fut, write_fut) = if buffer.is_empty() {
Expand Down
59 changes: 50 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ pub enum Error {
Signal(&'static str),
}

enum RawModeGuardState {
Empty,
#[cfg(target_family = "windows")]
PreInit {
stdin: std::os::windows::raw::HANDLE,
stdout: std::os::windows::raw::HANDLE,
},
#[cfg(target_family = "unix")]
PreInit(std::os::fd::RawFd),
Initialized(RawModeGuard),
}

// HANDLE (here used as a file descriptor) is a void pointer.
#[cfg(target_family = "windows")]
unsafe impl Send for RawModeGuardState {}

impl RawModeGuardState {
fn ensure_enabled(&mut self) -> Result<(), Error> {
match self {
Self::Empty | Self::Initialized(_) => Ok(()),
#[cfg(target_family = "windows")]
Self::PreInit { stdin, stdout } => {
*self = Self::Initialized(RawModeGuard::new(*stdin, *stdout)?);
Ok(())
}
#[cfg(target_family = "unix")]
Self::PreInit(fd) => {
*self = Self::Initialized(RawModeGuard::new(*fd)?);
Ok(())
}
}
}
fn take(&mut self) {
*self = Self::Empty;
}
}

/// A simple abstraction over a TTY's async I/O streams.
///
/// It provides:
Expand All @@ -72,7 +109,7 @@ pub struct Console<O: AsyncWriteExt + Unpin + Send> {
relay_rx: mpsc::Receiver<Vec<u8>>,
read_handle: task::JoinHandle<()>,
relay_handle: task::JoinHandle<()>,
raw_guard: Option<RawModeGuard>,
raw_guard: RawModeGuardState,
}

impl Console<tokio::io::Stdout> {
Expand All @@ -93,12 +130,12 @@ impl<O: AsyncWriteExt + Unpin + Send + AsRawFdHandle> Console<O> {
escape: Option<EscapeSequence>,
) -> Result<Self, Error> {
#[cfg(target_family = "unix")]
let raw_guard = Some(RawModeGuard::new(stdout.as_raw_fd())?);
let raw_guard = RawModeGuardState::PreInit(stdout.as_raw_fd());
#[cfg(target_family = "windows")]
let raw_guard = Some(RawModeGuard::new(
stdin.as_raw_handle(),
stdout.as_raw_handle(),
)?);
let raw_guard = RawModeGuardState::PreInit {
stdin: stdin.as_raw_handle(),
stdout: stdout.as_raw_handle(),
};
Ok(Self::new_inner(stdin, stdout, escape, raw_guard))
}
}
Expand All @@ -117,7 +154,7 @@ impl<O: AsyncWriteExt + Unpin + Send + MightBeRawHandle> Console<O> {
stdin: I,
stdout: O,
escape: Option<EscapeSequence>,
raw_guard: Option<RawModeGuard>,
raw_guard: RawModeGuardState,
) -> Self {
let (read_tx, read_rx) = mpsc::channel(16);
let (relay_tx, relay_rx) = mpsc::channel(16);
Expand All @@ -143,6 +180,7 @@ impl<O: AsyncWriteExt + Unpin + Send + MightBeRawHandle> Console<O> {

/// Write the given bytes to stdout.
pub async fn write_stdout(&mut self, bytes: &[u8]) -> Result<(), Error> {
self.raw_guard.ensure_enabled()?;
// windows io in rust fails if any byte sequences aren't valid utf8
#[cfg(all(not(test), target_family = "windows"))]
{
Expand Down Expand Up @@ -293,6 +331,7 @@ impl<O: AsyncWriteExt + Unpin + Send> Drop for Console<O> {

#[cfg(test)]
mod tests {
use super::RawModeGuardState;
use crate::{Console, EscapeSequence};
use futures::{SinkExt, StreamExt};
use std::time::Duration;
Expand All @@ -311,7 +350,8 @@ mod tests {
let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await;

let escape = Some(EscapeSequence::new(vec![1, 2, 3], 1).unwrap());
let mut console = Console::new_inner(in_console, out_console, escape, None);
let mut console =
Console::new_inner(in_console, out_console, escape, RawModeGuardState::Empty);

let join_handle = tokio::spawn(async move {
console.attach_to_websocket(ws_console).await.unwrap();
Expand Down Expand Up @@ -360,7 +400,8 @@ mod tests {
let (ws_testdrv, ws_console) = tokio::io::duplex(16);

let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await;
let mut console = Console::new_inner(in_console, out_console, None, None);
let mut console =
Console::new_inner(in_console, out_console, None, RawModeGuardState::Empty);

let join_handle =
tokio::spawn(async move { console.attach_to_websocket(ws_console).await });
Expand Down

0 comments on commit 1f22ae2

Please sign in to comment.