Skip to content

Commit

Permalink
lazily initialize raw-mode guard, only when we have something to outp…
Browse files Browse the repository at this point in the history
…ut to the terminal. (if an error occurs, we don't have anything to reset)
  • Loading branch information
lif committed May 15, 2024
1 parent f960190 commit 00cf488
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/bin/test_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use thouart::Console;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let mut cons = Console::new_stdio(None).await?;

// force raw-mode initialization
cons.write_stdout(&[]).await?;

let mut buffer = vec![];
loop {
let (read_fut, write_fut) = if buffer.is_empty() {
Expand Down
59 changes: 50 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ pub enum Error {
Signal(&'static str),
}

enum RawModeGuardState {
None,
#[cfg(target_family = "windows")]
PreInit {
stdin: std::os::windows::raw::HANDLE,
stdout: std::os::windows::raw::HANDLE,
},
#[cfg(target_family = "unix")]
PreInit(std::os::fd::RawFd),
Some(RawModeGuard),
}

// HANDLE (here used as a file descriptor) is a void pointer.
#[cfg(target_family = "windows")]
unsafe impl Send for RawModeGuardState {}

impl RawModeGuardState {
fn ensure_enabled(&mut self) -> Result<(), Error> {
match self {
Self::None | Self::Some(_) => Ok(()),
#[cfg(target_family = "windows")]
Self::PreInit { stdin, stdout } => {
*self = Self::Some(RawModeGuard::new(*stdin, *stdout)?);
Ok(())
}
#[cfg(target_family = "unix")]
Self::PreInit(fd) => {
*self = Self::Some(RawModeGuard::new(*fd)?);
Ok(())
}
}
}
fn take(&mut self) {
*self = Self::None;
}
}

/// A simple abstraction over a TTY's async I/O streams.
///
/// It provides:
Expand All @@ -72,7 +109,7 @@ pub struct Console<O: AsyncWriteExt + Unpin + Send> {
relay_rx: mpsc::Receiver<Vec<u8>>,
read_handle: task::JoinHandle<()>,
relay_handle: task::JoinHandle<()>,
raw_guard: Option<RawModeGuard>,
raw_guard: RawModeGuardState,
}

impl Console<tokio::io::Stdout> {
Expand All @@ -93,12 +130,12 @@ impl<O: AsyncWriteExt + Unpin + Send + AsRawFdHandle> Console<O> {
escape: Option<EscapeSequence>,
) -> Result<Self, Error> {
#[cfg(target_family = "unix")]
let raw_guard = Some(RawModeGuard::new(stdout.as_raw_fd())?);
let raw_guard = RawModeGuardState::PreInit(stdout.as_raw_fd());
#[cfg(target_family = "windows")]
let raw_guard = Some(RawModeGuard::new(
stdin.as_raw_handle(),
stdout.as_raw_handle(),
)?);
let raw_guard = RawModeGuardState::PreInit {
stdin: stdin.as_raw_handle(),
stdout: stdout.as_raw_handle(),
};
Ok(Self::new_inner(stdin, stdout, escape, raw_guard))
}
}
Expand All @@ -117,7 +154,7 @@ impl<O: AsyncWriteExt + Unpin + Send + MightBeRawHandle> Console<O> {
stdin: I,
stdout: O,
escape: Option<EscapeSequence>,
raw_guard: Option<RawModeGuard>,
raw_guard: RawModeGuardState,
) -> Self {
let (read_tx, read_rx) = mpsc::channel(16);
let (relay_tx, relay_rx) = mpsc::channel(16);
Expand All @@ -143,6 +180,7 @@ impl<O: AsyncWriteExt + Unpin + Send + MightBeRawHandle> Console<O> {

/// Write the given bytes to stdout.
pub async fn write_stdout(&mut self, bytes: &[u8]) -> Result<(), Error> {
self.raw_guard.ensure_enabled()?;
// windows io in rust fails if any byte sequences aren't valid utf8
#[cfg(all(not(test), target_family = "windows"))]
{
Expand Down Expand Up @@ -293,6 +331,7 @@ impl<O: AsyncWriteExt + Unpin + Send> Drop for Console<O> {

#[cfg(test)]
mod tests {
use super::RawModeGuardState;
use crate::{Console, EscapeSequence};
use futures::{SinkExt, StreamExt};
use std::time::Duration;
Expand All @@ -311,7 +350,8 @@ mod tests {
let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await;

let escape = Some(EscapeSequence::new(vec![1, 2, 3], 1).unwrap());
let mut console = Console::new_inner(in_console, out_console, escape, None);
let mut console =
Console::new_inner(in_console, out_console, escape, RawModeGuardState::None);

let join_handle = tokio::spawn(async move {
console.attach_to_websocket(ws_console).await.unwrap();
Expand Down Expand Up @@ -360,7 +400,8 @@ mod tests {
let (ws_testdrv, ws_console) = tokio::io::duplex(16);

let mut ws = WebSocketStream::from_raw_socket(ws_testdrv, Role::Server, None).await;
let mut console = Console::new_inner(in_console, out_console, None, None);
let mut console =
Console::new_inner(in_console, out_console, None, RawModeGuardState::None);

let join_handle =
tokio::spawn(async move { console.attach_to_websocket(ws_console).await });
Expand Down

0 comments on commit 00cf488

Please sign in to comment.