From 1fd2d6f9bb87ff2eef23e1f40dd5f63482a05bf3 Mon Sep 17 00:00:00 2001 From: lif <> Date: Tue, 13 Feb 2024 21:27:11 -0800 Subject: [PATCH] Fix serial console breaking on Windows when a non-UTF8 sequence is output --- src/lib.rs | 51 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 968442c..7e8ccf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,6 +66,14 @@ pub struct Console { raw_guard: Option, } +impl Console { + /// Construct with the normal stdin and stdout file descriptors, for + /// typical use. + pub async fn new_stdio(escape: Option) -> Result { + Console::new(tokio::io::stdin(), tokio::io::stdout(), escape).await + } +} + impl Console { /// Construct with arbitrary [AsyncReadExt] and [AsyncWriteExt] streams, /// supporting use cases where we might be talking to something other than @@ -86,15 +94,16 @@ impl Console { } } -impl Console { - /// Construct with the normal stdin and stdout file descriptors, for - /// typical use. - pub async fn new_stdio(escape: Option) -> Result { - Console::new(tokio::io::stdin(), tokio::io::stdout(), escape).await - } -} +// this is really silly. in order to test reasonably with a tokio::io::duplex, +// which doesn't impl AsRawHandle, we only use the trait bound required by the +// raw Win32 API call path in non-tests, and use a dummy trait here otherwise +// to reduce code duplication. +#[cfg(any(test, not(target_family = "windows")))] +use std::marker::Sized as MightBeRawHandle; +#[cfg(all(not(test), target_family = "windows"))] +use std::os::windows::io::AsRawHandle as MightBeRawHandle; -impl Console { +impl Console { fn new_inner( stdin: I, stdout: O, @@ -125,8 +134,30 @@ impl Console { /// Write the given bytes to stdout. pub async fn write_stdout(&mut self, bytes: &[u8]) -> Result<(), Error> { - self.stdout.write_all(bytes).await?; - self.stdout.flush().await?; + // windows io in rust fails if any byte sequences aren't valid utf8 + #[cfg(all(not(test), target_family = "windows"))] + { + use winapi::shared::minwindef::LPDWORD; + use winapi::um::winnt::{HANDLE, VOID}; + let mut _lp_num_of_chars_written = 0u32; + let res = unsafe { + winapi::um::consoleapi::WriteConsoleA( + self.stdout.as_raw_handle() as HANDLE, + bytes.as_ptr() as *const VOID, + bytes.len() as u32, + (&mut _lp_num_of_chars_written) as LPDWORD, + std::ptr::null_mut::(), + ) + }; + if res == 0 { + return Err(std::io::Error::last_os_error().into()); + } + } + #[cfg(any(test, not(target_family = "windows")))] + { + self.stdout.write_all(bytes).await?; + self.stdout.flush().await?; + } Ok(()) }