Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serial console breaking on Windows when a non-UTF8 sequence is output #12

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ pub struct Console<O: AsyncWriteExt + Unpin + Send> {
raw_guard: Option<RawModeGuard>,
}

impl Console<tokio::io::Stdout> {
/// Construct with the normal stdin and stdout file descriptors, for
/// typical use.
pub async fn new_stdio(escape: Option<EscapeSequence>) -> Result<Self, Error> {
Console::new(tokio::io::stdin(), tokio::io::stdout(), escape).await
}
}

impl<O: AsyncWriteExt + Unpin + Send + AsRawFdHandle> Console<O> {
/// Construct with arbitrary [AsyncReadExt] and [AsyncWriteExt] streams,
/// supporting use cases where we might be talking to something other than
Expand All @@ -86,15 +94,16 @@ impl<O: AsyncWriteExt + Unpin + Send + AsRawFdHandle> Console<O> {
}
}

impl Console<tokio::io::Stdout> {
/// Construct with the normal stdin and stdout file descriptors, for
/// typical use.
pub async fn new_stdio(escape: Option<EscapeSequence>) -> Result<Self, Error> {
Console::new(tokio::io::stdin(), tokio::io::stdout(), escape).await
}
}
// this is really silly. in order to test reasonably with a tokio::io::duplex,
// which doesn't impl AsRawHandle, we only use the trait bound required by the
// raw Win32 API call path in non-tests, and use a dummy trait here otherwise
// to reduce code duplication.
#[cfg(any(test, not(target_family = "windows")))]
use std::marker::Sized as MightBeRawHandle;
#[cfg(all(not(test), target_family = "windows"))]
use std::os::windows::io::AsRawHandle as MightBeRawHandle;

impl<O: AsyncWriteExt + Unpin + Send> Console<O> {
impl<O: AsyncWriteExt + Unpin + Send + MightBeRawHandle> Console<O> {
fn new_inner<I: AsyncReadExt + Unpin + Send + 'static>(
stdin: I,
stdout: O,
Expand Down Expand Up @@ -125,8 +134,30 @@ impl<O: AsyncWriteExt + Unpin + Send> Console<O> {

/// Write the given bytes to stdout.
pub async fn write_stdout(&mut self, bytes: &[u8]) -> Result<(), Error> {
self.stdout.write_all(bytes).await?;
self.stdout.flush().await?;
// windows io in rust fails if any byte sequences aren't valid utf8
#[cfg(all(not(test), target_family = "windows"))]
{
use winapi::shared::minwindef::LPDWORD;
use winapi::um::winnt::{HANDLE, VOID};
let mut _lp_num_of_chars_written = 0u32;
let res = unsafe {
winapi::um::consoleapi::WriteConsoleA(
self.stdout.as_raw_handle() as HANDLE,
bytes.as_ptr() as *const VOID,
bytes.len() as u32,
(&mut _lp_num_of_chars_written) as LPDWORD,
std::ptr::null_mut::<VOID>(),
)
};
if res == 0 {
return Err(std::io::Error::last_os_error().into());
}
}
#[cfg(any(test, not(target_family = "windows")))]
{
self.stdout.write_all(bytes).await?;
self.stdout.flush().await?;
}
Ok(())
}

Expand Down