diff --git a/src/line_collector.rs b/src/line_collector.rs index 4619051..0c97c1e 100644 --- a/src/line_collector.rs +++ b/src/line_collector.rs @@ -50,8 +50,17 @@ fn get_fixed_highlight(line: &str) -> Option<&str> { return None; } -fn print(stream: &mut BufWriter, text: &str) { - if let Err(error) = stream.write_all(text.as_bytes()) { +/// Write the string bytes to the stream. +fn print(stream: &mut BufWriter, text: &str, strip_color: bool) { + let result = if strip_color { + let mut bytes = text.as_bytes().to_vec(); + remove_ansi_escape_codes(&mut bytes); + stream.write_all(&bytes) + } else { + stream.write_all(text.as_bytes()) + }; + + if let Err(error) = result { if error.kind() == ErrorKind::BrokenPipe { // This is fine, somebody probably just quit their pager before it // was done reading our output. @@ -129,7 +138,7 @@ impl Drop for LineCollector { } impl LineCollector { - pub fn new(output: W) -> LineCollector { + pub fn new(output: W, color: bool) -> LineCollector { // This is how many entries we can look ahead. An "entry" in this case // being either a plain text section or an oldnew section. // @@ -158,7 +167,7 @@ impl LineCollector { // Secret handshake received, done! break; } - print(&mut output, print_me.get()); + print(&mut output, print_me.get(), !color); } } }) diff --git a/src/main.rs b/src/main.rs index 8b4741a..7e3b3d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ extern crate lazy_static; use backtrace::Backtrace; use clap::CommandFactory; use clap::Parser; +use clap::ValueEnum; use git_version::git_version; use line_collector::LineCollector; use std::io::{self, IsTerminal}; @@ -71,9 +72,9 @@ const GIT_VERSION: &str = git_version!(cargo_prefix = ""); about = "Colors diff output, highlighting the changed parts of every line.", after_help = HELP_TEXT_FOOTER, override_usage = r#" - diff ... | riff [--no-pager] [--no-adds-only-special] - riff [-b] [--no-pager] [--no-adds-only-special] - riff [-b] [--no-pager] [--no-adds-only-special] --file "# + diff ... | riff [options...] + riff [-b] [options...] + riff [-b] [options...] --file "# )] struct Options { @@ -102,10 +103,33 @@ struct Options { #[arg(long)] no_adds_only_special: bool, + /// `auto` = color if stdout is a terminal + #[arg(long)] + color: Option, + #[arg(long, hide(true))] please_panic: bool, } +#[derive(ValueEnum, Clone, Default)] +enum ColorOption { + On, + Off, + + #[default] + Auto, +} + +impl ColorOption { + fn bool_or(self, default: bool) -> bool { + match self { + ColorOption::On => true, + ColorOption::Off => false, + ColorOption::Auto => default, + } + } +} + fn format_error(message: String, line_number: usize, line: &[u8]) -> Result<(), String> { return Err(format!( "ERROR on line {}: {}\n Line {}: {}", @@ -116,8 +140,12 @@ fn format_error(message: String, line_number: usize, line: &[u8]) -> Result<(), )); } -fn highlight_diff_or_exit(input: &mut dyn io::Read, output: W) { - if let Err(message) = highlight_diff(input, output) { +fn highlight_diff_or_exit( + input: &mut dyn io::Read, + output: W, + color: bool, +) { + if let Err(message) = highlight_diff(input, output, color) { eprintln!("{}", message); exit(1); } @@ -128,8 +156,9 @@ fn highlight_diff_or_exit(input: &mut dyn io::Rea fn highlight_diff( input: &mut dyn io::Read, output: W, + color: bool, ) -> Result<(), String> { - let mut line_collector = LineCollector::new(output); + let mut line_collector = LineCollector::new(output, color); // Read input line by line, using from_utf8_lossy() to convert lines into // strings while handling invalid UTF-8 without crashing @@ -183,7 +212,7 @@ fn highlight_diff( /// /// Returns `true` if the pager was found, `false` otherwise. #[must_use] -fn try_pager(input: &mut dyn io::Read, pager_name: &str) -> bool { +fn try_pager(input: &mut dyn io::Read, pager_name: &str, color: bool) -> bool { let mut command = Command::new(pager_name); if env::var(PAGER_FORKBOMB_STOP).is_ok() { @@ -208,7 +237,7 @@ fn try_pager(input: &mut dyn io::Read, pager_name: &str) -> bool { Ok(mut pager) => { let pager_stdin = pager.stdin.unwrap(); pager.stdin = None; - highlight_diff_or_exit(input, pager_stdin); + highlight_diff_or_exit(input, pager_stdin, color); // FIXME: Report pager exit status if non-zero, together with // contents of pager stderr as well if possible. @@ -243,20 +272,20 @@ fn panic_handler(panic_info: &panic::PanicInfo) { } /// Highlight the given stream, paging if stdout is a terminal -fn highlight_stream(input: &mut dyn io::Read, no_pager: bool) { +fn highlight_stream(input: &mut dyn io::Read, no_pager: bool, color: bool) { if !io::stdout().is_terminal() { // We're being piped, just do stdin -> stdout - highlight_diff_or_exit(input, io::stdout()); + highlight_diff_or_exit(input, io::stdout(), color); return; } if no_pager { - highlight_diff_or_exit(input, io::stdout()); + highlight_diff_or_exit(input, io::stdout(), color); return; } if let Ok(pager_value) = env::var("PAGER") { - if try_pager(input, &pager_value) { + if try_pager(input, &pager_value, color) { return; } @@ -264,16 +293,16 @@ fn highlight_stream(input: &mut dyn io::Read, no_pager: bool) { // doesn't exist. } - if try_pager(input, "moar") { + if try_pager(input, "moar", color) { return; } - if try_pager(input, "less") { + if try_pager(input, "less", color) { return; } // No pager found, wth? - highlight_diff_or_exit(input, io::stdout()); + highlight_diff_or_exit(input, io::stdout(), color); } /// `Not found`, `File`, `Directory` or `Not file not dir` @@ -305,7 +334,13 @@ fn ensure_listable(path: &path::Path) { } /// Run the `diff` binary on the two paths and highlight the output -fn exec_diff_highlight(path1: &str, path2: &str, ignore_space_change: bool, no_pager: bool) { +fn exec_diff_highlight( + path1: &str, + path2: &str, + ignore_space_change: bool, + no_pager: bool, + color: bool, +) { let path1 = path::Path::new(path1); let path2 = path::Path::new(path2); let both_paths_are_non_dirs = !path1.is_dir() && !path2.is_dir(); @@ -352,7 +387,7 @@ fn exec_diff_highlight(path1: &str, path2: &str, ignore_space_change: bool, no_p } let diff_stdout = diff_subprocess.stdout.as_mut().unwrap(); - highlight_stream(diff_stdout, no_pager); + highlight_stream(diff_stdout, no_pager, color); let diff_result = diff_subprocess.wait().unwrap(); let diff_exit_code = diff_result.code().unwrap_or(2); @@ -424,6 +459,10 @@ fn main() { &file2, options.ignore_space_change, options.no_pager, + options + .color + .unwrap_or(ColorOption::Auto) + .bool_or(io::stdout().is_terminal()), ); return; } @@ -442,7 +481,14 @@ fn main() { exit(1); } }; - highlight_stream(&mut diff_file, options.no_pager); + highlight_stream( + &mut diff_file, + options.no_pager, + options + .color + .unwrap_or(ColorOption::Auto) + .bool_or(io::stdout().is_terminal()), + ); return; } @@ -456,7 +502,14 @@ fn main() { exit(1); } - highlight_stream(&mut io::stdin().lock(), options.no_pager); + highlight_stream( + &mut io::stdin().lock(), + options.no_pager, + options + .color + .unwrap_or(ColorOption::Auto) + .bool_or(io::stdout().is_terminal()), + ); } #[cfg(test)] @@ -489,7 +542,7 @@ mod tests { ); let file = tempfile::NamedTempFile::new().unwrap(); - if let Err(error) = highlight_diff(&mut input, file.reopen().unwrap()) { + if let Err(error) = highlight_diff(&mut input, file.reopen().unwrap(), true) { panic!("{}", error); } let actual = fs::read_to_string(file.path()).unwrap(); @@ -594,6 +647,7 @@ mod tests { if let Err(error) = highlight_diff( &mut fs::File::open(&riff_input_file).unwrap(), file.reopen().unwrap(), + true, ) { if failing_example.is_none() { failing_example = Some(riff_input_file.to_str().unwrap().to_string());