diff --git a/Cargo.toml b/Cargo.toml index b434509..6c8c2d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ categories = ["science"] readme = "README.md" rust-version = "1.85" +[dependencies] +libc = "0.2.178" + [build-dependencies] bindgen = "0.72.0" ureq = "2.0" diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index 8365194..eb16d4f 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -106,7 +106,10 @@ fn main() -> Result<(), CatBoostError> { Err(e) => println!(" Batch prediction error: {}", e), } - // Example 4: Model validation + // Example 4: Feature names and indices + display_feature_names(&model); + + // Example 5: Model validation println!("\n=== Model Validation ==="); validate_model(&model)?; @@ -153,6 +156,37 @@ fn display_model_info(model: &Model) -> Result<(), CatBoostError> { Ok(()) } +#[cfg(catboost_feature_indices)] +fn display_feature_names(model: &Model) { + println!("\n=== Feature Names ==="); + match model.get_feature_names() { + Ok(names) => println!(" All feature names: {:?}", names), + Err(e) => println!(" get_feature_names error: {}", e), + } + match model.get_float_feature_names() { + Ok(names) => println!(" Float feature names: {:?}", names), + Err(e) => println!(" get_float_feature_names error: {}", e), + } + match model.get_cat_feature_names() { + Ok(names) => println!(" Cat feature names: {:?}", names), + Err(e) => println!(" get_cat_feature_names error: {}", e), + } + match model.get_text_feature_names() { + Ok(names) => println!(" Text feature names: {:?}", names), + Err(e) => println!(" get_text_feature_names error: {}", e), + } + match model.get_embedding_feature_names() { + Ok(names) => println!(" Embedding feature names: {:?}", names), + Err(e) => println!(" get_embedding_feature_names error: {}", e), + } +} + +#[cfg(not(catboost_feature_indices))] +fn display_feature_names(_model: &Model) { + println!("\n=== Feature Names ==="); + println!(" Feature name queries not available in this CatBoost version."); +} + fn validate_model(model: &Model) -> Result<(), CatBoostError> { println!("Validating model..."); diff --git a/src/model.rs b/src/model.rs index cbd106b..fc24c08 100644 --- a/src/model.rs +++ b/src/model.rs @@ -5,6 +5,15 @@ use std::ffi::{CStr, CString}; use std::os::raw::c_char; use std::path::Path; +/// RAII guard that frees a C-allocated pointer on drop. +struct CFreeGuard(*mut T); + +impl Drop for CFreeGuard { + fn drop(&mut self) { + unsafe { libc::free(self.0 as *mut libc::c_void) }; + } +} + pub struct Model { handle: *mut sys::ModelCalcerHandle, /// Buffer owner for zero-copy loading - keeps the buffer alive for model's lifetime @@ -311,6 +320,187 @@ impl Model { }) } + /// # Safety + /// + /// This function is unsafe because it dereferences a raw pointer and assumes a memory + /// allocation contract with an external C API. + /// + /// - `ptr` must be a valid pointer to a C-allocated buffer containing `count` elements of type `T`, + /// or it must be a null pointer if `count` is 0. + /// - The buffer must have been allocated by a `malloc`-compatible allocator, as the CatBoost C API + /// documentation for functions like `GetFloatFeatureIndices` and `GetModelUsedFeaturesNames` + /// stipulates that the caller is responsible for freeing the returned buffer. The standard C + /// mechanism for this is `free()`. + /// (Source: https://github.com/catboost/catboost/blob/master/catboost/libs/model_interface/c_api.h) + /// + /// This function takes ownership of the buffer and frees it with `libc::free` after copying + /// the data into a Rust `Vec`. + unsafe fn from_c_allocated_buffer(ptr: *mut T, count: usize) -> Vec { + if ptr.is_null() { + return Vec::new(); + } + let _guard = CFreeGuard(ptr); + unsafe { std::slice::from_raw_parts(ptr, count) }.to_vec() + } + + /// Converts a C-style array of feature indices into a `Vec`, freeing the C buffer. + fn get_feature_indices_from_c( + indices_ptr: *mut usize, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if indices_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + // SAFETY: The contract for CatBoost functions like `GetFloatFeatureIndices` is that they + // return a `malloc`-allocated buffer that the caller must free. `from_c_allocated_buffer` + // upholds this contract by copying the data and then calling `libc::free`. + let indices = unsafe { Self::from_c_allocated_buffer(indices_ptr, count) }; + Ok(indices) + } + + /// Converts a C-style array of C strings into a `Vec`, freeing all associated C memory. + /// + /// Uses a drop guard to ensure all C memory is freed even if a panic occurs mid-iteration + /// (e.g. OOM during `into_owned()` or `Vec::push`). + fn get_feature_names_from_c( + names_ptr: *mut *mut std::ffi::c_char, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if names_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + + let str_ptrs = unsafe { Self::from_c_allocated_buffer(names_ptr, count) }; + let guards: Vec> = str_ptrs.into_iter().map(CFreeGuard).collect(); + guards + .iter() + .map(|g| { + if g.0.is_null() { + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + unsafe { CStr::from_ptr(g.0) } + .to_str() + .map(|s| s.to_owned()) + .map_err(|_| CatBoostError { + description: err_msg.to_owned(), + }) + }) + .collect() + } + + /// Get names of specific type of features used in model, + /// returns error if index out of bounds + #[cfg(catboost_feature_indices)] + fn get_specific_feature_names( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let all_names = self.get_feature_names()?; + let indices = self.get_feature_indices(indices_fn, err_msg)?; + indices + .into_iter() + .map(|i| { + all_names + .get(i) + .ok_or_else(|| CatBoostError { + description: format!("feature index {} out of bounds", i), + }) + .cloned() + }) + .collect() + } + + /// Get names of features used in model + #[cfg(catboost_feature_indices)] + pub fn get_feature_names(&self) -> CatBoostResult> { + unsafe { + let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut(); + let mut count: usize = 0; + + let ok = sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count); + CatBoostError::check_return_value(ok)?; + + Self::get_feature_names_from_c( + names_ptr, + count, + "GetModelUsedFeaturesNames returned null pointer", + ) + } + } + + #[cfg(catboost_feature_indices)] + fn get_feature_indices( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let mut indices_ptr: *mut usize = std::ptr::null_mut(); + let mut count: usize = 0; + CatBoostError::check_return_value(unsafe { + indices_fn(self.handle, &mut indices_ptr, &mut count) + })?; + Self::get_feature_indices_from_c(indices_ptr, count, err_msg) + } + + /// Get names of float features used in model + #[cfg(catboost_feature_indices)] + pub fn get_float_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetFloatFeatureIndices, + "GetFloatFeatureIndices returned null pointer", + ) + } + + /// Get names of cat features used in model + #[cfg(catboost_feature_indices)] + pub fn get_cat_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetCatFeatureIndices, + "GetCatFeatureIndices returned null pointer", + ) + } + + /// Get names of text features used in model + #[cfg(catboost_feature_indices)] + pub fn get_text_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetTextFeatureIndices, + "GetTextFeatureIndices returned null pointer", + ) + } + + /// Get names of embedding features used in model + #[cfg(catboost_feature_indices)] + pub fn get_embedding_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetEmbeddingFeatureIndices, + "GetEmbeddingFeatureIndices returned null pointer", + ) + } + /// Get expected float feature count for model pub fn get_float_features_count(&self) -> usize { unsafe { sys::GetFloatFeaturesCount(self.handle) } @@ -367,3 +557,182 @@ impl Drop for Model { unsafe { sys::ModelCalcerDelete(self.handle) }; } } + +#[cfg(test)] +mod tests { + use super::*; + use std::ptr; + + unsafe fn make_c_buffer(data: &[T]) -> *mut T { + unsafe { + let ptr = libc::malloc(std::mem::size_of::() * data.len()) as *mut T; + assert!(!ptr.is_null(), "malloc failed"); + for (i, val) in data.iter().enumerate() { + ptr.add(i).write(*val); + } + ptr + } + } + + unsafe fn make_c_string_array(strings: &[&str]) -> *mut *mut c_char { + unsafe { + let outer = libc::malloc(std::mem::size_of::<*mut c_char>() * strings.len()) + as *mut *mut c_char; + assert!(!outer.is_null()); + for (i, s) in strings.iter().enumerate() { + let len = s.len(); + let buf = libc::malloc(len + 1) as *mut c_char; + assert!(!buf.is_null()); + std::ptr::copy_nonoverlapping(s.as_ptr() as *const c_char, buf, len); + *buf.add(len) = 0; + *outer.add(i) = buf; + } + outer + } + } + + // from_c_allocated_buffer tests + + #[test] + fn test_buffer_null_ptr() { + let result = unsafe { Model::from_c_allocated_buffer::(ptr::null_mut(), 0) }; + assert!(result.is_empty()); + } + + #[test] + fn test_buffer_null_ptr_nonzero_count() { + let result = unsafe { Model::from_c_allocated_buffer::(ptr::null_mut(), 5) }; + assert!(result.is_empty()); + } + + #[test] + fn test_buffer_single_element() { + let ptr = unsafe { make_c_buffer(&[42usize]) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, 1) }; + assert_eq!(result, vec![42usize]); + } + + #[test] + fn test_buffer_multiple_elements() { + let data: Vec = (0..100).collect(); + let ptr = unsafe { make_c_buffer(&data) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, data.len()) }; + assert_eq!(result, data); + } + + #[test] + fn test_buffer_f64() { + let data = [1.5f64, 2.7, 3.14, 0.0, -1.0]; + let ptr = unsafe { make_c_buffer(&data) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, data.len()) }; + assert_eq!(result, data.to_vec()); + } + + #[test] + fn test_buffer_i32() { + let data = [-1i32, 0, 1, i32::MAX, i32::MIN]; + let ptr = unsafe { make_c_buffer(&data) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, data.len()) }; + assert_eq!(result, data.to_vec()); + } + + #[test] + fn test_buffer_u8() { + let data: Vec = (0..=255).collect(); + let ptr = unsafe { make_c_buffer(&data) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, data.len()) }; + assert_eq!(result, data); + } + + #[test] + fn test_buffer_zero_count() { + let ptr = unsafe { libc::malloc(8) as *mut usize }; + assert!(!ptr.is_null()); + let result = unsafe { Model::from_c_allocated_buffer(ptr, 0) }; + assert!(result.is_empty()); + } + + #[test] + fn test_buffer_large() { + let data: Vec = (0..10_000).collect(); + let ptr = unsafe { make_c_buffer(&data) }; + let result = unsafe { Model::from_c_allocated_buffer(ptr, data.len()) }; + assert_eq!(result, data); + } + + // get_feature_indices_from_c tests + + #[test] + fn test_indices_null_zero_count() { + let result = Model::get_feature_indices_from_c(ptr::null_mut(), 0, "error"); + assert_eq!(result.unwrap(), Vec::::new()); + } + + #[test] + fn test_indices_null_nonzero_count() { + let result = Model::get_feature_indices_from_c(ptr::null_mut(), 3, "custom error"); + assert_eq!(result.unwrap_err().description, "custom error"); + } + + #[test] + fn test_indices_valid() { + let data = [0usize, 2, 5, 10]; + let ptr = unsafe { make_c_buffer(&data) }; + let result = Model::get_feature_indices_from_c(ptr, data.len(), "error").unwrap(); + assert_eq!(result, data.to_vec()); + } + + // get_feature_names_from_c tests + + #[test] + fn test_names_null_zero_count() { + let result = Model::get_feature_names_from_c(ptr::null_mut(), 0, "error"); + assert_eq!(result.unwrap(), Vec::::new()); + } + + #[test] + fn test_names_null_nonzero_count() { + let result = Model::get_feature_names_from_c(ptr::null_mut(), 3, "custom error"); + assert_eq!(result.unwrap_err().description, "custom error"); + } + + #[test] + fn test_names_single() { + let ptr = unsafe { make_c_string_array(&["feature_0"]) }; + let result = Model::get_feature_names_from_c(ptr, 1, "error").unwrap(); + assert_eq!(result, vec!["feature_0"]); + } + + #[test] + fn test_names_multiple() { + let names = ["alpha", "beta", "gamma", "delta"]; + let ptr = unsafe { make_c_string_array(&names) }; + let result = Model::get_feature_names_from_c(ptr, names.len(), "error").unwrap(); + assert_eq!(result, names.map(String::from).to_vec()); + } + + #[test] + fn test_names_empty_strings() { + let names = ["", "", "nonempty", ""]; + let ptr = unsafe { make_c_string_array(&names) }; + let result = Model::get_feature_names_from_c(ptr, names.len(), "error").unwrap(); + assert_eq!(result, names.map(String::from).to_vec()); + } + + #[test] + fn test_names_unicode() { + let names = ["café", "naïve", "日本語"]; + let ptr = unsafe { make_c_string_array(&names) }; + let result = Model::get_feature_names_from_c(ptr, names.len(), "error").unwrap(); + assert_eq!(result, names.map(String::from).to_vec()); + } + + #[test] + fn test_names_many() { + let names: Vec = (0..500).map(|i| format!("feature_{}", i)).collect(); + let name_refs: Vec<&str> = names.iter().map(|s| s.as_str()).collect(); + let ptr = unsafe { make_c_string_array(&name_refs) }; + let result = Model::get_feature_names_from_c(ptr, names.len(), "error").unwrap(); + assert_eq!(result, names); + } +}