diff --git a/src/context.rs b/src/context.rs index ce200206..7a35dd40 100644 --- a/src/context.rs +++ b/src/context.rs @@ -5,9 +5,8 @@ use crate::modules; use clap::ArgMatches; use git2::{Repository, RepositoryState}; use once_cell::sync::OnceCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env; -use std::ffi::OsStr; use std::fs; use std::path::{Path, PathBuf}; use std::string::String; @@ -23,8 +22,8 @@ pub struct Context<'a> { /// The current working directory that starship is being called in. pub current_dir: PathBuf, - /// A vector containing the full paths of all the files in `current_dir`. - dir_files: OnceCell>, + /// A struct containing directory contents in a lookup-optimised format. + dir_contents: OnceCell, /// Properties to provide to modules. pub properties: HashMap<&'a str, String>, @@ -80,7 +79,7 @@ impl<'a> Context<'a> { config, properties, current_dir, - dir_files: OnceCell::new(), + dir_contents: OnceCell::new(), repo: OnceCell::new(), shell, } @@ -117,7 +116,7 @@ impl<'a> Context<'a> { // see ScanDir for methods pub fn try_begin_scan(&'a self) -> Option> { Some(ScanDir { - dir_files: self.get_dir_files().ok()?, + dir_contents: self.dir_contents().ok()?, files: &[], folders: &[], extensions: &[], @@ -145,26 +144,101 @@ impl<'a> Context<'a> { }) } - pub fn get_dir_files(&self) -> Result<&Vec, std::io::Error> { - let start_time = SystemTime::now(); - let scan_timeout = Duration::from_millis(self.config.get_root_config().scan_timeout); + pub fn dir_contents(&self) -> Result<&DirContents, std::io::Error> { + self.dir_contents.get_or_try_init(|| { + let timeout = Duration::from_millis(self.config.get_root_config().scan_timeout); + DirContents::from_path_with_timeout(&self.current_dir, timeout) + }) + } +} - self.dir_files - .get_or_try_init(|| -> Result, std::io::Error> { - let dir_files = fs::read_dir(&self.current_dir)? - .take_while(|_item| { - SystemTime::now().duration_since(start_time).unwrap() < scan_timeout - }) - .filter_map(Result::ok) - .map(|entry| entry.path()) - .collect::>(); +#[derive(Debug)] +pub struct DirContents { + // HashSet of all files, no folders, relative to the base directory given at construction. + files: HashSet, + // HashSet of all file names, e.g. the last section without any folders, as strings. + file_names: HashSet, + // HashSet of all folders, relative to the base directory given at construction. + folders: HashSet, + // HashSet of all extensions found, without dots, e.g. "js" instead of ".js". + extensions: HashSet, +} - log::trace!( - "Building a vector of directory files took {:?}", - SystemTime::now().duration_since(start_time).unwrap() - ); - Ok(dir_files) - }) +impl DirContents { + fn from_path(base: &PathBuf) -> Result { + Self::from_path_with_timeout(base, Duration::from_secs(30)) + } + + fn from_path_with_timeout(base: &PathBuf, timeout: Duration) -> Result { + let start = SystemTime::now(); + + let mut folders: HashSet = HashSet::new(); + let mut files: HashSet = HashSet::new(); + let mut file_names: HashSet = HashSet::new(); + let mut extensions: HashSet = HashSet::new(); + + fs::read_dir(base)? + .take_while(|_| SystemTime::now().duration_since(start).unwrap() < timeout) + .filter_map(Result::ok) + .for_each(|entry| { + let path = PathBuf::from(entry.path().strip_prefix(base).unwrap()); + if entry.path().is_dir() { + folders.insert(path); + } else { + if !path.to_string_lossy().starts_with('.') { + path.extension() + .map(|ext| extensions.insert(ext.to_string_lossy().to_string())); + } + if let Some(file_name) = path.file_name() { + file_names.insert(file_name.to_string_lossy().to_string()); + } + files.insert(path); + } + }); + + log::trace!( + "Building HashSets of directory files, folders and extensions took {:?}", + SystemTime::now().duration_since(start).unwrap() + ); + + Ok(DirContents { + folders, + files, + file_names, + extensions, + }) + } + + pub fn files(&self) -> impl Iterator { + self.files.iter() + } + + pub fn has_file(&self, path: &str) -> bool { + self.files.contains(Path::new(path)) + } + + pub fn has_file_name(&self, name: &str) -> bool { + self.file_names.contains(name) + } + + pub fn has_any_file_name(&self, names: &[&str]) -> bool { + names.iter().any(|name| self.has_file_name(name)) + } + + pub fn has_folder(&self, path: &str) -> bool { + self.folders.contains(Path::new(path)) + } + + pub fn has_any_folder(&self, paths: &[&str]) -> bool { + paths.iter().any(|path| self.has_folder(path)) + } + + pub fn has_extension(&self, ext: &str) -> bool { + self.extensions.contains(ext) + } + + pub fn has_any_extension(&self, exts: &[&str]) -> bool { + exts.iter().any(|ext| self.has_extension(ext)) } fn get_shell() -> Shell { @@ -196,7 +270,7 @@ pub struct Repo { // A struct of Criteria which will be used to verify current PathBuf is // of X language, criteria can be set via the builder pattern pub struct ScanDir<'a> { - dir_files: &'a Vec, + dir_contents: &'a DirContents, files: &'a [&'a str], folders: &'a [&'a str], extensions: &'a [&'a str], @@ -221,48 +295,12 @@ impl<'a> ScanDir<'a> { /// based on the current Pathbuf check to see /// if any of this criteria match or exist and returning a boolean pub fn is_match(&self) -> bool { - self.dir_files.iter().any(|path| { - if path.is_dir() { - path_has_name(path, self.folders) - } else { - path_has_name(path, self.files) || has_extension(path, self.extensions) - } - }) + self.dir_contents.has_any_extension(self.extensions) + || self.dir_contents.has_any_folder(self.folders) + || self.dir_contents.has_any_file_name(self.files) } } -/// checks to see if the pathbuf matches a file or folder name -pub fn path_has_name<'a>(dir_entry: &PathBuf, names: &'a [&'a str]) -> bool { - let found_file_or_folder_name = names.iter().find(|file_or_folder_name| { - dir_entry - .file_name() - .and_then(OsStr::to_str) - .unwrap_or_default() - == **file_or_folder_name - }); - - match found_file_or_folder_name { - Some(name) => !name.is_empty(), - None => false, - } -} - -/// checks if pathbuf doesn't start with a dot and matches any provided extension -pub fn has_extension<'a>(dir_entry: &PathBuf, extensions: &'a [&'a str]) -> bool { - if let Some(file_name) = dir_entry.file_name() { - if file_name.to_string_lossy().starts_with('.') { - return false; - } - return extensions.iter().any(|ext| { - dir_entry - .extension() - .and_then(OsStr::to_str) - .map_or(false, |e| e == *ext) - }); - } - false -} - fn get_current_branch(repository: &Repository) -> Option { let head = repository.head().ok()?; let shorthand = head.shorthand(); @@ -284,69 +322,73 @@ pub enum Shell { mod tests { use super::*; - #[test] - fn test_path_has_name() { - let mut buf = PathBuf::from("/"); - let files = vec!["package.json"]; - - assert_eq!(path_has_name(&buf, &files), false); - - buf.set_file_name("some-file.js"); - assert_eq!(path_has_name(&buf, &files), false); - - buf.set_file_name("package.json"); - assert_eq!(path_has_name(&buf, &files), true); + fn testdir(paths: &[&str]) -> Result { + let dir = tempfile::tempdir()?; + for path in paths { + let p = dir.path().join(Path::new(path)); + if let Some(parent) = p.parent() { + fs::create_dir_all(parent)?; + } + fs::File::create(p)?.sync_all()?; + } + Ok(dir) } #[test] - fn test_has_extension() { - let mut buf = PathBuf::from("/"); - let extensions = vec!["js"]; + fn test_scan_dir() -> Result<(), Box> { + let empty = testdir(&[])?; + let empty_dc = DirContents::from_path(&PathBuf::from(empty.path()))?; - assert_eq!(has_extension(&buf, &extensions), false); + assert_eq!( + ScanDir { + dir_contents: &empty_dc, + files: &["package.json"], + extensions: &["js"], + folders: &["node_modules"], + } + .is_match(), + false + ); - buf.set_file_name("some-file.rs"); - assert_eq!(has_extension(&buf, &extensions), false); + let rust = testdir(&["README.md", "Cargo.toml", "src/main.rs"])?; + let rust_dc = DirContents::from_path(&PathBuf::from(rust.path()))?; + assert_eq!( + ScanDir { + dir_contents: &rust_dc, + files: &["package.json"], + extensions: &["js"], + folders: &["node_modules"], + } + .is_match(), + false + ); - buf.set_file_name(".some-file.js"); - assert_eq!(has_extension(&buf, &extensions), false); + let java = testdir(&["README.md", "src/com/test/Main.java", "pom.xml"])?; + let java_dc = DirContents::from_path(&PathBuf::from(java.path()))?; + assert_eq!( + ScanDir { + dir_contents: &java_dc, + files: &["package.json"], + extensions: &["js"], + folders: &["node_modules"], + } + .is_match(), + false + ); - buf.set_file_name("some-file.js"); - assert_eq!(has_extension(&buf, &extensions), true) - } + let node = testdir(&["README.md", "node_modules/lodash/main.js", "package.json"])?; + let node_dc = DirContents::from_path(&PathBuf::from(node.path()))?; + assert_eq!( + ScanDir { + dir_contents: &node_dc, + files: &["package.json"], + extensions: &["js"], + folders: &["node_modules"], + } + .is_match(), + true + ); - #[test] - fn test_criteria_scan_fails() { - let failing_criteria = ScanDir { - dir_files: &vec![PathBuf::new()], - files: &["package.json"], - extensions: &["js"], - folders: &["node_modules"], - }; - - // fails if buffer does not match any criteria - assert_eq!(failing_criteria.is_match(), false); - - let failing_dir_criteria = ScanDir { - dir_files: &vec![PathBuf::from("/package.js/dog.go")], - files: &["package.json"], - extensions: &["js"], - folders: &["node_modules"], - }; - - // fails when passed a pathbuf dir matches extension path - assert_eq!(failing_dir_criteria.is_match(), false); - } - - #[test] - fn test_criteria_scan_passes() { - let passing_criteria = ScanDir { - dir_files: &vec![PathBuf::from("package.json")], - files: &["package.json"], - extensions: &["js"], - folders: &["node_modules"], - }; - - assert_eq!(passing_criteria.is_match(), true); + Ok(()) } } diff --git a/src/modules/dotnet.rs b/src/modules/dotnet.rs index 738e38a9..5ca4d636 100644 --- a/src/modules/dotnet.rs +++ b/src/modules/dotnet.rs @@ -165,8 +165,8 @@ fn get_pinned_sdk_version(json: &str) -> Option { fn get_local_dotnet_files<'a>(context: &'a Context) -> Result>, std::io::Error> { Ok(context - .get_dir_files()? - .iter() + .dir_contents()? + .files() .filter_map(|p| { get_dotnet_file_type(p).map(|t| DotNetFile { path: p.as_ref(), diff --git a/src/modules/rust.rs b/src/modules/rust.rs index b18692b8..e11fa430 100644 --- a/src/modules/rust.rs +++ b/src/modules/rust.rs @@ -1,4 +1,3 @@ -use std::ffi::OsStr; use std::path::Path; use std::process::{Command, Output}; use std::{env, fs}; @@ -107,13 +106,11 @@ fn find_rust_toolchain_file(context: &Context) -> Option { Some(line.trim().to_owned()) } - if let Some(path) = context - .get_dir_files() - .ok()? - .iter() - .find(|p| p.file_name() == Some(OsStr::new("rust-toolchain"))) + if let Ok(true) = context + .dir_contents() + .map(|dir| dir.has_file("rust-toolchain")) { - if let Some(toolchain) = read_first_line(path) { + if let Some(toolchain) = read_first_line(Path::new("rust-toolchain")) { return Some(toolchain); } }