From 17b40802d138a4f16767a31d843d27c2a6644d5e Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 11:12:24 +0200 Subject: [PATCH 01/11] add bindings and build.rs script. --- .gitignore | 8 ++ build.rs | 288 +++++++++++++++++++++++++++++++++++++ examples/advanced_usage.rs | 80 +++++++++++ examples/basic_usage.rs | 95 ++++++++++++ src/error.rs | 39 +++++ src/lib.rs | 22 +++ src/model.rs | 243 +++++++++++++++++++++++++++++++ src/sys.rs | 5 + wrapper.h | 1 + 9 files changed, 781 insertions(+) create mode 100644 .gitignore create mode 100644 build.rs create mode 100644 examples/advanced_usage.rs create mode 100644 examples/basic_usage.rs create mode 100644 src/error.rs create mode 100644 src/lib.rs create mode 100644 src/model.rs create mode 100644 src/sys.rs create mode 100644 wrapper.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1138cd6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +/target +Cargo.lock +*.json +*.bin +*.model +*.dylib +*.so +*.dll diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..93bed5f --- /dev/null +++ b/build.rs @@ -0,0 +1,288 @@ +extern crate bindgen; + +use std::env; +use std::path::{Path, PathBuf}; +use std::fs; +use std::io; + +fn get_xgboost_version() -> String { + env::var("XGBOOST_VERSION").unwrap_or_else(|_| "3.1.1".to_string()) +} + +fn get_platform_info() -> (String, String) { + let target = env::var("TARGET").unwrap(); + + // Determine OS + let os = if target.contains("apple-darwin") { + "darwin" + } else if target.contains("linux") { + "linux" + } else if target.contains("windows") { + "windows" + } else { + panic!("Unsupported target: {}", target); + }; + + // Determine architecture + let arch = if target.contains("x86_64") { + "x86_64" + } else if target.contains("aarch64") || target.contains("arm64") { + "aarch64" + } else if target.contains("i686") || target.contains("i586") { + "i686" + } else { + panic!("Unsupported architecture for target: {}", target); + }; + + (os.to_string(), arch.to_string()) +} + +fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box> { + let version = get_xgboost_version(); + + // Create the include/xgboost directory + let include_dir = out_dir.join("include/xgboost"); + fs::create_dir_all(&include_dir)?; + + // Download the c_api.h file + let c_api_url = format!( + "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/c_api.h", + version + ); + + println!("cargo:warning=Downloading c_api.h from: {}", c_api_url); + + let response = ureq::get(&c_api_url).call()?; + let status = response.status(); + if status < 200 || status >= 300 { + return Err(format!("Failed to download c_api.h: HTTP {}", status).into()); + } + + let c_api_path = include_dir.join("c_api.h"); + let mut file = fs::File::create(&c_api_path)?; + io::copy(&mut response.into_reader(), &mut file)?; + + // Also download base.h which is referenced by c_api.h + let base_url = format!( + "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/base.h", + version + ); + + println!("cargo:warning=Downloading base.h from: {}", base_url); + + let response = ureq::get(&base_url).call()?; + let status = response.status(); + if status < 200 || status >= 300 { + return Err(format!("Failed to download base.h: HTTP {}", status).into()); + } + + let base_path = include_dir.join("base.h"); + let mut file = fs::File::create(&base_path)?; + io::copy(&mut response.into_reader(), &mut file)?; + + Ok(()) +} + +fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box> { + let (os, arch) = get_platform_info(); + let version = get_xgboost_version(); + + // Determine wheel filename based on platform + let wheel_filename = match (os.as_str(), arch.as_str()) { + ("linux", "x86_64") => format!("xgboost-{}-py3-none-manylinux_2_28_x86_64.whl", version), + ("linux", "aarch64") => format!("xgboost-{}-py3-none-manylinux_2_28_aarch64.whl", version), + ("darwin", "x86_64") => format!("xgboost-{}-py3-none-macosx_10_15_x86_64.whl", version), + ("darwin", "aarch64") => format!("xgboost-{}-py3-none-macosx_12_0_arm64.whl", version), + ("windows", "x86_64") => format!("xgboost-{}-py3-none-win_amd64.whl", version), + _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + }; + + let download_url = format!( + "https://files.pythonhosted.org/packages/py3/x/xgboost/{}", + wheel_filename + ); + + println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url); + + // Download the wheel + let wheel_dir = out_dir.join("wheel"); + fs::create_dir_all(&wheel_dir)?; + let wheel_path = wheel_dir.join(&wheel_filename); + + let response = ureq::get(&download_url).call()?; + let status = response.status(); + if status < 200 || status >= 300 { + return Err(format!("Failed to download wheel: HTTP {}", status).into()); + } + + let mut wheel_file = fs::File::create(&wheel_path)?; + io::copy(&mut response.into_reader(), &mut wheel_file)?; + drop(wheel_file); + + println!("cargo:warning=Extracting wheel: {}", wheel_path.display()); + + // Extract the wheel (it's a ZIP file) + let file = fs::File::open(&wheel_path)?; + let mut archive = zip::ZipArchive::new(file)?; + + // Create libs directory + let lib_dir = out_dir.join("libs"); + fs::create_dir_all(&lib_dir)?; + + // Determine library filename based on OS + let lib_filename = match os.as_str() { + "windows" => "xgboost.dll", + "darwin" => "libxgboost.dylib", + _ => "libxgboost.so", + }; + + // Search for the library file in the wheel + let mut found = false; + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let file_path = file.name().to_string(); + + // Look for the library file (usually in xgboost/lib/) + if file_path.ends_with(lib_filename) { + println!("cargo:warning=Found library at: {}", file_path); + let dest_path = lib_dir.join(lib_filename); + let mut dest = fs::File::create(&dest_path)?; + io::copy(&mut file, &mut dest)?; + found = true; + break; + } + } + + if !found { + return Err(format!("Library file {} not found in wheel", lib_filename).into()); + } + + println!("cargo:warning=Successfully extracted XGBoost library to: {}", lib_dir.display()); + + Ok(()) +} + +fn main() { + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let xgb_include_root = out_dir.join("include"); + + // Download the headers + if let Err(e) = download_xgboost_headers(&out_dir) { + eprintln!("Failed to download XGBoost headers: {}", e); + panic!("Cannot proceed without headers"); + } + + // Download and extract the wheel + if let Err(e) = download_and_extract_wheel(&out_dir) { + eprintln!("Failed to download and extract wheel: {}", e); + panic!("Cannot proceed without compiled library"); + } + + let bindings = bindgen::Builder::default() + .header("wrapper.h") + .clang_arg(format!("-I{}", xgb_include_root.display())) + // Generate bindings for XGB and XGD functions (Booster and DMatrix) + .allowlist_function("XGB.*") + .allowlist_function("XGD.*") + // Allowlist the main types we need + .allowlist_type("BoosterHandle") + .allowlist_type("DMatrixHandle") + .allowlist_type("bst_ulong") + .size_t_is_usize(true) + .generate() + .expect("Unable to generate bindings."); + + bindings + .write_to_file(out_dir.join("bindings.rs")) + .expect("Couldn't write bindings."); + + // Get platform info + let (os, _arch) = get_platform_info(); + + // Determine the library filename based on the OS + let lib_filename = match os.as_str() { + "windows" => "xgboost.dll", + "darwin" => "libxgboost.dylib", + _ => "libxgboost.so", + }; + + // Copy the library from OUT_DIR/libs to the final target directory + let lib_source_path = out_dir.join("libs").join(lib_filename); + + // Find the final output directory (e.g., target/release) + let target_dir = out_dir + .ancestors() + .find(|p| p.ends_with("target")) + .unwrap() + .join(env::var("PROFILE").unwrap()); + + let lib_dest_path = target_dir.join(lib_filename); + fs::copy(&lib_source_path, &lib_dest_path) + .expect("Failed to copy library to target directory"); + + // On macOS/Linux, change the install name/soname to use @loader_path/$ORIGIN + if os == "darwin" { + use std::process::Command; + let _ = Command::new("install_name_tool") + .arg("-id") + .arg(format!("@loader_path/{}", lib_filename)) + .arg(&lib_source_path) + .status(); + let _ = Command::new("install_name_tool") + .arg("-id") + .arg(format!("@loader_path/{}", lib_filename)) + .arg(&lib_dest_path) + .status(); + } else if os == "linux" { + use std::process::Command; + // Use patchelf to set soname (if available) + let _ = Command::new("patchelf") + .arg("--set-soname") + .arg(&lib_filename) + .arg(&lib_source_path) + .output(); + let _ = Command::new("patchelf") + .arg("--set-soname") + .arg(&lib_filename) + .arg(&lib_dest_path) + .output(); + } + + // Set the library search path for the build-time linker + let lib_search_path = out_dir.join("libs"); + println!( + "cargo:rustc-link-search=native={}", + lib_search_path.display() + ); + + // Set the rpath for the run-time linker based on the OS + match os.as_str() { + "darwin" => { + // For macOS, add multiple rpath entries for IDE compatibility + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../.."); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { + println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + } + }, + "linux" => { + // For Linux, use $ORIGIN + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { + println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + } + }, + _ => {} // No rpath needed for Windows + } + + println!("cargo:rustc-link-lib=dylib=xgboost"); +} diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs new file mode 100644 index 0000000..ab1acde --- /dev/null +++ b/examples/advanced_usage.rs @@ -0,0 +1,80 @@ +use xgboost_rust::{Booster, XGBoostResult}; + +fn main() -> XGBoostResult<()> { + println!("XGBoost Rust Bindings - Advanced Usage Example"); + println!("===============================================\n"); + + // Load model + let model_path = "iris_model.json"; + println!("Loading model from: {}", model_path); + + let booster = match Booster::load(model_path) { + Ok(b) => { + println!("✓ Model loaded successfully\n"); + b + } + Err(e) => { + eprintln!("✗ Failed to load model: {}", e); + eprintln!("\nPlease create a model file first. See basic_usage.rs for instructions."); + return Err(e); + } + }; + + // Example: Binary classification prediction + println!("=== Example 1: Normal Prediction ==="); + let data = vec![ + 5.1, 3.5, 1.4, 0.2, + 6.7, 3.0, 5.2, 2.3, + ]; + + let predictions = booster.predict(&data, 2, 4, 0, false)?; + println!("Normal predictions (probabilities): {:?}\n", predictions); + + // Example: Get SHAP values (feature contributions) + println!("=== Example 2: Feature Contributions (SHAP) ==="); + use xgboost_rust::predict_option; + + let shap_values = booster.predict( + &data, + 2, + 4, + predict_option::PRED_CONTRIBS, + false + )?; + + println!("SHAP values for first sample:"); + // SHAP values include one extra value for bias term + let num_features = 4; + for i in 0..=num_features { + if i < num_features { + println!(" Feature {}: {:.4}", i, shap_values[i]); + } else { + println!(" Bias term: {:.4}", shap_values[i]); + } + } + println!(); + + // Example: Save model to a new location + println!("=== Example 3: Save Model ==="); + let save_path = "iris_model_copy.json"; + booster.save(save_path)?; + println!("✓ Model saved to: {}\n", save_path); + + // Example: Load from buffer (useful for embedded models) + println!("=== Example 4: Load from Buffer ==="); + let buffer = std::fs::read(model_path)?; + let booster_from_buffer = Booster::load_from_buffer(&buffer)?; + println!("✓ Model loaded from buffer ({} bytes)", buffer.len()); + + let num_features = booster_from_buffer.num_features()?; + println!(" Model has {} features\n", num_features); + + // Make a prediction with the buffer-loaded model + let test_data = vec![5.1, 3.5, 1.4, 0.2]; + let pred = booster_from_buffer.predict(&test_data, 1, 4, 0, false)?; + println!("Prediction from buffer-loaded model: {:?}\n", pred); + + println!("Advanced examples completed successfully!"); + + Ok(()) +} diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs new file mode 100644 index 0000000..9fa7bd4 --- /dev/null +++ b/examples/basic_usage.rs @@ -0,0 +1,95 @@ +use xgboost_rust::{Booster, XGBoostResult}; + +fn main() -> XGBoostResult<()> { + println!("XGBoost Rust Bindings - Basic Usage Example"); + println!("============================================\n"); + + // Note: This example assumes you have a trained XGBoost model file. + // To create one, you can use Python: + // + // ```python + // import xgboost as xgb + // from sklearn.datasets import load_iris + // from sklearn.model_selection import train_test_split + // + // # Load iris dataset + // iris = load_iris() + // X_train, X_test, y_train, y_test = train_test_split( + // iris.data, iris.target, test_size=0.2, random_state=42 + // ) + // + // # Train model + // dtrain = xgb.DMatrix(X_train, label=y_train) + // params = { + // 'objective': 'multi:softprob', + // 'num_class': 3, + // 'max_depth': 3, + // 'eta': 0.3 + // } + // bst = xgb.train(params, dtrain, num_boost_round=10) + // + // # Save model + // bst.save_model('iris_model.json') + // ``` + + // Load a pre-trained model + let model_path = "iris_model.json"; + + println!("Loading model from: {}", model_path); + let booster = match Booster::load(model_path) { + Ok(b) => { + println!("✓ Model loaded successfully\n"); + b + } + Err(e) => { + eprintln!("✗ Failed to load model: {}", e); + eprintln!("\nPlease create a model file first using the Python code in the example."); + return Err(e); + } + }; + + // Get model information + let num_features = booster.num_features()?; + println!("Model expects {} features\n", num_features); + + // Example prediction data (4 features for iris dataset) + // This is a sample from the iris dataset + let data = vec![ + 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa + 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica + 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor + ]; + let num_rows = 3; + let num_features = 4; + + println!("Making predictions on {} samples...", num_rows); + + // Make predictions + let predictions = booster.predict(&data, num_rows, num_features, 0, false)?; + + println!("✓ Predictions complete\n"); + + // Print results + println!("Predictions (probabilities for 3 classes):"); + for i in 0..num_rows { + println!("Sample {}:", i + 1); + println!(" Class 0 (Setosa): {:.4}", predictions[i * 3]); + println!(" Class 1 (Versicolor): {:.4}", predictions[i * 3 + 1]); + println!(" Class 2 (Virginica): {:.4}", predictions[i * 3 + 2]); + + // Find predicted class (argmax) + let mut max_prob = predictions[i * 3]; + let mut predicted_class = 0; + for j in 1..3 { + if predictions[i * 3 + j] > max_prob { + max_prob = predictions[i * 3 + j]; + predicted_class = j; + } + } + println!(" → Predicted class: {}\n", predicted_class); + } + + println!("Example completed successfully!"); + + Ok(()) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..f2bfc45 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,39 @@ +use std::ffi::CStr; +use std::fmt; +use crate::sys; + +pub type XGBoostResult = std::result::Result; + +#[derive(Debug, Eq, PartialEq)] +pub struct XGBoostError { + pub description: String, +} + +impl XGBoostError { + /// Check the return value from an XGBoost FFI call, and return the last error message on error. + /// Return values of 0 are treated as success, non-zero values are treated as errors. + pub fn check_return_value(ret_val: i32) -> XGBoostResult<()> { + if ret_val == 0 { + Ok(()) + } else { + Err(XGBoostError::fetch_xgboost_error()) + } + } + + /// Fetch current error message from XGBoost. + fn fetch_xgboost_error() -> Self { + let c_str = unsafe { CStr::from_ptr(sys::XGBGetLastError()) }; + let str_slice = c_str.to_str().unwrap_or("Unknown error"); + XGBoostError { + description: str_slice.to_owned(), + } + } +} + +impl fmt::Display for XGBoostError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.description) + } +} + +impl std::error::Error for XGBoostError {} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..60d3931 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,22 @@ +// Include the XGBoost C API bindings +mod sys; + +mod error; +pub use crate::error::{XGBoostError, XGBoostResult}; + +mod model; +pub use crate::model::Booster; + +// Re-export prediction option constants for convenience +pub mod predict_option { + /// Normal prediction, output is the transformed probability + pub const OUTPUT_MARGIN: u32 = 0x01; + /// Output the untransformed margin value + pub const PRED_LEAF: u32 = 0x02; + /// Output the leaf index of trees + pub const PRED_CONTRIBS: u32 = 0x04; + /// Output feature contributions (SHAP values) + pub const PRED_APPROX_CONTRIBS: u32 = 0x08; + /// Output feature interaction contributions + pub const PRED_INTERACTIONS: u32 = 0x10; +} diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..d8da5b3 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,243 @@ +use crate::error::{XGBoostError, XGBoostResult}; +use crate::sys; +use std::ffi::CString; +use std::path::Path; +use std::ptr; + +/// An XGBoost Booster for making predictions. +/// +/// # Thread Safety +/// +/// **Thread safety depends on XGBoost version:** +/// - **XGBoost ≥ 1.4**: Predictions are thread-safe for tree models (gbtree/dart). +/// However, prediction cache may still have edge cases. +/// - **XGBoost < 1.4**: NOT thread-safe for concurrent predictions. +/// +/// This wrapper does NOT implement `Send` or `Sync` to be conservative. +/// For multi-threaded use cases, use one of these approaches: +/// +/// 1. **Create one Booster per thread** (recommended, works with all versions): +/// ```ignore +/// let booster = Booster::load("model.json")?; +/// thread::spawn(move || { +/// booster.predict(...); // Each thread owns its Booster +/// }); +/// ``` +/// +/// 2. **Wrap in Arc>** for shared access: +/// ```ignore +/// use std::sync::{Arc, Mutex}; +/// let booster = Arc::new(Mutex::new(Booster::load("model.json")?)); +/// let booster_clone = booster.clone(); +/// thread::spawn(move || { +/// let booster = booster_clone.lock().unwrap(); +/// booster.predict(...); +/// }); +/// ``` +pub struct Booster { + handle: sys::BoosterHandle, +} + +// NOTE: We do NOT implement Send or Sync for Booster because: +// 1. Thread safety guarantees vary by XGBoost version (≥1.4 is safer) +// 2. C API documentation doesn't explicitly guarantee thread safety +// 3. Users should explicitly choose synchronization strategy (one-per-thread or Mutex) +// +// If you need to share a Booster across threads, wrap it in Arc>. + +impl Booster { + /// Load a model from a file + /// + /// # Arguments + /// * `path` - Path to the model file (can be JSON, binary, or deprecated text format) + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// ``` + pub fn load>(path: P) -> XGBoostResult { + let path_str = path.as_ref().to_str() + .ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str) + .map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; + + // Create a booster first + let mut handle: sys::BoosterHandle = ptr::null_mut(); + XGBoostError::check_return_value(unsafe { + sys::XGBoosterCreate(ptr::null(), 0, &mut handle) + })?; + + // Load model into the booster + XGBoostError::check_return_value(unsafe { + sys::XGBoosterLoadModel( + handle, + path_c_str.as_ptr(), + ) + })?; + + Ok(Booster { handle }) + } + + /// Load a model from a memory buffer + /// + /// # Arguments + /// * `buffer` - Model content as bytes + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// use std::fs; + /// + /// let buffer = fs::read("model.json").unwrap(); + /// let booster = Booster::load_from_buffer(&buffer).unwrap(); + /// ``` + pub fn load_from_buffer(buffer: &[u8]) -> XGBoostResult { + // Create a booster first + let mut handle: sys::BoosterHandle = ptr::null_mut(); + XGBoostError::check_return_value(unsafe { + sys::XGBoosterCreate(ptr::null(), 0, &mut handle) + })?; + + // Load model from buffer into the booster + XGBoostError::check_return_value(unsafe { + sys::XGBoosterLoadModelFromBuffer( + handle, + buffer.as_ptr() as *const std::os::raw::c_void, + buffer.len() as u64, + ) + })?; + + Ok(Booster { handle }) + } + + /// Make predictions on data + /// + /// # Arguments + /// * `data` - 2D array of features (row-major, num_rows x num_features) + /// * `num_rows` - Number of rows in the data + /// * `num_features` - Number of features per row + /// * `option_mask` - Prediction options (see `predict_option` module) + /// * `training` - Whether this is for training (false for inference) + /// + /// # Returns + /// A vector of prediction values + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// let data = vec![1.0, 2.0, 3.0, 4.0]; // 2 rows, 2 features + /// let predictions = booster.predict(&data, 2, 2, 0, false).unwrap(); + /// ``` + pub fn predict( + &self, + data: &[f32], + num_rows: usize, + num_features: usize, + option_mask: u32, + training: bool, + ) -> XGBoostResult> { + // Create DMatrix from data + let mut dmatrix_handle: sys::DMatrixHandle = ptr::null_mut(); + + XGBoostError::check_return_value(unsafe { + sys::XGDMatrixCreateFromMat( + data.as_ptr(), + num_rows as u64, + num_features as u64, + f32::NAN, + &mut dmatrix_handle, + ) + })?; + + // Make prediction + let mut out_len: u64 = 0; + let mut out_result: *const f32 = ptr::null(); + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterPredict( + self.handle, + dmatrix_handle, + option_mask as i32, + 0, // ntree_limit (0 means use all trees) + training as i32, + &mut out_len, + &mut out_result, + ) + })?; + + // Copy results to a Vec + let results = unsafe { + std::slice::from_raw_parts(out_result, out_len as usize).to_vec() + }; + + // Free DMatrix + unsafe { + sys::XGDMatrixFree(dmatrix_handle); + } + + Ok(results) + } + + /// Get the number of features the model expects + /// + /// # Returns + /// The number of features + pub fn num_features(&self) -> XGBoostResult { + let mut out_num_features: u64 = 0; + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterGetNumFeature( + self.handle, + &mut out_num_features, + ) + })?; + + Ok(out_num_features as usize) + } + + /// Save the model to a file + /// + /// # Arguments + /// * `path` - Path where to save the model + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// booster.save("model_copy.json").unwrap(); + /// ``` + pub fn save>(&self, path: P) -> XGBoostResult<()> { + let path_str = path.as_ref().to_str() + .ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str) + .map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterSaveModel( + self.handle, + path_c_str.as_ptr(), + ) + }) + } +} + +impl Drop for Booster { + fn drop(&mut self) { + unsafe { + sys::XGBoosterFree(self.handle); + } + } +} diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..a38a13a --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,5 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/wrapper.h b/wrapper.h new file mode 100644 index 0000000..1213e62 --- /dev/null +++ b/wrapper.h @@ -0,0 +1 @@ +#include From b632569cbae7cc13906ea71d723494167df434f2 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 11:27:22 +0200 Subject: [PATCH 02/11] forgot to push cargo.toml --- Cargo.toml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 Cargo.toml diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c0de980 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "xgboost-rust" +version = "0.1.0" +edition = "2021" +description = "Rust bindings for XGBoost, a gradient boosting library for machine learning. Downloads XGBoost binaries at build time for cross-platform compatibility." +license = "Apache-2.0" +keywords = ["machine-learning", "gradient-boosting", "xgboost", "ml"] +categories = ["science"] +readme = "README.md" +rust-version = "1.70" + +[build-dependencies] +bindgen = "0.72.0" +ureq = "2.0" +zip = "0.6" + +[features] +default = [] +gpu = [] + +[[example]] +name = "basic_usage" +path = "examples/basic_usage.rs" + +[[example]] +name = "advanced_usage" +path = "examples/advanced_usage.rs" From 1c4842dd1501ed3792649772f19f471cba4dcfc6 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 11:32:02 +0200 Subject: [PATCH 03/11] support thread safety for versions 1.4 and above. --- src/model.rs | 51 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/model.rs b/src/model.rs index d8da5b3..74994fe 100644 --- a/src/model.rs +++ b/src/model.rs @@ -10,13 +10,30 @@ use std::ptr; /// /// **Thread safety depends on XGBoost version:** /// - **XGBoost ≥ 1.4**: Predictions are thread-safe for tree models (gbtree/dart). -/// However, prediction cache may still have edge cases. -/// - **XGBoost < 1.4**: NOT thread-safe for concurrent predictions. +/// `Send` and `Sync` are automatically implemented for these versions. +/// You can safely share `Arc` across threads. +/// - **XGBoost < 1.4**: NOT thread-safe. `Send` and `Sync` are NOT implemented. /// -/// This wrapper does NOT implement `Send` or `Sync` to be conservative. -/// For multi-threaded use cases, use one of these approaches: +/// ## Usage with XGBoost ≥ 1.4 /// -/// 1. **Create one Booster per thread** (recommended, works with all versions): +/// ```ignore +/// use std::sync::Arc; +/// use std::thread; +/// +/// let booster = Arc::new(Booster::load("model.json")?); +/// let booster_clone = booster.clone(); +/// +/// thread::spawn(move || { +/// // Safe to call predict concurrently +/// booster_clone.predict(...); +/// }); +/// ``` +/// +/// ## Usage with XGBoost < 1.4 +/// +/// For older versions, use one of these approaches: +/// +/// 1. **Create one Booster per thread** (recommended): /// ```ignore /// let booster = Booster::load("model.json")?; /// thread::spawn(move || { @@ -24,26 +41,26 @@ use std::ptr; /// }); /// ``` /// -/// 2. **Wrap in Arc>** for shared access: +/// 2. **Wrap in Arc>**: /// ```ignore /// use std::sync::{Arc, Mutex}; /// let booster = Arc::new(Mutex::new(Booster::load("model.json")?)); -/// let booster_clone = booster.clone(); -/// thread::spawn(move || { -/// let booster = booster_clone.lock().unwrap(); -/// booster.predict(...); -/// }); /// ``` pub struct Booster { handle: sys::BoosterHandle, } -// NOTE: We do NOT implement Send or Sync for Booster because: -// 1. Thread safety guarantees vary by XGBoost version (≥1.4 is safer) -// 2. C API documentation doesn't explicitly guarantee thread safety -// 3. Users should explicitly choose synchronization strategy (one-per-thread or Mutex) -// -// If you need to share a Booster across threads, wrap it in Arc>. +// Thread safety implementation based on XGBoost version +// XGBoost 1.4.0+ supports thread-safe predictions for tree models (gbtree/dart) +// See: https://github.com/dmlc/xgboost/issues/5339 +#[cfg(xgboost_thread_safe)] +unsafe impl Send for Booster {} + +#[cfg(xgboost_thread_safe)] +unsafe impl Sync for Booster {} + +// For XGBoost < 1.4, Send and Sync are NOT implemented. +// Users should wrap in Arc> or use one Booster per thread. impl Booster { /// Load a model from a file From 5fabbdbe88bd50f66e65ddee5c4329a12f81f6a4 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 11:32:22 +0200 Subject: [PATCH 04/11] support thread safety for versions 1.4 and above. --- build.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/build.rs b/build.rs index 93bed5f..8776b47 100644 --- a/build.rs +++ b/build.rs @@ -9,6 +9,27 @@ fn get_xgboost_version() -> String { env::var("XGBOOST_VERSION").unwrap_or_else(|_| "3.1.1".to_string()) } +fn parse_version(version: &str) -> (u32, u32, u32) { + let parts: Vec<&str> = version.split('.').collect(); + let major = parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(0); + let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); + (major, minor, patch) +} + +fn emit_version_cfg_flags(version: &str) { + let (major, minor, _patch) = parse_version(version); + + // XGBoost 1.4.0+ has thread-safe predictions for tree models + // See: https://github.com/dmlc/xgboost/issues/5339 + if major > 1 || (major == 1 && minor >= 4) { + println!("cargo:rustc-cfg=xgboost_thread_safe"); + println!("cargo:warning=XGBoost version {} supports thread-safe predictions", version); + } else { + println!("cargo:warning=XGBoost version {} does NOT support thread-safe predictions", version); + } +} + fn get_platform_info() -> (String, String) { let target = env::var("TARGET").unwrap(); @@ -166,6 +187,10 @@ fn main() { let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let xgb_include_root = out_dir.join("include"); + // Get version and emit cfg flags for thread safety + let version = get_xgboost_version(); + emit_version_cfg_flags(&version); + // Download the headers if let Err(e) = download_xgboost_headers(&out_dir) { eprintln!("Failed to download XGBoost headers: {}", e); From bb0d3741b57a180af681736a4de5883ddfb29ac0 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 12:02:08 +0200 Subject: [PATCH 05/11] add checksum + avoid memory leak --- Cargo.toml | 1 + build.rs | 247 +++++++++++++++++++++++++++++++++++++++------------ src/model.rs | 41 ++++++++- 3 files changed, 226 insertions(+), 63 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c0de980..d1848e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ rust-version = "1.70" bindgen = "0.72.0" ureq = "2.0" zip = "0.6" +sha2 = "0.10" [features] default = [] diff --git a/build.rs b/build.rs index 8776b47..ab614e8 100644 --- a/build.rs +++ b/build.rs @@ -1,14 +1,63 @@ extern crate bindgen; +use sha2::{Sha256, Digest}; +use std::collections::HashMap; use std::env; use std::path::{Path, PathBuf}; use std::fs; -use std::io; +use std::io::{self, Read, Write}; +use std::thread; +use std::time::Duration; fn get_xgboost_version() -> String { env::var("XGBOOST_VERSION").unwrap_or_else(|_| "3.1.1".to_string()) } +// Known SHA256 checksums for header files by version +fn get_header_checksums() -> HashMap<&'static str, (&'static str, &'static str)> { + let mut checksums = HashMap::new(); + // Format: version => (c_api.h SHA256, base.h SHA256) + checksums.insert("3.1.1", ( + "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", + "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a" + )); + checksums.insert("3.0.5", ( + "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", + "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7" + )); + checksums.insert("2.1.4", ( + "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", + "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce" + )); + checksums.insert("1.7.6", ( + "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", + "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935" + )); + checksums.insert("1.4.2", ( + "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", + "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7" + )); + checksums +} + +fn compute_sha256(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + format!("{:x}", hasher.finalize()) +} + +fn verify_checksum(data: &[u8], expected: &str, filename: &str) -> Result<(), Box> { + let actual = compute_sha256(data); + if actual != expected { + return Err(format!( + "SHA256 checksum mismatch for {}:\n Expected: {}\n Got: {}", + filename, expected, actual + ).into()); + } + println!("cargo:warning=✓ Verified SHA256 for {}", filename); + Ok(()) +} + fn parse_version(version: &str) -> (u32, u32, u32) { let parts: Vec<&str> = version.split('.').collect(); let major = parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(0); @@ -60,48 +109,106 @@ fn get_platform_info() -> (String, String) { fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box> { let version = get_xgboost_version(); + let checksums = get_header_checksums(); + + // Get expected checksums for this version + let (c_api_expected, base_expected) = checksums.get(version.as_str()) + .ok_or_else(|| format!( + "No known SHA256 checksums for XGBoost version {}. \ + Please verify this version manually or add checksums to build.rs", + version + ))?; // Create the include/xgboost directory let include_dir = out_dir.join("include/xgboost"); fs::create_dir_all(&include_dir)?; - // Download the c_api.h file - let c_api_url = format!( - "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/c_api.h", - version - ); + // Download and verify c_api.h + let c_api_path = include_dir.join("c_api.h"); + download_and_verify_file( + &format!("https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/c_api.h", version), + &c_api_path, + c_api_expected, + "c_api.h" + )?; + + // Download and verify base.h + let base_path = include_dir.join("base.h"); + download_and_verify_file( + &format!("https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/base.h", version), + &base_path, + base_expected, + "base.h" + )?; + + Ok(()) +} - println!("cargo:warning=Downloading c_api.h from: {}", c_api_url); +fn download_and_verify_file( + url: &str, + dest_path: &Path, + expected_sha256: &str, + filename: &str +) -> Result<(), Box> { + println!("cargo:warning=Downloading {} from: {}", filename, url); - let response = ureq::get(&c_api_url).call()?; + // Download into memory buffer + let response = ureq::get(url).call()?; let status = response.status(); if status < 200 || status >= 300 { - return Err(format!("Failed to download c_api.h: HTTP {}", status).into()); + return Err(format!("Failed to download {}: HTTP {}", filename, status).into()); } - let c_api_path = include_dir.join("c_api.h"); - let mut file = fs::File::create(&c_api_path)?; - io::copy(&mut response.into_reader(), &mut file)?; + let mut buffer = Vec::new(); + response.into_reader().read_to_end(&mut buffer)?; - // Also download base.h which is referenced by c_api.h - let base_url = format!( - "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/base.h", - version - ); + // Verify SHA256 checksum + verify_checksum(&buffer, expected_sha256, filename)?; - println!("cargo:warning=Downloading base.h from: {}", base_url); + // Only write file after successful verification + let mut file = fs::File::create(dest_path)?; + file.write_all(&buffer)?; - let response = ureq::get(&base_url).call()?; - let status = response.status(); - if status < 200 || status >= 300 { - return Err(format!("Failed to download base.h: HTTP {}", status).into()); - } + Ok(()) +} - let base_path = include_dir.join("base.h"); - let mut file = fs::File::create(&base_path)?; - io::copy(&mut response.into_reader(), &mut file)?; +fn download_with_retry(url: &str, max_retries: u32) -> Result, Box> { + let mut last_error = None; - Ok(()) + for attempt in 0..max_retries { + if attempt > 0 { + let backoff = Duration::from_millis(100 * 2_u64.pow(attempt)); + println!("cargo:warning=Retry attempt {} after {:?}", attempt + 1, backoff); + thread::sleep(backoff); + } + + match ureq::get(url).call() { + Ok(response) => { + let status = response.status(); + if status < 200 || status >= 300 { + last_error = Some(format!("HTTP {}", status)); + continue; + } + + let mut buffer = Vec::new(); + if let Err(e) = response.into_reader().read_to_end(&mut buffer) { + last_error = Some(e.to_string()); + continue; + } + + return Ok(buffer); + } + Err(e) => { + last_error = Some(e.to_string()); + } + } + } + + Err(format!( + "Failed to download after {} attempts. Last error: {}", + max_retries, + last_error.unwrap_or_else(|| "Unknown error".to_string()) + ).into()) } fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box> { @@ -118,44 +225,59 @@ fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box return Err(format!("Unsupported platform: {}-{}", os, arch).into()), }; - let download_url = format!( - "https://files.pythonhosted.org/packages/py3/x/xgboost/{}", - wheel_filename - ); - - println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url); + let lib_filename = match os.as_str() { + "windows" => "xgboost.dll", + "darwin" => "libxgboost.dylib", + _ => "libxgboost.so", + }; - // Download the wheel + // Setup paths let wheel_dir = out_dir.join("wheel"); + let lib_dir = out_dir.join("libs"); fs::create_dir_all(&wheel_dir)?; + fs::create_dir_all(&lib_dir)?; + let wheel_path = wheel_dir.join(&wheel_filename); + let lib_dest_path = lib_dir.join(lib_filename); - let response = ureq::get(&download_url).call()?; - let status = response.status(); - if status < 200 || status >= 300 { - return Err(format!("Failed to download wheel: HTTP {}", status).into()); + // Check if library already exists and is valid + if lib_dest_path.exists() { + println!("cargo:warning=Using cached XGBoost library at: {}", lib_dest_path.display()); + return Ok(()); } - let mut wheel_file = fs::File::create(&wheel_path)?; - io::copy(&mut response.into_reader(), &mut wheel_file)?; - drop(wheel_file); + // Check if wheel is cached + let wheel_buffer = if wheel_path.exists() { + println!("cargo:warning=Using cached wheel at: {}", wheel_path.display()); + fs::read(&wheel_path)? + } else { + // Download wheel with retry + let download_url = format!( + "https://files.pythonhosted.org/packages/py3/x/xgboost/{}", + wheel_filename + ); + + println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url); + let buffer = download_with_retry(&download_url, 3)?; + + // Write atomically (temp file + rename) + let temp_path = wheel_path.with_extension("tmp"); + { + let mut temp_file = fs::File::create(&temp_path)?; + temp_file.write_all(&buffer)?; + temp_file.sync_all()?; + } + fs::rename(&temp_path, &wheel_path)?; - println!("cargo:warning=Extracting wheel: {}", wheel_path.display()); + println!("cargo:warning=✓ Downloaded and cached wheel"); + buffer + }; - // Extract the wheel (it's a ZIP file) - let file = fs::File::open(&wheel_path)?; - let mut archive = zip::ZipArchive::new(file)?; + // Extract library from wheel + println!("cargo:warning=Extracting library from wheel"); - // Create libs directory - let lib_dir = out_dir.join("libs"); - fs::create_dir_all(&lib_dir)?; - - // Determine library filename based on OS - let lib_filename = match os.as_str() { - "windows" => "xgboost.dll", - "darwin" => "libxgboost.dylib", - _ => "libxgboost.so", - }; + let cursor = io::Cursor::new(wheel_buffer); + let mut archive = zip::ZipArchive::new(cursor)?; // Search for the library file in the wheel let mut found = false; @@ -166,9 +288,16 @@ fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box Result<(), Box XGBoostResult> { + // Validate input dimensions + let expected_len = num_rows.checked_mul(num_features) + .ok_or_else(|| XGBoostError { + description: format!( + "Integer overflow: num_rows ({}) * num_features ({}) exceeds usize::MAX", + num_rows, num_features + ), + })?; + + if data.len() != expected_len { + return Err(XGBoostError { + description: format!( + "Data length mismatch: expected {} elements ({}×{}), got {}", + expected_len, num_rows, num_features, data.len() + ), + }); + } + // Create DMatrix from data let mut dmatrix_handle: sys::DMatrixHandle = ptr::null_mut(); @@ -174,6 +192,17 @@ impl Booster { ) })?; + // RAII guard to ensure DMatrix is always freed + struct DMatrixGuard(sys::DMatrixHandle); + impl Drop for DMatrixGuard { + fn drop(&mut self) { + unsafe { + sys::XGDMatrixFree(self.0); + } + } + } + let _guard = DMatrixGuard(dmatrix_handle); + // Make prediction let mut out_len: u64 = 0; let mut out_result: *const f32 = ptr::null(); @@ -190,15 +219,19 @@ impl Booster { ) })?; + // Validate output pointers + if out_result.is_null() || out_len == 0 { + return Err(XGBoostError { + description: "XGBoost returned null or empty prediction result".to_string(), + }); + } + // Copy results to a Vec let results = unsafe { std::slice::from_raw_parts(out_result, out_len as usize).to_vec() }; - // Free DMatrix - unsafe { - sys::XGDMatrixFree(dmatrix_handle); - } + // DMatrix will be automatically freed when _guard goes out of scope Ok(results) } From 9853ff350321f44573e188a4c3f806472af600ca Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 12:09:28 +0200 Subject: [PATCH 06/11] update README.md --- README.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 00a8046..72a68bf 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Rust bindings for XGBoost, a gradient boosting library for machine learning. - **Automatic Binary Download**: Downloads XGBoost binaries at build time from PyPI wheels - **Cross-Platform**: Supports Linux (x86_64, aarch64), macOS (x86_64, arm64), and Windows (x86_64) - **Version Control**: Specify XGBoost version via `XGBOOST_VERSION` environment variable +- **Version-Aware Thread Safety**: Automatically enables `Send + Sync` for XGBoost ≥ 1.4 - **Easy to Use**: Simple, safe Rust API wrapping the XGBoost C API ## Installation @@ -85,10 +86,28 @@ This crate downloads the appropriate XGBoost Python wheel from PyPI during the b ## Thread Safety -The `Booster` type is **not thread-safe**. For multi-threaded usage: +Thread safety is **version-aware**: -1. **Recommended**: Create one `Booster` per thread -2. **Alternative**: Wrap in `Arc>` for shared access +- **XGBoost ≥ 1.4**: `Booster` implements `Send + Sync` and is thread-safe for predictions on tree models. You can safely share `Arc` across threads. +- **XGBoost < 1.4**: `Booster` does NOT implement `Send + Sync`. Use one booster per thread or wrap in `Arc>`. + +### Example with XGBoost ≥ 1.4 + +```rust +use std::sync::Arc; +use std::thread; +use xgboost_rust::Booster; + +let booster = Arc::new(Booster::load("model.json")?); +let booster_clone = booster.clone(); + +thread::spawn(move || { + // Safe concurrent predictions with XGBoost ≥ 1.4 + let predictions = booster_clone.predict(&data, rows, cols, 0, false)?; +}); +``` + +The version check happens automatically at build time based on the `XGBOOST_VERSION` environment variable. ## Examples From 40ae111d0e12f9541e98f0404e8897e0ca3065d3 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 12:19:38 +0200 Subject: [PATCH 07/11] avoid memory leak when error reading the model. --- src/model.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/model.rs b/src/model.rs index 7abefe8..47c135e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -91,12 +91,17 @@ impl Booster { })?; // Load model into the booster - XGBoostError::check_return_value(unsafe { + let result = XGBoostError::check_return_value(unsafe { sys::XGBoosterLoadModel( handle, path_c_str.as_ptr(), ) - })?; + }); + + if let Err(e) = result { + unsafe { sys::XGBoosterFree(handle); } + return Err(e); + } Ok(Booster { handle }) } @@ -122,13 +127,18 @@ impl Booster { })?; // Load model from buffer into the booster - XGBoostError::check_return_value(unsafe { + let result = XGBoostError::check_return_value(unsafe { sys::XGBoosterLoadModelFromBuffer( handle, buffer.as_ptr() as *const std::os::raw::c_void, buffer.len() as u64, ) - })?; + }); + + if let Err(e) = result { + unsafe { sys::XGBoosterFree(handle); } + return Err(e); + } Ok(Booster { handle }) } From b06cc1dd1a9633712f13f4416305e8e60b612570 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 10 Nov 2025 22:17:55 +0200 Subject: [PATCH 08/11] add ci to check builds. --- .github/workflows/ci.yml | 252 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a7714b8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,252 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Test ${{ matrix.os }} (${{ matrix.arch }}) - XGBoost ${{ matrix.xgboost_version }} + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + # macOS ARM64 (M1/M2/M3) - Test multiple versions + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.1.1" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.0.5" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "2.1.4" + + # macOS x86_64 (Intel) + - os: macos + arch: x86_64 + runner: macos-15-large + xgboost_version: "3.1.1" + + # Linux x86_64 - Test multiple versions including thread-safety boundary + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.7.6" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.4.2" # First thread-safe version + + # Linux ARM64 + - os: linux + arch: arm64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + + # Windows x86_64 + - os: windows + arch: x86_64 + runner: windows-latest + xgboost_version: "3.1.1" + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v4 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v4 + with: + path: target + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target- + + - name: Install system dependencies (Linux) + if: matrix.os == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Install system dependencies (macOS) + if: matrix.os == 'macos' + run: | + brew install libomp + + - name: Check build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo check --verbose + + - name: Build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo build --verbose + + - name: Run tests (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo test --verbose + + - name: Build examples + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + cargo build --example basic_usage --verbose + cargo build --example advanced_usage --verbose + + - name: Verify library architecture (macOS) + - name: Verify library architecture (macOS) + if: matrix.os == 'macos' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.dylib + lipo -info target/debug/libxgboost.dylib || otool -L target/debug/libxgboost.dylib + + - name: Verify library architecture (Linux) + if: matrix.os == 'linux' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.so + readelf -h target/debug/libxgboost.so | grep Machine + + - name: Verify library exists (Windows) + if: matrix.os == 'windows' + run: | + echo "Checking library exists..." + Get-Item target/debug/xgboost.dll + + - name: Verify thread safety detection + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + echo "XGBoost version: ${{ matrix.xgboost_version }}" + cargo build --verbose 2>&1 | grep "thread-safe" || echo "No thread-safe message found" + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Run clippy (no features) + run: cargo clippy -- -D warnings + + fmt: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + + # Test checksum verification + security-checksums: + name: Verify SHA256 checksums + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Test checksum verification for XGBoost 3.1.1 + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 3.1.1" + + - name: Test checksum verification for XGBoost 2.1.4 + env: + XGBOOST_VERSION: "2.1.4" + run: | + cargo clean + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 2.1.4" + + # Test caching behavior + caching-test: + name: Test wheel caching + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: First build (should download) + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build1.log + grep "Downloading XGBoost wheel" build1.log || true + + - name: Second build (should use cache) + env: + XGBOOST_VERSION: "3.1.1" + run: | + touch src/lib.rs + cargo check --verbose 2>&1 | tee build2.log + grep "Using cached XGBoost library" build2.log + echo "Caching is working correctly" From 7b262c0252d6f0bf11bedb2f59c8a9d0fa31b0de Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 10 Nov 2025 22:20:40 +0200 Subject: [PATCH 09/11] fix type in ci. --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a7714b8..26ec020 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -127,7 +127,6 @@ jobs: cargo build --example basic_usage --verbose cargo build --example advanced_usage --verbose - - name: Verify library architecture (macOS) - name: Verify library architecture (macOS) if: matrix.os == 'macos' run: | From c790f02b4184d4543de30e439e9fe9de955cf701 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Tue, 11 Nov 2025 20:14:15 +0200 Subject: [PATCH 10/11] dix advanced_usage.rs. --- examples/advanced_usage.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index ab1acde..d500446 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -62,7 +62,11 @@ fn main() -> XGBoostResult<()> { // Example: Load from buffer (useful for embedded models) println!("=== Example 4: Load from Buffer ==="); - let buffer = std::fs::read(model_path)?; + let buffer = std::fs::read(model_path).map_err(|e| { + xgboost_rust::XGBoostError { + description: format!("Failed to read model file: {}", e), + } + })?; let booster_from_buffer = Booster::load_from_buffer(&buffer)?; println!("✓ Model loaded from buffer ({} bytes)", buffer.len()); From 2f808fdaf08de945270018726e61bd7c728c3901 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Tue, 11 Nov 2025 21:34:31 +0200 Subject: [PATCH 11/11] fix fmt and clippy --- build.rs | 180 +++++++++++++++++++++++++------------ examples/advanced_usage.rs | 25 ++---- examples/basic_usage.rs | 6 +- src/error.rs | 2 +- src/model.rs | 63 ++++++------- src/sys.rs | 1 + 6 files changed, 165 insertions(+), 112 deletions(-) diff --git a/build.rs b/build.rs index ab614e8..c76c9f5 100644 --- a/build.rs +++ b/build.rs @@ -1,11 +1,11 @@ extern crate bindgen; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::env; -use std::path::{Path, PathBuf}; use std::fs; use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; use std::thread; use std::time::Duration; @@ -17,26 +17,41 @@ fn get_xgboost_version() -> String { fn get_header_checksums() -> HashMap<&'static str, (&'static str, &'static str)> { let mut checksums = HashMap::new(); // Format: version => (c_api.h SHA256, base.h SHA256) - checksums.insert("3.1.1", ( - "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", - "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a" - )); - checksums.insert("3.0.5", ( - "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", - "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7" - )); - checksums.insert("2.1.4", ( - "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", - "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce" - )); - checksums.insert("1.7.6", ( - "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", - "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935" - )); - checksums.insert("1.4.2", ( - "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", - "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7" - )); + checksums.insert( + "3.1.1", + ( + "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", + "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a", + ), + ); + checksums.insert( + "3.0.5", + ( + "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", + "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7", + ), + ); + checksums.insert( + "2.1.4", + ( + "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", + "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce", + ), + ); + checksums.insert( + "1.7.6", + ( + "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", + "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935", + ), + ); + checksums.insert( + "1.4.2", + ( + "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", + "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7", + ), + ); checksums } @@ -46,13 +61,18 @@ fn compute_sha256(data: &[u8]) -> String { format!("{:x}", hasher.finalize()) } -fn verify_checksum(data: &[u8], expected: &str, filename: &str) -> Result<(), Box> { +fn verify_checksum( + data: &[u8], + expected: &str, + filename: &str, +) -> Result<(), Box> { let actual = compute_sha256(data); if actual != expected { return Err(format!( "SHA256 checksum mismatch for {}:\n Expected: {}\n Got: {}", filename, expected, actual - ).into()); + ) + .into()); } println!("cargo:warning=✓ Verified SHA256 for {}", filename); Ok(()) @@ -60,7 +80,7 @@ fn verify_checksum(data: &[u8], expected: &str, filename: &str) -> Result<(), Bo fn parse_version(version: &str) -> (u32, u32, u32) { let parts: Vec<&str> = version.split('.').collect(); - let major = parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(0); + let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0); let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); (major, minor, patch) @@ -73,9 +93,15 @@ fn emit_version_cfg_flags(version: &str) { // See: https://github.com/dmlc/xgboost/issues/5339 if major > 1 || (major == 1 && minor >= 4) { println!("cargo:rustc-cfg=xgboost_thread_safe"); - println!("cargo:warning=XGBoost version {} supports thread-safe predictions", version); + println!( + "cargo:warning=XGBoost version {} supports thread-safe predictions", + version + ); } else { - println!("cargo:warning=XGBoost version {} does NOT support thread-safe predictions", version); + println!( + "cargo:warning=XGBoost version {} does NOT support thread-safe predictions", + version + ); } } @@ -112,12 +138,13 @@ fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box> { println!("cargo:warning=Downloading {} from: {}", filename, url); // Download into memory buffer let response = ureq::get(url).call()?; let status = response.status(); - if status < 200 || status >= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download {}: HTTP {}", filename, status).into()); } @@ -178,14 +211,18 @@ fn download_with_retry(url: &str, max_retries: u32) -> Result, Box 0 { let backoff = Duration::from_millis(100 * 2_u64.pow(attempt)); - println!("cargo:warning=Retry attempt {} after {:?}", attempt + 1, backoff); + println!( + "cargo:warning=Retry attempt {} after {:?}", + attempt + 1, + backoff + ); thread::sleep(backoff); } match ureq::get(url).call() { Ok(response) => { let status = response.status(); - if status < 200 || status >= 300 { + if !(200..300).contains(&status) { last_error = Some(format!("HTTP {}", status)); continue; } @@ -208,7 +245,8 @@ fn download_with_retry(url: &str, max_retries: u32) -> Result, Box Result<(), Box> { @@ -242,13 +280,19 @@ fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box { // For Linux, use $ORIGIN println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); // Add the target directory to rpath as well if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); } - }, + } _ => {} // No rpath needed for Windows } diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index d500446..8ad0ba9 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -22,10 +22,7 @@ fn main() -> XGBoostResult<()> { // Example: Binary classification prediction println!("=== Example 1: Normal Prediction ==="); - let data = vec![ - 5.1, 3.5, 1.4, 0.2, - 6.7, 3.0, 5.2, 2.3, - ]; + let data = vec![5.1, 3.5, 1.4, 0.2, 6.7, 3.0, 5.2, 2.3]; let predictions = booster.predict(&data, 2, 4, 0, false)?; println!("Normal predictions (probabilities): {:?}\n", predictions); @@ -34,22 +31,16 @@ fn main() -> XGBoostResult<()> { println!("=== Example 2: Feature Contributions (SHAP) ==="); use xgboost_rust::predict_option; - let shap_values = booster.predict( - &data, - 2, - 4, - predict_option::PRED_CONTRIBS, - false - )?; + let shap_values = booster.predict(&data, 2, 4, predict_option::PRED_CONTRIBS, false)?; println!("SHAP values for first sample:"); // SHAP values include one extra value for bias term let num_features = 4; - for i in 0..=num_features { + for (i, &value) in shap_values.iter().enumerate().take(num_features + 1) { if i < num_features { - println!(" Feature {}: {:.4}", i, shap_values[i]); + println!(" Feature {}: {:.4}", i, value); } else { - println!(" Bias term: {:.4}", shap_values[i]); + println!(" Bias term: {:.4}", value); } } println!(); @@ -62,10 +53,8 @@ fn main() -> XGBoostResult<()> { // Example: Load from buffer (useful for embedded models) println!("=== Example 4: Load from Buffer ==="); - let buffer = std::fs::read(model_path).map_err(|e| { - xgboost_rust::XGBoostError { - description: format!("Failed to read model file: {}", e), - } + let buffer = std::fs::read(model_path).map_err(|e| xgboost_rust::XGBoostError { + description: format!("Failed to read model file: {}", e), })?; let booster_from_buffer = Booster::load_from_buffer(&buffer)?; println!("✓ Model loaded from buffer ({} bytes)", buffer.len()); diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index 9fa7bd4..d167a0a 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -55,9 +55,9 @@ fn main() -> XGBoostResult<()> { // Example prediction data (4 features for iris dataset) // This is a sample from the iris dataset let data = vec![ - 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa - 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica - 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor + 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa + 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica + 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor ]; let num_rows = 3; let num_features = 4; diff --git a/src/error.rs b/src/error.rs index f2bfc45..88e17d5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ +use crate::sys; use std::ffi::CStr; use std::fmt; -use crate::sys; pub type XGBoostResult = std::result::Result; diff --git a/src/model.rs b/src/model.rs index 47c135e..9ab8a5d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -75,14 +75,12 @@ impl Booster { /// let booster = Booster::load("model.json").unwrap(); /// ``` pub fn load>(path: P) -> XGBoostResult { - let path_str = path.as_ref().to_str() - .ok_or_else(|| XGBoostError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| XGBoostError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; // Create a booster first let mut handle: sys::BoosterHandle = ptr::null_mut(); @@ -92,14 +90,13 @@ impl Booster { // Load model into the booster let result = XGBoostError::check_return_value(unsafe { - sys::XGBoosterLoadModel( - handle, - path_c_str.as_ptr(), - ) + sys::XGBoosterLoadModel(handle, path_c_str.as_ptr()) }); if let Err(e) = result { - unsafe { sys::XGBoosterFree(handle); } + unsafe { + sys::XGBoosterFree(handle); + } return Err(e); } @@ -136,7 +133,9 @@ impl Booster { }); if let Err(e) = result { - unsafe { sys::XGBoosterFree(handle); } + unsafe { + sys::XGBoosterFree(handle); + } return Err(e); } @@ -172,7 +171,8 @@ impl Booster { training: bool, ) -> XGBoostResult> { // Validate input dimensions - let expected_len = num_rows.checked_mul(num_features) + let expected_len = num_rows + .checked_mul(num_features) .ok_or_else(|| XGBoostError { description: format!( "Integer overflow: num_rows ({}) * num_features ({}) exceeds usize::MAX", @@ -184,7 +184,10 @@ impl Booster { return Err(XGBoostError { description: format!( "Data length mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_features, data.len() + expected_len, + num_rows, + num_features, + data.len() ), }); } @@ -237,9 +240,7 @@ impl Booster { } // Copy results to a Vec - let results = unsafe { - std::slice::from_raw_parts(out_result, out_len as usize).to_vec() - }; + let results = unsafe { std::slice::from_raw_parts(out_result, out_len as usize).to_vec() }; // DMatrix will be automatically freed when _guard goes out of scope @@ -254,10 +255,7 @@ impl Booster { let mut out_num_features: u64 = 0; XGBoostError::check_return_value(unsafe { - sys::XGBoosterGetNumFeature( - self.handle, - &mut out_num_features, - ) + sys::XGBoosterGetNumFeature(self.handle, &mut out_num_features) })?; Ok(out_num_features as usize) @@ -276,20 +274,15 @@ impl Booster { /// booster.save("model_copy.json").unwrap(); /// ``` pub fn save>(&self, path: P) -> XGBoostResult<()> { - let path_str = path.as_ref().to_str() - .ok_or_else(|| XGBoostError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| XGBoostError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; XGBoostError::check_return_value(unsafe { - sys::XGBoosterSaveModel( - self.handle, - path_c_str.as_ptr(), - ) + sys::XGBoosterSaveModel(self.handle, path_c_str.as_ptr()) }) } } diff --git a/src/sys.rs b/src/sys.rs index a38a13a..cd503e4 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1,5 +1,6 @@ #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +#![allow(dead_code)] include!(concat!(env!("OUT_DIR"), "/bindings.rs"));