diff --git a/src/main.rs b/src/main.rs index 752f1a7..fed72aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use std::fs::File; use std::io::Write; use std::os::fd::AsRawFd; use std::path::PathBuf; -use std::sync::Arc; use clap::Parser; use miette::{IntoDiagnostic, Result}; @@ -67,70 +66,72 @@ fn main() -> Result<()> { }; let mut file = File::open(args.file).into_diagnostic()?; - let strings = Arc::new(lib::get_strings(&mut file, args.min_str_len).into_diagnostic()?); + let strings = lib::get_strings(&mut file, args.min_str_len).into_diagnostic()?; eprintln!("Found {} strings", strings.len()); - let pointers = Arc::new(lib::get_pointers(&mut file, args.big_endian).into_diagnostic()?); + let pointers = lib::get_pointers(&mut file, args.big_endian).into_diagnostic()?; eprintln!("Found {} pointers", pointers.len()); let start = 0x0000_0000u64; let end = 0x1_0000_0000u64; let total_pages = ((end - start) / (args.page_size as u64)) as usize; let chunk_size = (end - start) / (args.threads as u64); - let ranges = (start..=end) - .step_by(chunk_size as usize) - .zip((start + chunk_size..=end).step_by(chunk_size as usize)); - let tasks: Vec<_> = ranges - .map(|(start, end)| { - let progress = Arc::new(lib::ComputeProgress::new()); + let progress_counters: Vec<_> = + (0..args.threads).map(|_| lib::ComputeProgress::new()).collect(); - let strings = Arc::clone(&strings); - let pointers = Arc::clone(&pointers); - let child_progress = Arc::clone(&progress); + let mut thread_results = std::thread::scope(|s| { + let ranges = (start..=end) + .step_by(chunk_size as usize) + .zip((start + chunk_size..=end).step_by(chunk_size as usize)); - let thread = std::thread::spawn(move || { - lib::compute_matches( - &strings, - &pointers, - start, - end, - args.page_size, - Some(&child_progress), - ) - }); + let tasks: Vec<_> = ranges.zip(progress_counters.iter()) + .map(|((start, end), counter)| { + let strings = &strings; + let pointers = &pointers; + s.spawn(move || { + lib::compute_matches( + strings, + pointers, + start, + end, + args.page_size, + Some(counter), + ) + }) + }) + .collect(); - (thread, progress) - }) - .collect(); + if let Some(mut term) = term { + loop { + std::thread::sleep(std::time::Duration::from_millis(100)); + if tasks.iter().any(|thread| !thread.is_finished()) { + term.carriage_return().unwrap(); + term.delete_line().unwrap(); - if let Some(mut term) = term { - loop { - std::thread::sleep(std::time::Duration::from_millis(100)); - if tasks.iter().any(|(thread, _)| !thread.is_finished()) { - term.carriage_return().unwrap(); - term.delete_line().unwrap(); + let completed_pages: usize = + progress_counters.iter().map(|prg| prg.num_completed()).sum(); - let completed_pages: usize = tasks.iter().map(|(_, prg)| prg.num_completed()).sum(); - - eprint!("{} / {} ({}%)", completed_pages, total_pages, - completed_pages * 100 / total_pages); - std::io::stderr().flush().unwrap(); - } else { - break; + eprint!("{} / {} ({}%)", completed_pages, total_pages, + completed_pages * 100 / total_pages); + std::io::stderr().flush().unwrap(); + } else { + break; + } } + + term.carriage_return().unwrap(); + term.delete_line().unwrap(); + eprintln!("Scan complete"); + } else { + eprintln!("Scanning..."); } - term.carriage_return().unwrap(); - term.delete_line().unwrap(); - eprintln!("Scan complete"); - } else { - eprintln!("Scanning..."); - } - - let mut thread_results: Vec<_> = tasks - .into_iter() - .map(|(thread, _)| thread.join().map_err(std::panic::resume_unwind).unwrap()) - .collect::>(); + let thread_results: Vec<_> = tasks + .into_iter() + .map(|thread| thread.join().map_err(std::panic::resume_unwind).unwrap()) + .collect::>(); + thread_results + }); eprintln!("Results:");