attempt at speed...

This commit is contained in:
Kaden Frisk 2024-10-15 17:53:16 -05:00
parent 16a2f6571a
commit ca402704ca
3 changed files with 196 additions and 135 deletions

62
Cargo.lock generated
View file

@ -2,6 +2,15 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "aligned"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "377e4c0ba83e4431b10df45c1d4666f178ea9c552cac93e60c3a88bf32785923"
dependencies = [
"as-slice",
]
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.15" version = "0.6.15"
@ -51,6 +60,15 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "as-slice"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516"
dependencies = [
"stable_deref_trait",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.6.0" version = "2.6.0"
@ -63,6 +81,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.20" version = "4.5.20"
@ -131,6 +155,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.1" version = "1.70.1"
@ -149,6 +179,28 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "num_cpus"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.2" version = "1.20.2"
@ -186,6 +238,12 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
@ -314,6 +372,10 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
name = "zap" name = "zap"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"aligned",
"clap", "clap",
"libc",
"nix",
"num_cpus",
"tempfile", "tempfile",
] ]

View file

@ -4,5 +4,9 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
aligned = "0.4.2"
clap = { version = "4.5.20", features = ["derive"] } clap = { version = "4.5.20", features = ["derive"] }
libc = "0.2.159"
nix = { version = "0.29.0", features = ["fs", "uio"] }
num_cpus = "1.16.0"
tempfile = "3.13.0" tempfile = "3.13.0"

View file

