From 34619a2e89b8f65d362a596b86337ca17409f13c Mon Sep 17 00:00:00 2001 From: sharkdp Date: Tue, 21 Apr 2020 08:19:24 +0200 Subject: Small refactoring, handle invalid UTF-8 filenames --- src/assets.rs | 42 +++++++++++++++++++++++++++++++++--------- src/bin/bat/app.rs | 7 ++++--- src/controller.rs | 2 +- src/inputfile.rs | 15 ++++++--------- src/printer.rs | 6 +++--- 5 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/assets.rs b/src/assets.rs index e374d5ed..6aca40be 100644 --- a/src/assets.rs +++ b/src/assets.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::ffi::OsStr; use std::fs::{self, File}; use std::io::BufReader; use std::path::Path; @@ -183,8 +184,7 @@ impl HighlightingAssets { let syntax = match (language, file) { (Some(language), _) => self.syntax_set.find_syntax_by_token(language), (None, InputFile::Ordinary(ofile)) => { - let path = Path::new(ofile.filename()); - let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + let path = Path::new(ofile.provided_path()); let line_syntax = self.get_first_line_syntax(reader); let absolute_path = path.canonicalize().ok().unwrap_or(path.to_owned()); @@ -195,14 +195,17 @@ impl HighlightingAssets { self.syntax_set.find_syntax_by_name(syntax_name) } Some(MappingTarget::MapToUnknown) => line_syntax, - None => self.get_extension_syntax(file_name).or(line_syntax), + None => { + let file_name = path.file_name().unwrap_or_default(); + self.get_extension_syntax(file_name).or(line_syntax) + } } } (None, InputFile::StdIn(None)) => String::from_utf8(reader.first_line.clone()) .ok() .and_then(|l| self.syntax_set.find_syntax_by_first_line(&l)), (None, InputFile::StdIn(Some(file_name))) => self - .get_extension_syntax(file_name.to_str().unwrap()) + .get_extension_syntax(file_name) .or(self.get_first_line_syntax(reader)), (_, InputFile::ThemePreviewFile) => self.syntax_set.find_syntax_by_name("Rust"), }; @@ -210,15 +213,15 @@ impl HighlightingAssets { syntax.unwrap_or_else(|| self.syntax_set.find_syntax_plain_text()) } - fn get_extension_syntax(&self, file_name: &str) -> Option<&SyntaxReference> { + fn get_extension_syntax(&self, file_name: &OsStr) -> Option<&SyntaxReference> { self.syntax_set - .find_syntax_by_extension(file_name) + .find_syntax_by_extension(file_name.to_str().unwrap_or_default()) .or_else(|| { self.syntax_set.find_syntax_by_extension( Path::new(file_name) .extension() .and_then(|x| x.to_str()) - .unwrap_or(""), + .unwrap_or_default(), ) }) } @@ -259,14 +262,14 @@ mod tests { } } - fn syntax_for_file_with_content(&self, file_name: &str, first_line: &str) -> String { + fn syntax_for_file_with_content_os(&self, file_name: &OsStr, first_line: &str) -> String { let file_path = self.temp_dir.path().join(file_name); { let mut temp_file = File::create(&file_path).unwrap(); writeln!(temp_file, "{}", first_line).unwrap(); } - let input_file = InputFile::Ordinary(OrdinaryFile::from_path(OsStr::new(&file_path))); + let input_file = InputFile::Ordinary(OrdinaryFile::from_path(file_path.as_os_str())); let syntax = self.assets.get_syntax( None, input_file, @@ -277,6 +280,14 @@ mod tests { syntax.name.clone() } + fn syntax_for_file_os(&self, file_name: &OsStr) -> String { + self.syntax_for_file_with_content_os(file_name, "") + } + + fn syntax_for_file_with_content(&self, file_name: &str, first_line: &str) -> String { + self.syntax_for_file_with_content_os(OsStr::new(file_name), first_line) + } + fn syntax_for_file(&self, file_name: &str) -> String { self.syntax_for_file_with_content(file_name, "") } @@ -308,6 +319,19 @@ mod tests { assert_eq!(test.syntax_for_file("Makefile"), "Makefile"); } + #[cfg(unix)] + #[test] + fn syntax_detection_invalid_utf8() { + use std::os::unix::ffi::OsStrExt; + + let test = SyntaxDetectionTest::new(); + + assert_eq!( + test.syntax_for_file_os(OsStr::from_bytes(b"invalid_\xFEutf8_filename.rs")), + "Rust" + ); + } + #[test] fn syntax_detection_well_defined_mapping_for_duplicate_extensions() { let test = SyntaxDetectionTest::new(); diff --git a/src/bin/bat/app.rs b/src/bin/bat/app.rs index cbcf0fbc..7c72f213 100644 --- a/src/bin/bat/app.rs +++ b/src/bin/bat/app.rs @@ -264,9 +264,10 @@ impl App { if input.to_str().unwrap() == "-" { file_input.push(InputFile::StdIn(name)); } else { - let ofile = name.map_or(OrdinaryFile::from_path(input), |n| { - OrdinaryFile::from_path_with_name(input, n) - }); + let mut ofile = OrdinaryFile::from_path(input); + if let Some(path) = name { + ofile.set_provided_path(path); + } file_input.push(InputFile::Ordinary(ofile)) } } diff --git a/src/controller.rs b/src/controller.rs index 11d4d3a3..10230562 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -36,7 +36,7 @@ impl<'b> Controller<'b> { if self.config.paging_mode != PagingMode::Never { let call_pager = self.config.files.iter().any(|file| { if let InputFile::Ordinary(ofile) = file { - return Path::new(ofile.filename()).exists(); + return Path::new(ofile.provided_path()).exists(); } else { return true; } diff --git a/src/inputfile.rs b/src/inputfile.rs index 737d3ea3..e7ed739d 100644 --- a/src/inputfile.rs +++ b/src/inputfile.rs @@ -55,26 +55,23 @@ impl<'a> InputFileReader<'a> { #[derive(Debug, Clone, Copy, PartialEq)] pub struct OrdinaryFile<'a> { path: &'a OsStr, - user_provided_name: Option<&'a OsStr>, + user_provided_path: Option<&'a OsStr>, } impl<'a> OrdinaryFile<'a> { pub fn from_path(path: &'a OsStr) -> OrdinaryFile<'a> { OrdinaryFile { path, - user_provided_name: None, + user_provided_path: None, } } - pub fn from_path_with_name(path: &'a OsStr, user_provided_name: &'a OsStr) -> OrdinaryFile<'a> { - OrdinaryFile { - path, - user_provided_name: Some(user_provided_name), - } + pub fn set_provided_path(&mut self, user_provided_path: &'a OsStr) { + self.user_provided_path = Some(user_provided_path); } - pub(crate) fn filename(&self) -> &'a OsStr { - self.user_provided_name.unwrap_or_else(|| self.path) + pub(crate) fn provided_path(&self) -> &'a OsStr { + self.user_provided_path.unwrap_or_else(|| self.path) } } diff --git a/src/printer.rs b/src/printer.rs index 8ea12407..ab64808d 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -161,7 +161,7 @@ impl<'a> InteractivePrinter<'a> { { if config.style_components.changes() { if let InputFile::Ordinary(ofile) = file { - line_changes = get_git_diff(ofile.filename()); + line_changes = get_git_diff(ofile.provided_path()); } } } @@ -235,7 +235,7 @@ impl<'a> Printer for InteractivePrinter<'a> { if Some(ContentType::BINARY) == self.content_type && !self.config.show_nonprintable { let input = match file { InputFile::Ordinary(ofile) => { - format!("file '{}'", &ofile.filename().to_string_lossy()) + format!("file '{}'", &ofile.provided_path().to_string_lossy()) } InputFile::StdIn(Some(name)) => name.to_string_lossy().into_owned(), InputFile::StdIn(None) => "STDIN".to_owned(), @@ -276,7 +276,7 @@ impl<'a> Printer for InteractivePrinter<'a> { let (prefix, name) = match file { InputFile::Ordinary(ofile) => ( "File: ", - Cow::from(ofile.filename().to_string_lossy().to_owned()), + Cow::from(ofile.provided_path().to_string_lossy().to_owned()), ), InputFile::StdIn(Some(name)) => { ("File: ", Cow::from(name.to_string_lossy().to_owned())) -- cgit v1.2.3