diff --git a/include/hmll/linux/backend/iouring.h b/include/hmll/linux/backend/iouring.h index 828ebc2..f38a32d 100644 --- a/include/hmll/linux/backend/iouring.h +++ b/include/hmll/linux/backend/iouring.h @@ -105,8 +105,13 @@ static inline int hmll_io_uring_slot_find_available(const struct hmll_iouring_io { for (unsigned i = 0; i < HMLL_URING_IOBUSY_WORDS; ++i) { const int pos = __builtin_ffsll(~iobusy.bits[i]); - if (pos > 0) - return (int)(i * 64) + pos - 1; + if (pos > 0) { + const int slot = (int)(i * 64) + pos - 1; + // Ensure we don't return slots beyond QUEUE_DEPTH + if (slot >= (int)HMLL_URING_QUEUE_DEPTH) + return -1; + return slot; + } } return -1; } diff --git a/lib/linux/backend/iouring.c b/lib/linux/backend/iouring.c index ef2fb95..7e56e8e 100644 --- a/lib/linux/backend/iouring.c +++ b/lib/linux/backend/iouring.c @@ -1,4 +1,7 @@ +#include #include +#include +#include #include #include "hmll/hmll.h" #include "hmll/memory.h" @@ -14,6 +17,61 @@ #include #endif +/* ── NUMA topology helpers ──────────────────────────────────────────── */ + +/** + * Get the NUMA node for a CUDA device by reading sysfs via the PCI bus ID. + * Returns -1 on failure. + */ +static int hmll_get_gpu_numa_node(const int device_idx) +{ +#if defined(__HMLL_CUDA_ENABLED__) + char pci_bus_id[64] = {0}; + if (cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), device_idx) != cudaSuccess) + return -1; + + /* Convert to lowercase for sysfs path (CUDA returns uppercase hex) */ + for (char *p = pci_bus_id; *p; p++) + *p = (*p >= 'A' && *p <= 'Z') ? (*p + 32) : *p; + + char path[256]; + snprintf(path, sizeof(path), "/sys/bus/pci/devices/%s/numa_node", pci_bus_id); + + FILE *f = fopen(path, "r"); + if (!f) return -1; + + int node = -1; + if (fscanf(f, "%d", &node) != 1) node = -1; + fclose(f); + + return node; +#else + (void)device_idx; + return -1; +#endif +} + +/** + * Get the first CPU core on a given NUMA node by parsing sysfs. + * Returns -1 on failure. + */ +static int hmll_get_first_cpu_on_node(const int numa_node) +{ + if (numa_node < 0) return -1; + + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/node/node%d/cpulist", numa_node); + + FILE *f = fopen(path, "r"); + if (!f) return -1; + + int first_cpu = -1; + if (fscanf(f, "%d", &first_cpu) != 1) first_cpu = -1; + fclose(f); + + return first_cpu; +} + /* ── runtime kernel version detection ───────────────────────────────── */ static inline unsigned hmll_kernel_version_internal(unsigned maj, unsigned min) { @@ -80,6 +138,7 @@ static struct hmll_error hmll_io_uring_register_staging_buffers( } unsigned char *arena = hmll_alloc(HMLL_URING_QUEUE_DEPTH * HMLL_URING_BUFFER_SIZE, device, HMLL_MEM_STAGING); + if (!arena) { ctx->error = HMLL_ERR(HMLL_ERR_ALLOCATION_FAILED); return ctx->error; @@ -139,13 +198,14 @@ static inline void hmll_io_uring_reclaim_slots( struct hmll_io_uring_cuda_context *dctx = fetcher->device_ctx; for (size_t i = 0; i < HMLL_URING_QUEUE_DEPTH; ++i) { struct hmll_io_uring_cuda_context *cd = dctx + i; - if (hmll_io_uring_slot_is_busy(fetcher->iobusy, i)) { - if (cd->state == HMLL_CUDA_STREAM_MEMCPY && cudaEventQuery(cd->done) == cudaSuccess) { + if (hmll_io_uring_slot_is_busy(fetcher->iobusy, i) && cd->state == HMLL_CUDA_STREAM_MEMCPY) { + if (cudaEventQuery(cd->done) == cudaSuccess) { hmll_io_uring_cuda_stream_set_idle(&cd->state); hmll_io_uring_slot_set_available(&fetcher->iobusy, cd->slot); } } } + #else HMLL_UNUSED(fetcher); HMLL_UNUSED(device); @@ -196,6 +256,7 @@ static inline void hmll_io_uring_prep_sqe( #if defined(__HMLL_CUDA_ENABLED__) else if (hmll_device_is_cuda(device)) { struct hmll_io_uring_cuda_context *dctx = fetcher->device_ctx; + dctx[slot].offset = offset; io_uring_prep_read_fixed(sqe, iofile, fetcher->iovecs[slot].iov_base, len, offset, slot); io_uring_sqe_set_data(sqe, dctx + slot); @@ -296,7 +357,7 @@ static ssize_t hmll_io_uring_fetch_loop( if (count == 0 && n_inflight > 0) { struct io_uring_cqe *cqe; if (unlikely(io_uring_wait_cqe(&fetcher->ioring, &cqe) < 0)) { - ctx->error = HMLL_ERR(HMLL_ERR_IO_ERROR); + ctx->error = HMLL_SYS_ERR(errno); return -1; } cqes[0] = cqe; @@ -758,8 +819,19 @@ static struct hmll_error hmll_io_uring_queue_init( const struct hmll_device device ) { (void)ctx; + + /* Detect NUMA node for the target device and pin SQPOLL thread accordingly */ + int numa_node = -1; + int sq_cpu = 0; + + if (hmll_device_is_cuda(device)) { + numa_node = hmll_get_gpu_numa_node(device.idx); + int cpu = hmll_get_first_cpu_on_node(numa_node); + if (cpu >= 0) sq_cpu = cpu; + } + struct io_uring_params params = { - .sq_thread_cpu = 0, + .sq_thread_cpu = (unsigned)sq_cpu, .flags = hmll_io_uring_get_setup_flags(), .sq_thread_idle = 500 }; @@ -771,6 +843,33 @@ static struct hmll_error hmll_io_uring_queue_init( return HMLL_ERR(HMLL_ERR_CUDA_SET_DEVICE_FAILED); } + /* Pin this thread to the GPU's NUMA node for optimal memory allocation */ + if (numa_node >= 0) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/node/node%d/cpulist", numa_node); + FILE *f = fopen(path, "r"); + if (f) { + char buf[1024] = {0}; + if (fgets(buf, sizeof(buf), f)) { + /* Parse cpulist format: "0-23,48-71" */ + char *tok = strtok(buf, ",\n"); + while (tok) { + int lo, hi; + if (sscanf(tok, "%d-%d", &lo, &hi) == 2) { + for (int c = lo; c <= hi; c++) CPU_SET(c, &cpuset); + } else if (sscanf(tok, "%d", &lo) == 1) { + CPU_SET(lo, &cpuset); + } + tok = strtok(NULL, ",\n"); + } + } + fclose(f); + sched_setaffinity(0, sizeof(cpuset), &cpuset); + } + } + struct hmll_io_uring_cuda_context *data = calloc(HMLL_URING_QUEUE_DEPTH, sizeof(struct hmll_io_uring_cuda_context)); backend->device_ctx = (void *)data; @@ -887,7 +986,7 @@ void hmll_io_uring_destroy(void *ptr) } } - munmap(backend->iovecs[0].iov_base, HMLL_URING_QUEUE_DEPTH * sizeof(struct iovec)); + cudaFreeHost(backend->iovecs[0].iov_base); free(backend->device_ctx); backend->device_ctx = NULL; } diff --git a/lib/rust/hmll/examples/basic.rs b/lib/rust/hmll/examples/basic.rs index 0a2fd8b..579ed2a 100644 --- a/lib/rust/hmll/examples/basic.rs +++ b/lib/rust/hmll/examples/basic.rs @@ -30,18 +30,19 @@ fn main() -> Result<(), Box> { ); // Store in an array to ensure proper lifetime - let sources = [source]; + let source_size = source.size(); + let sources = vec![source]; // Create a weight loader println!("\nCreating weight loader..."); - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto)?; println!("✓ Loader created successfully"); println!(" Device: {}", loader.device()); println!(" Number of sources: {}", loader.num_sources()); // Fetch some data from the beginning of the file let fetch_size = end - start; - let actual_fetch_size = fetch_size.min(sources[0].size()); + let actual_fetch_size = fetch_size.min(source_size); println!( "\nFetching {} bytes ({:.2} MB)...", actual_fetch_size, diff --git a/lib/rust/hmll/examples/multi_files.rs b/lib/rust/hmll/examples/multi_files.rs index 88ebb18..fb10f9a 100644 --- a/lib/rust/hmll/examples/multi_files.rs +++ b/lib/rust/hmll/examples/multi_files.rs @@ -49,7 +49,7 @@ fn main() -> Result<(), Box> { // Create a weight loader for all sources println!("\nCreating weight loader for {} sources...", sources.len()); - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto)?; println!("✓ Loader created successfully"); // Display information about each source diff --git a/lib/rust/hmll/src/buffer.rs b/lib/rust/hmll/src/buffer.rs index 8486095..f0aaa33 100644 --- a/lib/rust/hmll/src/buffer.rs +++ b/lib/rust/hmll/src/buffer.rs @@ -90,16 +90,41 @@ impl From for ops::Range { /// - **Empty**: Zero-length buffer with no memory. /// - **Owned**: Allocated memory that is freed when the buffer is dropped. /// - **SourceView**: Zero-copy pointer into mmap'd memory, kept alive via Arc. -pub struct Buffer { +pub struct BufferInner { buf: hmll_iobuf, kind: BufferKind, } +impl Drop for BufferInner { + fn drop(&mut self) { + if let BufferKind::Owned = self.kind { + if !self.buf.ptr.is_null() { + unsafe { hmll_free_buffer(&mut self.buf) }; + } + } + // For SourceView: the Arc is dropped automatically, decrementing refcount. + // When the last Arc is dropped, SourceHandle::drop() unmaps the memory. + } +} + +// SAFETY: BufferInner is safe to send across threads because: +// - The buffer data is immutable after creation (read-only access) +// - Owned buffers: memory is allocated by hmll, only freed on drop +// - SourceView buffers: memory is mmap'd and kept alive by Arc +// - No internal mutation occurs after construction +// +// Callers must NOT mutate data through `as_ptr()` - doing so would be UB. +unsafe impl Send for BufferInner {} +unsafe impl Sync for BufferInner {} + +#[derive(Clone)] +pub struct Buffer(Arc); + impl std::fmt::Debug for Buffer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Buffer") - .field("size", &self.buf.size) - .field("ptr", &self.buf.ptr) + .field("size", &self.0.buf.size) + .field("ptr", &self.0.buf.ptr) .field("device", &self.device()) .field("owned", &self.is_owned()) .finish() @@ -112,14 +137,14 @@ impl Buffer { /// This is useful when you need to represent a zero-length fetch result. #[inline(always)] pub fn empty(device: Device) -> Self { - Self { + Self(Arc::new(BufferInner { buf: hmll_iobuf { size: 0, ptr: std::ptr::null_mut(), device: device.to_raw(), }, kind: BufferKind::Empty, - } + })) } /// Create a new owned buffer from an `hmll_iobuf`. @@ -132,10 +157,10 @@ impl Buffer { /// and that the memory was allocated via hmll allocation functions. #[inline(always)] pub(crate) unsafe fn from_raw_owned(buf: hmll_iobuf) -> Self { - Self { + Self(Arc::new(BufferInner { buf, kind: BufferKind::Owned, - } + })) } /// Create a zero-copy view into mmap'd source memory. @@ -153,28 +178,28 @@ impl Buffer { device: Device, source_handle: Arc, ) -> Self { - Self { + Self(Arc::new(BufferInner { buf: hmll_iobuf { size, ptr, device: device.to_raw(), }, kind: BufferKind::SourceView(source_handle), - } + })) } /// Get the buffer as a byte slice (CPU only). #[inline] pub fn as_slice(&self) -> Option<&[u8]> { if self.device() == Device::Cpu { - if self.buf.ptr.is_null() || self.buf.size == 0 { + if self.0.buf.ptr.is_null() || self.0.buf.size == 0 { // Return empty slice for empty/null buffers Some(&[]) } else { unsafe { Some(std::slice::from_raw_parts( - self.buf.ptr as *const u8, - self.buf.size, + self.0.buf.ptr as *const u8, + self.0.buf.size, )) } } @@ -185,26 +210,26 @@ impl Buffer { /// Get the size of the buffer in bytes. #[inline(always)] - pub const fn len(&self) -> usize { - self.buf.size + pub fn len(&self) -> usize { + self.0.buf.size } /// Check if the buffer is empty. #[inline(always)] - pub const fn is_empty(&self) -> bool { - self.buf.size == 0 + pub fn is_empty(&self) -> bool { + self.0.buf.size == 0 } /// Get the device where the buffer is located. #[inline(always)] pub fn device(&self) -> Device { - Device::from_raw(self.buf.device) + Device::from_raw(self.0.buf.device) } /// Get a raw pointer to the buffer. #[inline(always)] - pub const fn as_ptr(&self) -> *const u8 { - self.buf.ptr as *const u8 + pub fn as_ptr(&self) -> *const u8 { + self.0.buf.ptr as *const u8 } /// Convert to a Vec (copies data if on CPU). @@ -223,18 +248,6 @@ impl Buffer { /// and are kept alive by an Arc reference to the source. #[inline(always)] pub fn is_owned(&self) -> bool { - matches!(self.kind, BufferKind::Owned) - } -} - -impl Drop for Buffer { - fn drop(&mut self) { - if let BufferKind::Owned = self.kind { - if !self.buf.ptr.is_null() { - unsafe { hmll_free_buffer(&mut self.buf) }; - } - } - // For SourceView: the Arc is dropped automatically, decrementing refcount. - // When the last Arc is dropped, SourceHandle::drop() unmaps the memory. + matches!(self.0.kind, BufferKind::Owned) } } diff --git a/lib/rust/hmll/src/error.rs b/lib/rust/hmll/src/error.rs index adc3026..5bab340 100644 --- a/lib/rust/hmll/src/error.rs +++ b/lib/rust/hmll/src/error.rs @@ -39,8 +39,8 @@ pub enum Error { #[error("Buffer too small")] BufferTooSmall, - #[error("I/O error")] - IoError, + #[error("I/O error: {0}")] + IoError(String), #[error("No source provided")] NoSourceProvided, @@ -132,7 +132,17 @@ impl Error { HMLL_ERR_INVALID_RANGE => Error::InvalidRange, HMLL_ERR_BUFFER_ADDR_NOT_ALIGNED => Error::BufferAddrNotAligned, HMLL_ERR_BUFFER_TOO_SMALL => Error::BufferTooSmall, - HMLL_ERR_IO_ERROR => Error::IoError, + HMLL_ERR_IO_ERROR => { + let msg = unsafe { + let ptr = hmll_strerr(err); + if ptr.is_null() { + format!("errno {}", err.sys_err) + } else { + CStr::from_ptr(ptr).to_string_lossy().into_owned() + } + }; + Error::IoError(msg) + } HMLL_ERR_NO_SOURCE_PROVIDED => Error::NoSourceProvided, HMLL_ERR_FILE_NOT_FOUND => Error::FileNotFound(String::new()), HMLL_ERR_FILE_EMPTY => Error::FileEmpty, diff --git a/lib/rust/hmll/src/lib.rs b/lib/rust/hmll/src/lib.rs index f5bcbf5..3161eec 100644 --- a/lib/rust/hmll/src/lib.rs +++ b/lib/rust/hmll/src/lib.rs @@ -22,16 +22,16 @@ //! ```no_run //! # use hmll::{Source, WeightLoader, Device, LoaderKind}; //! # fn main() -> Result<(), Box> { -//! let sources = [ -//! Source::open("shard-00001.bin")?, -//! Source::open("shard-00002.bin")?, +//! let sources = vec![ +//! Source::open("model-00001.bin")?, +//! Source::open("model-00002.bin")?, //! ]; //! -//! let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; +//! let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto)?; //! //! // Fetch bytes from specific file by index -//! let data = loader.fetch(0..1024, 0)?; // from shard 1 -//! let data = loader.fetch(0..1024, 1)?; // from shard 2 +//! let data = loader.fetch(0..1024, 0)?; // from source 1 +//! let data = loader.fetch(0..1024, 1)?; // from source 2 //! # Ok(()) //! # } //! ``` diff --git a/lib/rust/hmll/src/loader.rs b/lib/rust/hmll/src/loader.rs index d552fdf..bc5492d 100644 --- a/lib/rust/hmll/src/loader.rs +++ b/lib/rust/hmll/src/loader.rs @@ -2,12 +2,9 @@ use hmll_sys::hmll_iobuf; -use crate::source::SourceHandle; use crate::{Buffer, Device, Error, Range, Result, Source}; use std::collections::HashSet; -use std::marker::PhantomData; use std::ptr; -use std::sync::Arc; /// Loader backend kind. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -36,8 +33,6 @@ impl LoaderKind { impl Default for LoaderKind { /// Default loader kind is Auto. - /// - /// Hot path - inline always for zero-cost default. #[inline(always)] fn default() -> Self { LoaderKind::Auto @@ -59,11 +54,11 @@ impl Default for LoaderKind { /// let source1 = Source::open("model-00001-of-00003.safetensors")?; /// let source2 = Source::open("model-00002-of-00003.safetensors")?; /// let source3 = Source::open("model-00003-of-00003.safetensors")?; -/// let sources = [source1, source2, source3]; +/// let sources = vec![source1, source2, source3]; /// /// // Create a loader /// let mut loader = WeightLoader::new( -/// &sources, +/// sources, /// Device::Cpu, /// LoaderKind::Auto /// )?; @@ -74,17 +69,14 @@ impl Default for LoaderKind { /// # Ok(()) /// # } /// ``` -pub struct WeightLoader<'a> { +#[derive(Debug)] +pub struct WeightLoader { context: Box, - /// Raw sources passed to C layer - sources: Vec, - /// Arc handles for each source - keeps mmap alive while views exist - source_handles: Vec>, + sources: Vec, device: Device, - _marker: PhantomData<&'a ()>, } -impl<'a> WeightLoader<'a> { +impl WeightLoader { /// Create a new weight loader. /// /// # Arguments @@ -96,15 +88,11 @@ impl<'a> WeightLoader<'a> { /// # Errors /// /// Returns an error if the loader initialization fails. - pub fn new(sources: &'a [Source], device: Device, kind: LoaderKind) -> Result { + pub fn new(sources: Vec, device: Device, kind: LoaderKind) -> Result { if sources.is_empty() { return Err(Error::InvalidRange); } - // Clone Arc handles to keep sources alive while views exist - let source_handles: Vec> = - sources.iter().map(|s| s.handle().clone()).collect(); - let mut sources_vec: Vec = sources.iter().map(|s| *s.as_raw()).collect(); let mut context = Box::new(hmll_sys::hmll { @@ -130,19 +118,16 @@ impl<'a> WeightLoader<'a> { // For mmap backend, close fds immediately - mmap is independent of fd. // NOTE: Could detect Auto resolving to mmap and close too, but left out for now. if kind == LoaderKind::Mmap { - for handle in &source_handles { - let inner_ptr = &handle.inner as *const _ as *mut hmll_sys::hmll_source; - hmll_sys::hmll_source_close(inner_ptr); + for source in &sources { + source.close_fd(); } } } Ok(Self { context, - sources: sources_vec, - source_handles, + sources, device, - _marker: PhantomData, }) } @@ -172,8 +157,8 @@ impl<'a> WeightLoader<'a> { /// # use hmll::{Source, WeightLoader, Device, LoaderKind, Range}; /// # fn main() -> Result<(), Box> { /// # let source = Source::open("model.safetensors")?; - /// # let sources = [source]; - /// # let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + /// # let sources = vec![source]; + /// # let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto)?; /// /// // Fetch multiple weight tensors in a single batch /// let ranges = vec![ @@ -189,7 +174,7 @@ impl<'a> WeightLoader<'a> { /// # } /// ``` pub fn fetchv(&mut self, ranges: &[Range], file_index: usize) -> Result> { - if file_index >= self.sources.len() { + if file_index >= self.num_sources() { return Err(Error::InvalidFileIndex(file_index)); } @@ -289,8 +274,8 @@ impl<'a> WeightLoader<'a> { /// # use hmll::{Source, WeightLoader, Device, LoaderKind}; /// # fn main() -> Result<(), Box> { /// # let source = Source::open("model.safetensors")?; - /// # let sources = [source]; - /// # let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto)?; + /// # let sources = vec![source]; + /// # let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto)?; /// /// // Fetch first 1MB from the first file /// let data = loader.fetch(0..1024 * 1024, 0)?; @@ -301,7 +286,7 @@ impl<'a> WeightLoader<'a> { pub fn fetch>(&mut self, range: R, file_index: usize) -> Result { let range = range.into(); - if file_index >= self.sources.len() { + if file_index >= self.num_sources() { return Err(Error::InvalidFileIndex(file_index)); } @@ -364,8 +349,8 @@ impl<'a> WeightLoader<'a> { /// # use hmll::{Source, WeightLoader, Device, LoaderKind}; /// # fn main() -> Result<(), Box> { /// # let source = Source::open("model.safetensors")?; - /// # let sources = [source]; - /// # let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap)?; + /// # let sources = vec![source]; + /// # let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap)?; /// /// // Get a zero-copy view into the first 1MB /// let view = loader.fetch_view(0..1024 * 1024, 0)?; @@ -378,10 +363,6 @@ impl<'a> WeightLoader<'a> { pub fn fetch_view>(&mut self, range: R, file_index: usize) -> Result { let range = range.into(); - if file_index >= self.sources.len() { - return Err(Error::InvalidFileIndex(file_index)); - } - if range.is_empty() { return Ok(Buffer::empty(self.device)); } @@ -392,10 +373,14 @@ impl<'a> WeightLoader<'a> { } // Clone the Arc to keep the source (and its mmap) alive - let source_handle = self.source_handles[file_index].clone(); + let source = self + .sources + .get(file_index) + .cloned() + .ok_or(Error::InvalidFileIndex(file_index))?; // Get the mmap'd content pointer directly from the source - let content_ptr = source_handle.inner.content; + let content_ptr = source.as_raw().content; if content_ptr.is_null() { return Err(Error::MmapFailed); } @@ -406,7 +391,7 @@ impl<'a> WeightLoader<'a> { let view_size = range.len(); // Create a view buffer - Arc keeps source (and mmap) alive - Ok(unsafe { Buffer::from_source_view(view_ptr, view_size, self.device, source_handle) }) + Ok(unsafe { Buffer::from_source_view(view_ptr, view_size, self.device, source.handle()) }) } /// Get the device this loader is configured for. @@ -424,11 +409,13 @@ impl<'a> WeightLoader<'a> { /// Get information about a specific source file. #[inline] pub fn source_info(&self, index: usize) -> Option { - self.sources.get(index).map(|s| SourceInfo { size: s.size }) + self.sources + .get(index) + .map(|s| SourceInfo { size: s.size() }) } } -impl<'a> Drop for WeightLoader<'a> { +impl Drop for WeightLoader { fn drop(&mut self) { unsafe { hmll_sys::hmll_destroy(self.context.as_mut()); @@ -459,7 +446,7 @@ mod tests { #[test] fn test_empty_sources() { - let result = WeightLoader::new(&[], Device::Cpu, LoaderKind::Auto); + let result = WeightLoader::new(vec![], Device::Cpu, LoaderKind::Auto); assert!(result.is_err()); } @@ -479,9 +466,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); assert_eq!(loader.device(), Device::Cpu); @@ -497,9 +484,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffer = loader @@ -519,9 +506,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffer = loader @@ -540,9 +527,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffer = loader.fetch(5..5, 0).expect("Failed to fetch empty range"); @@ -557,9 +544,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let result = loader.fetch(0..10, 99); @@ -580,9 +567,9 @@ mod tests { let source2 = Source::open(temp2.path()).expect("Failed to open source 2"); let source3 = Source::open(temp3.path()).expect("Failed to open source 3"); - let sources = [source1, source2, source3]; + let sources = vec![source1, source2, source3]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); assert_eq!(loader.num_sources(), 3); @@ -609,9 +596,9 @@ mod tests { let temp_file = create_test_file(&content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffer = loader @@ -630,9 +617,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let info = loader.source_info(0); @@ -649,9 +636,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffer = loader.fetch(0..content.len(), 0).expect("Failed to fetch"); @@ -666,9 +653,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let buffer = loader @@ -684,9 +671,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let view = loader @@ -705,9 +692,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); // Get a view of just the uppercase letters @@ -726,9 +713,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let view = loader @@ -745,9 +732,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let result = loader.fetch_view(0..10, 99); @@ -762,9 +749,9 @@ mod tests { let view = { let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); loader @@ -784,9 +771,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); // Create multiple views from the same source @@ -817,9 +804,9 @@ mod tests { let source1 = Source::open(temp1.path()).expect("Failed to open source 1"); let source2 = Source::open(temp2.path()).expect("Failed to open source 2"); - let sources = [source1, source2]; + let sources = vec![source1, source2]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let view1 = loader @@ -843,9 +830,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![Range::new(0, 10), Range::new(10, 20), Range::new(20, 36)]; @@ -864,9 +851,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![Range::new(0, content.len())]; @@ -882,9 +869,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let buffers = loader.fetchv(&[], 0).expect("Failed to fetchv"); @@ -897,9 +884,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![ @@ -924,9 +911,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![Range::new(0, 0), Range::new(5, 5), Range::new(10, 3)]; @@ -943,9 +930,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![Range::new(0, 10)]; @@ -963,9 +950,9 @@ mod tests { let source1 = Source::open(temp1.path()).expect("Failed to open source 1"); let source2 = Source::open(temp2.path()).expect("Failed to open source 2"); - let sources = [source1, source2]; + let sources = vec![source1, source2]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges1 = vec![Range::new(0, content1.len())]; @@ -985,9 +972,9 @@ mod tests { let temp_file = create_test_file(&content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); // Split into 16 equal chunks @@ -1012,9 +999,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Auto) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Auto) .expect("Failed to create loader"); let ranges = vec![Range::new(0, 10), Range::new(10, 20), Range::new(36, 46)]; @@ -1044,9 +1031,9 @@ mod tests { let temp_file = create_test_file(content); let source = Source::open(temp_file.path()).expect("Failed to open source"); - let sources = [source]; + let sources = vec![source]; - let mut loader = WeightLoader::new(&sources, Device::Cpu, LoaderKind::Mmap) + let mut loader = WeightLoader::new(sources, Device::Cpu, LoaderKind::Mmap) .expect("Failed to create mmap loader"); let ranges = vec![Range::new(0, 10), Range::new(10, 20)]; diff --git a/lib/rust/hmll/src/source.rs b/lib/rust/hmll/src/source.rs index 1d9d6fa..c4ad9fa 100644 --- a/lib/rust/hmll/src/source.rs +++ b/lib/rust/hmll/src/source.rs @@ -100,8 +100,17 @@ impl Source { /// /// This is used internally for creating views that outlive the loader. #[inline(always)] - pub(crate) fn handle(&self) -> &Arc { - &self.handle + pub(crate) fn handle(&self) -> Arc { + self.handle.clone() + } + + /// Close the file descriptor associated to the [`Source`]. + /// Useful for mmap when we don't need a dangling file descriptor. + pub(crate) fn close_fd(&self) { + unsafe { + let inner_ptr = &self.handle.inner as *const _ as *mut hmll_sys::hmll_source; + hmll_sys::hmll_source_close(inner_ptr); + } } } diff --git a/lib/unix/memory.c b/lib/unix/memory.c index 5d79e5c..1ebc79b 100644 --- a/lib/unix/memory.c +++ b/lib/unix/memory.c @@ -1,6 +1,8 @@ // // Created by mfuntowicz on 1/13/26. // +#include +#include #include #include "hmll/hmll.h" @@ -43,8 +45,11 @@ void *hmll_alloc(const size_t size, const struct hmll_device device, const int f if (hmll_device_is_cuda(device) && flags == HMLL_MEM_DEVICE) cudaMalloc(&ptr, size); - if (hmll_device_is_cuda(device) && flags == HMLL_MEM_STAGING) - cudaHostAlloc(&ptr, size, cudaHostAllocDefault | cudaHostAllocPortable); + if (hmll_device_is_cuda(device) && flags == HMLL_MEM_STAGING) { + cudaError_t err = cudaHostAlloc(&ptr, size, cudaHostAllocDefault | cudaHostAllocPortable); + if (err != cudaSuccess) + ptr = NULL; + } #endif return ptr;