@ -1,10 +1,18 @@
use std::fs::{File, OpenOptions}; use libc;
use std::io::{self, BufReader, Error, ErrorKind, Read, Write}; use nix::fcntl::{open, OFlag};
use nix::sys::stat::Mode;
use nix::unistd::{close, fsync};
use std::fs::File;
use std::io::{self, Error, ErrorKind, Write};
use std::os::fd::AsRawFd;
use std::path::Path; use std::path::Path;
use std::process::{exit, Command}; use std::process::Command;
use std::sync::{Arc, Mutex};
use std::thread;
const BUFFER_SIZE: usize = 4096; // Define a buffer size constant
fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>> { fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>> {
// Check if the ISO file exists
if !Path::new(iso_path).exists() { if !Path::new(iso_path).exists() {
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::NotFound, ErrorKind::NotFound,
@ -12,7 +20,6 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
))); )));
} }
// Check if the destination is the root directory
if destination == "/" { if destination == "/" {
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::InvalidInput, ErrorKind::InvalidInput,
@ -20,12 +27,10 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
))); )));
} }
// Skip checks for devices like /dev/sda
if destination.starts_with("/dev/") { if destination.starts_with("/dev/") {
return None; // Allow flashing to block devices return None;
} }
// Check if the destination exists
if !Path::new(destination).exists() { if !Path::new(destination).exists() {
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::NotFound, ErrorKind::NotFound,
@ -33,7 +38,6 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
))); )));
} }
// Ensure the destination is not a regular file
if Path::new(destination).is_file() { if Path::new(destination).is_file() {
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::InvalidInput, ErrorKind::InvalidInput,
@ -44,7 +48,6 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
))); )));
} }
// Check for permission to write to the destination
if let Err(err) = File::create(Path::new(destination).join("test_permission")) { if let Err(err) = File::create(Path::new(destination).join("test_permission")) {
return Some(Err(Error::new( return Some(Err(Error::new(
ErrorKind::PermissionDenied, ErrorKind::PermissionDenied,
@ -54,6 +57,7 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
), ),
))); )));
} }
if Path::new(destination).join("test_permission").exists() { if Path::new(destination).join("test_permission").exists() {
let _ = std::fs::remove_file(Path::new(destination).join("test_permission")); let _ = std::fs::remove_file(Path::new(destination).join("test_permission"));
} }
@ -63,7 +67,6 @@ fn validate_input(iso_path: &str, destination: &str) -> Option<Result<(), Error>
fn print_metadata(iso_file: File) -> Result<(), Error> { fn print_metadata(iso_file: File) -> Result<(), Error> {
let metadata = iso_file.metadata()?; let metadata = iso_file.metadata()?;
println!("ISO file size: {} bytes", metadata.len());
println!("ISO file last modified: {:?}", metadata.modified()?); println!("ISO file last modified: {:?}", metadata.modified()?);
println!("ISO file permissions: {:?}", metadata.permissions()); println!("ISO file permissions: {:?}", metadata.permissions());
println!( println!(
@ -73,155 +76,147 @@ fn print_metadata(iso_file: File) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn open_file_with_direct_io(path: &str, write: bool) -> Result<i32, nix::Error> {
let flags = if write {
OFlag::O_WRONLY | OFlag::O_DIRECT
} else {
OFlag::O_RDONLY | OFlag::O_DIRECT
};
open(path, flags, Mode::empty())
}
pub fn flash_iso(iso_path: &str, destination: &str) -> Result<(), io::Error> { pub fn flash_iso(iso_path: &str, destination: &str) -> Result<(), io::Error> {
// Add a confirmation step if the destination is /dev/sda
if destination == "/dev/sda" {
println!(
"Warning: You are about to flash to /dev/sda. This could overwrite important data."
);
println!("Are you sure you want to continue? (y/n): ");
let mut input = String::new();
io::stdin().read_line(&mut input)?;
if input.trim().to_lowercase() != "y" {
println!("Operation cancelled.");
exit(0);
}
}
// Validate input, returns if invalid
if let Some(value) = validate_input(iso_path, destination) { if let Some(value) = validate_input(iso_path, destination) {
return value; return value;
} }
// Print metadata of the ISO file
let iso_file = File::open(iso_path)?; let iso_file = File::open(iso_path)?;
let metadata = iso_file.metadata()?; let metadata = iso_file.metadata()?;
let total_size = metadata.len(); let total_size = metadata.len();
println!("ISO file size: {} bytes", total_size);
print_metadata(iso_file)?; print_metadata(iso_file)?;
// Open the ISO file for reading // Open files with direct I/O
let iso_file = File::open(iso_path)?; let iso_fd = open_file_with_direct_io(iso_path, false)
let mut reader = BufReader::new(iso_file); .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let dest_fd = open_file_with_direct_io(destination, true)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// Open the destination block device for writing let iso_fd = Arc::new(Mutex::new(iso_fd));
let mut dest_file = OpenOptions::new().write(true).open(destination)?; let dest_fd = Arc::new(Mutex::new(dest_fd));
// Buffer to hold chunks of data while reading/writing // Get the number of CPUs
let mut buffer = [0u8; 4096]; let num_threads: u64 = num_cpus::get() as u64;
let mut bytes_written: u64 = 0; let chunk_size = total_size / num_threads as u64;
let progress = Arc::new(Mutex::new(0u64));
println!("Flashing {} to {}", iso_path, destination); println!("Flashing {} to {}", iso_path, destination);
// Copy data from ISO to the destination in chunks let mut handles = vec![];
loop {
let bytes_read = reader.read(&mut buffer)?; for i in 0..num_threads {
let iso_fd = Arc::clone(&iso_fd);
let dest_fd = Arc::clone(&dest_fd);
let progress = Arc::clone(&progress);
let handle = thread::spawn(move || -> Result<(), io::Error> {
let start = i as u64 * chunk_size;
let end = if i == num_threads - 1 {
total_size
} else {
(i as u64 + 1) * chunk_size
};
// Get the filesystem block size
let mut stat = std::mem::MaybeUninit::<libc::stat>::uninit();
if unsafe { libc::fstat(*iso_fd.lock().unwrap(), stat.as_mut_ptr()) } != 0 {
return Err(io::Error::last_os_error());
}
let block_size = unsafe { stat.assume_init().st_blksize } as usize;
// Use aligned buffer
let mut buffer = vec![0u8; block_size];
let mut current = start;
while current < end {
let bytes_to_read = std::cmp::min(BUFFER_SIZE, (end - current) as usize);
// Read
let bytes_read = {
let fd = iso_fd.lock().unwrap();
nix::sys::uio::pread(
unsafe { std::os::unix::io::BorrowedFd::borrow_raw(fd.as_raw_fd()) },
&mut buffer[..bytes_to_read],
current as i64,
)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
};
if bytes_read == 0 { if bytes_read == 0 {
break; // EOF reached break; // EOF reached
} }
// Log the number of bytes read // Ensure buffer is filled to the required size
// println!("Read {} bytes from ISO", bytes_read); if bytes_read < block_size {
for i in bytes_read..block_size {
dest_file.write_all(&buffer[..bytes_read])?; buffer[i] = 0;
bytes_written += bytes_read as u64;
// Calculate and display progress every 1 MB (1_048_576 bytes)
if bytes_written % (1 << 20) == 0 || bytes_written == total_size {
let progress = (bytes_written * 100) / total_size;
println!("Progress: {}%", progress);
} }
} }
// Ensure all data is flushed to the device // Write
dest_file.flush()?; {
let fd = dest_fd.lock().unwrap();
nix::sys::uio::pwrite(
unsafe { std::os::unix::io::BorrowedFd::borrow_raw(*fd) },
&buffer[..block_size],
current as i64,
)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
};
current += bytes_read as u64;
// Update progress
let mut progress = progress.lock().unwrap();
*progress += bytes_read as u64;
// Sync after every buffer write
{
let fd = dest_fd.lock().unwrap();
fsync(*fd).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
}
if *progress % (1 << 20) == 0 || *progress == total_size {
let percent = (*progress * 100) / total_size;
print!("\rProgress: {}%", percent);
io::stdout().flush()?;
}
}
Ok(())
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap()?;
}
println!();
// Close file descriptors
close(*iso_fd.lock().unwrap()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
close(*dest_fd.lock().unwrap()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// Sync data to disk explicitly
println!("Syncing data to disk...");
Command::new("sync")
.status()
.expect("Failed to flush data to disk");
// Now that flushing is complete, print success message
println!( println!(
"ISO file {} flashed successfully to {}", "ISO file {} flashed successfully to {}",
iso_path, destination iso_path, destination
); );
// Explicitly exit if desired (generally unnecessary)
// std::process::exit(0);
// Log that it's going to sync data to disk
println!("Syncing data to disk, process will not close until this is complete.");
Command::new("sync")
.status()
.expect("Failed to flush data to disk");
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_flash_iso_nonexistent_iso() {
let result = flash_iso("nonexistent.iso", "/tmp");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::NotFound);
}
#[test]
fn test_flash_iso_nonexistent_destination() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = flash_iso(temp_file.path().to_str().unwrap(), "/nonexistent");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::NotFound);
}
#[test]
fn test_flash_iso_destination_is_file() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = flash_iso(
temp_file.path().to_str().unwrap(),
temp_file.path().to_str().unwrap(),
);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidInput);
}
#[test]
fn test_flash_iso_fail() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let temp_dir = tempdir().unwrap();
let result = flash_iso(
temp_file.path().to_str().unwrap(),
temp_dir.path().to_str().unwrap(),
);
assert!(!result.is_ok());
}
#[test]
fn test_flash_iso_to_root() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = flash_iso(temp_file.path().to_str().unwrap(), "/");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidInput);
}
#[test]
fn test_flash_iso_no_permission() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let temp_dir = tempdir().unwrap();
let no_permission_dir = temp_dir.path().join("no_permission");
std::fs::create_dir(&no_permission_dir).unwrap();
let _ = std::fs::set_permissions(
&no_permission_dir,
<std::fs::Permissions as std::os::unix::fs::PermissionsExt>::from_mode(0o000),
);
let result = flash_iso(
temp_file.path().to_str().unwrap(),
no_permission_dir.to_str().unwrap(),
);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::PermissionDenied);
}
}