From 581e725b1e3b4e0e9b8d8860146cc1edd1411664 Mon Sep 17 00:00:00 2001 From: yushuoqi Date: Fri, 5 Jun 2026 15:17:28 +0800 Subject: [PATCH] fast_io.rs: use sendfile or splice if copy_file_range is not supported MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use `splice` and attempt to expand the pipe size when necessary. - Modify the `reliable_copy_file_range()` function to sequentially attempt `copy_file_range` → `splice` → `sendfile` → `write`. - Add Linux-based unit tests. Closes: #443 --- src/sed/fast_io.rs | 378 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 360 insertions(+), 18 deletions(-) diff --git a/src/sed/fast_io.rs b/src/sed/fast_io.rs index e6017bd..d95a943 100644 --- a/src/sed/fast_io.rs +++ b/src/sed/fast_io.rs @@ -861,7 +861,7 @@ impl OutputBuffer { #[cfg(unix)] fn reliable_write(fd: i32, ptr: *const u8, len: usize) -> std::io::Result { // A thin Write-compatible wrapper around a raw file descriptor - // This allows us to issue and utilize the write_all implementatin. + // This allows us to issue and utilize the write_all implementation. struct FdWriter(RawFd); impl Write for FdWriter { @@ -885,6 +885,98 @@ fn reliable_write(fd: i32, ptr: *const u8, len: usize) -> std::io::Result Ok(len) } +/// Wrapper to get current pipe buffer size +#[cfg(all(target_os = "linux", target_env = "gnu"))] +fn get_pipe_size(fd: i32) -> std::io::Result { + let ret = unsafe { libc::fcntl(fd, libc::F_GETPIPE_SZ) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret as usize) + } +} + +/// Wrapper to set pipe buffer size +#[cfg(all(target_os = "linux", target_env = "gnu"))] +fn set_pipe_size(fd: i32, size: usize) -> std::io::Result<()> { + let ret = unsafe { libc::fcntl(fd, libc::F_SETPIPE_SZ, size as libc::c_int) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } +} + +/// Try to splice data from input file to output. +/// Only works on Linux when both fds support splice. +/// Return the number of bytes written, or None if splice is not supported. +#[cfg(all(target_os = "linux", target_env = "gnu"))] +fn try_splice( + in_fd: i32, + in_off: *mut libc::off_t, + out_fd: i32, + len: usize, +) -> std::io::Result> { + // Try to extend pipe buffer size if output is a pipe + // This helps improve throughput when writing to pipes + if let Ok(current_size) = get_pipe_size(out_fd) { + let desired_size = 1024 * 1024; // 1 MiB + if current_size < desired_size { + let _ = set_pipe_size(out_fd, desired_size); + } + } + + let ret = unsafe { + libc::splice( + in_fd, + in_off, + out_fd, + std::ptr::null_mut(), + len, + libc::SPLICE_F_MOVE | libc::SPLICE_F_MORE, + ) + }; + + if ret < 0 { + let err = io::Error::last_os_error(); + match err.raw_os_error() { + Some(libc::ENOSYS) | Some(libc::EOPNOTSUPP) | Some(libc::EINVAL) => { + // splice not supported for these fds + Ok(None) + } + _ => Err(err), + } + } else { + Ok(Some(ret as usize)) + } +} + +/// Try to use sendfile to copy data from input to output. +/// Only works on Linux for regular files to any output. +/// Return the number of bytes written, or None if sendfile is not supported. +#[cfg(all(target_os = "linux", target_env = "gnu"))] +fn try_sendfile( + in_fd: i32, + mut in_off: libc::off_t, + out_fd: i32, + len: usize, +) -> std::io::Result> { + let ret = unsafe { libc::sendfile(in_fd, out_fd, &raw mut in_off, len) }; + + if ret < 0 { + let err = io::Error::last_os_error(); + match err.raw_os_error() { + Some(libc::ENOSYS) | Some(libc::EOPNOTSUPP) | Some(libc::EINVAL) => { + // sendfile not supported for these fds + Ok(None) + } + _ => Err(err), + } + } else { + Ok(Some(ret as usize)) + } +} + /// Copy efficiently len data from the input to the output file. /// Fall back to write(2) if the platform doesn't support copy_file_range(2). /// Return the number of bytes written. @@ -914,7 +1006,7 @@ fn portable_copy_file_range( } /// Copy efficiently len data from the input to the output file. -/// Handle partial copies and fall back to write(2) if the +/// Handle partial copies and fall back to alternative methods if the /// file system or options don't support copy_file_range(2). /// Return the number of bytes written. #[cfg(all(target_os = "linux", target_env = "gnu"))] @@ -926,33 +1018,79 @@ fn reliable_copy_file_range( len: usize, ) -> std::io::Result { let mut pending = len; + let mut copy_method: i32 = 0; // 0 = copy_file_range, 1 = splice, 2 = sendfile, 3 = write + while pending > 0 { - let ret = unsafe { - libc::copy_file_range( - in_fd, - &raw mut in_off, - out_fd, - std::ptr::null_mut(), // Use and update output offset - pending, - 0, - ) + let ret = match copy_method { + 0 => unsafe { + libc::copy_file_range( + in_fd, + &raw mut in_off, + out_fd, + std::ptr::null_mut(), + pending, + 0, + ) + }, + 1 => { + // Use splice + match try_splice(in_fd, &raw mut in_off, out_fd, pending) { + Ok(Some(written)) => written as isize, + Ok(None) => { + // Splice not supported, try sendfile + copy_method = 2; + continue; + } + Err(e) => return Err(e), + } + } + 2 => { + // Use sendfile + match try_sendfile(in_fd, in_off, out_fd, pending) { + Ok(Some(written)) => { + in_off += written as libc::off_t; + written as isize + } + Ok(None) => { + // Sendfile not supported, fall back to write + copy_method = 3; + continue; + } + Err(e) => return Err(e), + } + } + _ => { + // Use write as final fallback + return reliable_write(out_fd, in_ptr, pending); + } }; + if ret < 0 { let err = io::Error::last_os_error(); - return match err.raw_os_error() { - Some(libc::ENOSYS) | Some(libc::EOPNOTSUPP) | Some(libc::EXDEV) => { - // Fallback to write(2). - reliable_write(out_fd, in_ptr, pending) + match err.raw_os_error() { + Some(libc::ENOSYS) + | Some(libc::EOPNOTSUPP) + | Some(libc::EXDEV) + | Some(libc::EINVAL) => { + // Try next method + copy_method = copy_method.saturating_add(1); + if copy_method > 3 { + // All methods exhausted + return reliable_write(out_fd, in_ptr, pending); + } + continue; } - _ => Err(err), - }; + _ => return Err(err), + } } else if ret == 0 { // EOF reached break; } + pending -= ret as usize; } - Ok(len) + + Ok(len - pending) } /// Copy efficiently len data from the input to the output file. @@ -2021,4 +2159,208 @@ mod tests { file.read_to_string(&mut out).unwrap(); assert_eq!(out, "baz\n"); } + + /////////////////////////////// + // Unit tests for pipe optimization functions + /////////////////////////////// + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + #[test] + fn test_get_pipe_size_from_pipe() { + // Create a pipe + let mut fds = [0; 2]; + assert_eq!(unsafe { libc::pipe(fds.as_mut_ptr()) }, 0); + let read_fd = fds[0]; + let write_fd = fds[1]; + + // Get the pipe size - should succeed and return positive value + let size = get_pipe_size(write_fd).expect("get_pipe_size failed"); + assert!(size > 0, "pipe size should be positive"); + + // Clean up + unsafe { + libc::close(read_fd); + libc::close(write_fd); + } + } + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + #[test] + fn test_set_pipe_size() { + // Create a pipe + let mut fds = [0; 2]; + assert_eq!(unsafe { libc::pipe(fds.as_mut_ptr()) }, 0); + let read_fd = fds[0]; + let write_fd = fds[1]; + + // Get original size + let original_size = get_pipe_size(write_fd).expect("get_pipe_size failed"); + + // Try to set a larger buffer size + let new_size = (original_size * 2).min(1024 * 1024); // Cap at 1 MiB + if new_size > original_size { + set_pipe_size(write_fd, new_size).ok(); // Ignore error - may fail on some systems + + // Verify size was changed (if operation succeeded) + if let Ok(actual_size) = get_pipe_size(write_fd) { + assert!( + actual_size >= original_size, + "pipe size should not decrease" + ); + } + } + + // Clean up + unsafe { + libc::close(read_fd); + libc::close(write_fd); + } + } + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + #[test] + fn test_try_splice_to_pipe() { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::thread; + + // Create input file + let mut infile = tempfile().unwrap(); + let data = b"test splice data"; + infile.write_all(data).unwrap(); + infile.rewind().unwrap(); + + // Create a pipe for output + let mut fds = [0; 2]; + assert_eq!(unsafe { libc::pipe(fds.as_mut_ptr()) }, 0); + let read_fd = fds[0]; + let write_fd = fds[1]; + + let read_count = Arc::new(AtomicUsize::new(0)); + let read_count_clone = Arc::clone(&read_count); + + // Spawn a reader thread that will consume from the pipe + let reader_thread = thread::spawn(move || { + let mut buf = vec![0u8; 1024]; + let mut total_read = 0; + loop { + let n = unsafe { libc::read(read_fd, buf.as_mut_ptr().cast(), buf.len()) }; + if n <= 0 { + break; + } + total_read += n as usize; + } + read_count_clone.store(total_read, Ordering::Relaxed); + unsafe { libc::close(read_fd) }; + }); + + // Try splice + let in_fd = infile.as_raw_fd(); + let mut in_off = 0i64; + let result = try_splice(in_fd, &raw mut in_off, write_fd, data.len()); + + // Close write end to signal EOF to reader + unsafe { libc::close(write_fd) }; + + // Wait for reader thread + reader_thread.join().unwrap(); + + // Check result - splice might not be supported on this system + match result { + Ok(Some(written)) => { + assert!(written > 0, "splice should write some data"); + let bytes_read = read_count.load(Ordering::Relaxed); + assert!(bytes_read > 0, "data should be read from pipe"); + } + Ok(None) => { + // splice not supported on this system - that's OK + } + Err(e) => { + // Some errors are acceptable (e.g., unsupported filesystem) + eprintln!("splice error: {}", e); + } + } + } + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + #[test] + fn test_try_sendfile_to_file() { + // Create input file + let mut infile = tempfile().unwrap(); + let data = b"test sendfile data with some content"; + infile.write_all(data).unwrap(); + infile.rewind().unwrap(); + + // Create output file + let mut outfile = tempfile().unwrap(); + + let in_fd = infile.as_raw_fd(); + let out_fd = outfile.as_raw_fd(); + + // Try sendfile + let result = try_sendfile(in_fd, 0, out_fd, data.len()); + + match result { + Ok(Some(written)) => { + // sendfile may return 0 on some filesystems like tmpfs + // In that case, try using write instead + if written == 0 { + eprintln!("sendfile returned 0 (likely tmpfs or other special filesystem)"); + } else { + // Verify the data was written + outfile.rewind().unwrap(); + let mut buf = Vec::new(); + outfile.read_to_end(&mut buf).unwrap(); + assert!( + !buf.is_empty(), + "output file should contain data when bytes written > 0" + ); + } + } + Ok(None) => { + // sendfile not supported on this system or with these file types + eprintln!("sendfile not supported on this system/filesystem combination"); + } + Err(e) => { + // Some errors are acceptable (e.g., unsupported filesystem, EINVAL for temporary files) + eprintln!("sendfile error (acceptable): {}", e); + } + } + } + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + #[test] + fn test_reliable_copy_file_range_with_fallback() { + // Create input file + let mut infile = tempfile().unwrap(); + let data = b"test copy_file_range fallback"; + infile.write_all(data).unwrap(); + infile.rewind().unwrap(); + + // Create output file + let mut outfile = tempfile().unwrap(); + + let in_fd = infile.as_raw_fd(); + let out_fd = outfile.as_raw_fd(); + let in_ptr = data.as_ptr(); + let in_off = 0i64; + + // This should try copy_file_range and fallback to other methods if needed + let written = reliable_copy_file_range(in_ptr, in_fd, in_off, out_fd, data.len()); + + match written { + Ok(bytes) => { + assert!(bytes > 0, "should write some data"); + + // Verify the data was written + outfile.rewind().unwrap(); + let mut buf = Vec::new(); + outfile.read_to_end(&mut buf).unwrap(); + assert!(!buf.is_empty(), "output file should contain data"); + } + Err(e) => { + eprintln!("copy failed: {}", e); + } + } + } }