diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..26ec020 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,251 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Test ${{ matrix.os }} (${{ matrix.arch }}) - XGBoost ${{ matrix.xgboost_version }} + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + # macOS ARM64 (M1/M2/M3) - Test multiple versions + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.1.1" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.0.5" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "2.1.4" + + # macOS x86_64 (Intel) + - os: macos + arch: x86_64 + runner: macos-15-large + xgboost_version: "3.1.1" + + # Linux x86_64 - Test multiple versions including thread-safety boundary + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.7.6" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.4.2" # First thread-safe version + + # Linux ARM64 + - os: linux + arch: arm64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + + # Windows x86_64 + - os: windows + arch: x86_64 + runner: windows-latest + xgboost_version: "3.1.1" + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v4 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v4 + with: + path: target + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target- + + - name: Install system dependencies (Linux) + if: matrix.os == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Install system dependencies (macOS) + if: matrix.os == 'macos' + run: | + brew install libomp + + - name: Check build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo check --verbose + + - name: Build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo build --verbose + + - name: Run tests (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo test --verbose + + - name: Build examples + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + cargo build --example basic_usage --verbose + cargo build --example advanced_usage --verbose + + - name: Verify library architecture (macOS) + if: matrix.os == 'macos' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.dylib + lipo -info target/debug/libxgboost.dylib || otool -L target/debug/libxgboost.dylib + + - name: Verify library architecture (Linux) + if: matrix.os == 'linux' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.so + readelf -h target/debug/libxgboost.so | grep Machine + + - name: Verify library exists (Windows) + if: matrix.os == 'windows' + run: | + echo "Checking library exists..." + Get-Item target/debug/xgboost.dll + + - name: Verify thread safety detection + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + echo "XGBoost version: ${{ matrix.xgboost_version }}" + cargo build --verbose 2>&1 | grep "thread-safe" || echo "No thread-safe message found" + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Run clippy (no features) + run: cargo clippy -- -D warnings + + fmt: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + + # Test checksum verification + security-checksums: + name: Verify SHA256 checksums + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Test checksum verification for XGBoost 3.1.1 + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 3.1.1" + + - name: Test checksum verification for XGBoost 2.1.4 + env: + XGBOOST_VERSION: "2.1.4" + run: | + cargo clean + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 2.1.4" + + # Test caching behavior + caching-test: + name: Test wheel caching + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: First build (should download) + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build1.log + grep "Downloading XGBoost wheel" build1.log || true + + - name: Second build (should use cache) + env: + XGBOOST_VERSION: "3.1.1" + run: | + touch src/lib.rs + cargo check --verbose 2>&1 | tee build2.log + grep "Using cached XGBoost library" build2.log + echo "Caching is working correctly" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1138cd6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +/target +Cargo.lock +*.json +*.bin +*.model +*.dylib +*.so +*.dll diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d1848e0 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "xgboost-rust" +version = "0.1.0" +edition = "2021" +description = "Rust bindings for XGBoost, a gradient boosting library for machine learning. Downloads XGBoost binaries at build time for cross-platform compatibility." +license = "Apache-2.0" +keywords = ["machine-learning", "gradient-boosting", "xgboost", "ml"] +categories = ["science"] +readme = "README.md" +rust-version = "1.70" + +[build-dependencies] +bindgen = "0.72.0" +ureq = "2.0" +zip = "0.6" +sha2 = "0.10" + +[features] +default = [] +gpu = [] + +[[example]] +name = "basic_usage" +path = "examples/basic_usage.rs" + +[[example]] +name = "advanced_usage" +path = "examples/advanced_usage.rs" diff --git a/README.md b/README.md index 00a8046..72a68bf 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Rust bindings for XGBoost, a gradient boosting library for machine learning. - **Automatic Binary Download**: Downloads XGBoost binaries at build time from PyPI wheels - **Cross-Platform**: Supports Linux (x86_64, aarch64), macOS (x86_64, arm64), and Windows (x86_64) - **Version Control**: Specify XGBoost version via `XGBOOST_VERSION` environment variable +- **Version-Aware Thread Safety**: Automatically enables `Send + Sync` for XGBoost ≥ 1.4 - **Easy to Use**: Simple, safe Rust API wrapping the XGBoost C API ## Installation @@ -85,10 +86,28 @@ This crate downloads the appropriate XGBoost Python wheel from PyPI during the b ## Thread Safety -The `Booster` type is **not thread-safe**. For multi-threaded usage: +Thread safety is **version-aware**: -1. **Recommended**: Create one `Booster` per thread -2. **Alternative**: Wrap in `Arc>` for shared access +- **XGBoost ≥ 1.4**: `Booster` implements `Send + Sync` and is thread-safe for predictions on tree models. You can safely share `Arc` across threads. +- **XGBoost < 1.4**: `Booster` does NOT implement `Send + Sync`. Use one booster per thread or wrap in `Arc>`. + +### Example with XGBoost ≥ 1.4 + +```rust +use std::sync::Arc; +use std::thread; +use xgboost_rust::Booster; + +let booster = Arc::new(Booster::load("model.json")?); +let booster_clone = booster.clone(); + +thread::spawn(move || { + // Safe concurrent predictions with XGBoost ≥ 1.4 + let predictions = booster_clone.predict(&data, rows, cols, 0, false)?; +}); +``` + +The version check happens automatically at build time based on the `XGBOOST_VERSION` environment variable. ## Examples diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..c76c9f5 --- /dev/null +++ b/build.rs @@ -0,0 +1,512 @@ +extern crate bindgen; + +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::env; +use std::fs; +use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; +use std::thread; +use std::time::Duration; + +fn get_xgboost_version() -> String { + env::var("XGBOOST_VERSION").unwrap_or_else(|_| "3.1.1".to_string()) +} + +// Known SHA256 checksums for header files by version +fn get_header_checksums() -> HashMap<&'static str, (&'static str, &'static str)> { + let mut checksums = HashMap::new(); + // Format: version => (c_api.h SHA256, base.h SHA256) + checksums.insert( + "3.1.1", + ( + "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", + "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a", + ), + ); + checksums.insert( + "3.0.5", + ( + "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", + "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7", + ), + ); + checksums.insert( + "2.1.4", + ( + "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", + "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce", + ), + ); + checksums.insert( + "1.7.6", + ( + "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", + "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935", + ), + ); + checksums.insert( + "1.4.2", + ( + "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", + "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7", + ), + ); + checksums +} + +fn compute_sha256(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + format!("{:x}", hasher.finalize()) +} + +fn verify_checksum( + data: &[u8], + expected: &str, + filename: &str, +) -> Result<(), Box> { + let actual = compute_sha256(data); + if actual != expected { + return Err(format!( + "SHA256 checksum mismatch for {}:\n Expected: {}\n Got: {}", + filename, expected, actual + ) + .into()); + } + println!("cargo:warning=✓ Verified SHA256 for {}", filename); + Ok(()) +} + +fn parse_version(version: &str) -> (u32, u32, u32) { + let parts: Vec<&str> = version.split('.').collect(); + let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0); + let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); + (major, minor, patch) +} + +fn emit_version_cfg_flags(version: &str) { + let (major, minor, _patch) = parse_version(version); + + // XGBoost 1.4.0+ has thread-safe predictions for tree models + // See: https://github.com/dmlc/xgboost/issues/5339 + if major > 1 || (major == 1 && minor >= 4) { + println!("cargo:rustc-cfg=xgboost_thread_safe"); + println!( + "cargo:warning=XGBoost version {} supports thread-safe predictions", + version + ); + } else { + println!( + "cargo:warning=XGBoost version {} does NOT support thread-safe predictions", + version + ); + } +} + +fn get_platform_info() -> (String, String) { + let target = env::var("TARGET").unwrap(); + + // Determine OS + let os = if target.contains("apple-darwin") { + "darwin" + } else if target.contains("linux") { + "linux" + } else if target.contains("windows") { + "windows" + } else { + panic!("Unsupported target: {}", target); + }; + + // Determine architecture + let arch = if target.contains("x86_64") { + "x86_64" + } else if target.contains("aarch64") || target.contains("arm64") { + "aarch64" + } else if target.contains("i686") || target.contains("i586") { + "i686" + } else { + panic!("Unsupported architecture for target: {}", target); + }; + + (os.to_string(), arch.to_string()) +} + +fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box> { + let version = get_xgboost_version(); + let checksums = get_header_checksums(); + + // Get expected checksums for this version + let (c_api_expected, base_expected) = checksums.get(version.as_str()).ok_or_else(|| { + format!( + "No known SHA256 checksums for XGBoost version {}. \ + Please verify this version manually or add checksums to build.rs", + version + ) + })?; + + // Create the include/xgboost directory + let include_dir = out_dir.join("include/xgboost"); + fs::create_dir_all(&include_dir)?; + + // Download and verify c_api.h + let c_api_path = include_dir.join("c_api.h"); + download_and_verify_file( + &format!( + "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/c_api.h", + version + ), + &c_api_path, + c_api_expected, + "c_api.h", + )?; + + // Download and verify base.h + let base_path = include_dir.join("base.h"); + download_and_verify_file( + &format!( + "https://raw.githubusercontent.com/dmlc/xgboost/v{}/include/xgboost/base.h", + version + ), + &base_path, + base_expected, + "base.h", + )?; + + Ok(()) +} + +fn download_and_verify_file( + url: &str, + dest_path: &Path, + expected_sha256: &str, + filename: &str, +) -> Result<(), Box> { + println!("cargo:warning=Downloading {} from: {}", filename, url); + + // Download into memory buffer + let response = ureq::get(url).call()?; + let status = response.status(); + if !(200..300).contains(&status) { + return Err(format!("Failed to download {}: HTTP {}", filename, status).into()); + } + + let mut buffer = Vec::new(); + response.into_reader().read_to_end(&mut buffer)?; + + // Verify SHA256 checksum + verify_checksum(&buffer, expected_sha256, filename)?; + + // Only write file after successful verification + let mut file = fs::File::create(dest_path)?; + file.write_all(&buffer)?; + + Ok(()) +} + +fn download_with_retry(url: &str, max_retries: u32) -> Result, Box> { + let mut last_error = None; + + for attempt in 0..max_retries { + if attempt > 0 { + let backoff = Duration::from_millis(100 * 2_u64.pow(attempt)); + println!( + "cargo:warning=Retry attempt {} after {:?}", + attempt + 1, + backoff + ); + thread::sleep(backoff); + } + + match ureq::get(url).call() { + Ok(response) => { + let status = response.status(); + if !(200..300).contains(&status) { + last_error = Some(format!("HTTP {}", status)); + continue; + } + + let mut buffer = Vec::new(); + if let Err(e) = response.into_reader().read_to_end(&mut buffer) { + last_error = Some(e.to_string()); + continue; + } + + return Ok(buffer); + } + Err(e) => { + last_error = Some(e.to_string()); + } + } + } + + Err(format!( + "Failed to download after {} attempts. Last error: {}", + max_retries, + last_error.unwrap_or_else(|| "Unknown error".to_string()) + ) + .into()) +} + +fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box> { + let (os, arch) = get_platform_info(); + let version = get_xgboost_version(); + + // Determine wheel filename based on platform + let wheel_filename = match (os.as_str(), arch.as_str()) { + ("linux", "x86_64") => format!("xgboost-{}-py3-none-manylinux_2_28_x86_64.whl", version), + ("linux", "aarch64") => format!("xgboost-{}-py3-none-manylinux_2_28_aarch64.whl", version), + ("darwin", "x86_64") => format!("xgboost-{}-py3-none-macosx_10_15_x86_64.whl", version), + ("darwin", "aarch64") => format!("xgboost-{}-py3-none-macosx_12_0_arm64.whl", version), + ("windows", "x86_64") => format!("xgboost-{}-py3-none-win_amd64.whl", version), + _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + }; + + let lib_filename = match os.as_str() { + "windows" => "xgboost.dll", + "darwin" => "libxgboost.dylib", + _ => "libxgboost.so", + }; + + // Setup paths + let wheel_dir = out_dir.join("wheel"); + let lib_dir = out_dir.join("libs"); + fs::create_dir_all(&wheel_dir)?; + fs::create_dir_all(&lib_dir)?; + + let wheel_path = wheel_dir.join(&wheel_filename); + let lib_dest_path = lib_dir.join(lib_filename); + + // Check if library already exists and is valid + if lib_dest_path.exists() { + println!( + "cargo:warning=Using cached XGBoost library at: {}", + lib_dest_path.display() + ); + return Ok(()); + } + + // Check if wheel is cached + let wheel_buffer = if wheel_path.exists() { + println!( + "cargo:warning=Using cached wheel at: {}", + wheel_path.display() + ); + fs::read(&wheel_path)? + } else { + // Download wheel with retry + let download_url = format!( + "https://files.pythonhosted.org/packages/py3/x/xgboost/{}", + wheel_filename + ); + + println!( + "cargo:warning=Downloading XGBoost wheel from: {}", + download_url + ); + let buffer = download_with_retry(&download_url, 3)?; + + // Write atomically (temp file + rename) + let temp_path = wheel_path.with_extension("tmp"); + { + let mut temp_file = fs::File::create(&temp_path)?; + temp_file.write_all(&buffer)?; + temp_file.sync_all()?; + } + fs::rename(&temp_path, &wheel_path)?; + + println!("cargo:warning=✓ Downloaded and cached wheel"); + buffer + }; + + // Extract library from wheel + println!("cargo:warning=Extracting library from wheel"); + + let cursor = io::Cursor::new(wheel_buffer); + let mut archive = zip::ZipArchive::new(cursor)?; + + // Search for the library file in the wheel + let mut found = false; + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let file_path = file.name().to_string(); + + // Look for the library file (usually in xgboost/lib/) + if file_path.ends_with(lib_filename) { + println!("cargo:warning=Found library at: {}", file_path); + + // Extract to temp file, then rename atomically + let temp_dest_path = lib_dest_path.with_extension("tmp"); + { + let mut dest = fs::File::create(&temp_dest_path)?; + io::copy(&mut file, &mut dest)?; + dest.sync_all()?; + } + fs::rename(&temp_dest_path, &lib_dest_path)?; + + found = true; + break; + } + } + + if !found { + return Err(format!("Library file {} not found in wheel", lib_filename).into()); + } + + println!( + "cargo:warning=✓ Successfully extracted XGBoost library to: {}", + lib_dir.display() + ); + + Ok(()) +} + +fn main() { + // Tell cargo about custom cfg flags we emit + println!("cargo:rustc-check-cfg=cfg(xgboost_thread_safe)"); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let xgb_include_root = out_dir.join("include"); + + // Get version and emit cfg flags for thread safety + let version = get_xgboost_version(); + emit_version_cfg_flags(&version); + + // Download the headers + if let Err(e) = download_xgboost_headers(&out_dir) { + eprintln!("Failed to download XGBoost headers: {}", e); + panic!("Cannot proceed without headers"); + } + + // Download and extract the wheel + if let Err(e) = download_and_extract_wheel(&out_dir) { + eprintln!("Failed to download and extract wheel: {}", e); + panic!("Cannot proceed without compiled library"); + } + + let bindings = bindgen::Builder::default() + .header("wrapper.h") + .clang_arg(format!("-I{}", xgb_include_root.display())) + // Generate bindings for XGB and XGD functions (Booster and DMatrix) + .allowlist_function("XGB.*") + .allowlist_function("XGD.*") + // Allowlist the main types we need + .allowlist_type("BoosterHandle") + .allowlist_type("DMatrixHandle") + .allowlist_type("bst_ulong") + .size_t_is_usize(true) + .generate() + .expect("Unable to generate bindings."); + + bindings + .write_to_file(out_dir.join("bindings.rs")) + .expect("Couldn't write bindings."); + + // Get platform info + let (os, _arch) = get_platform_info(); + + // Determine the library filename based on the OS + let lib_filename = match os.as_str() { + "windows" => "xgboost.dll", + "darwin" => "libxgboost.dylib", + _ => "libxgboost.so", + }; + + // Copy the library from OUT_DIR/libs to the final target directory + let lib_source_path = out_dir.join("libs").join(lib_filename); + + // Find the final output directory (e.g., target/release) + let target_dir = out_dir + .ancestors() + .find(|p| p.ends_with("target")) + .unwrap() + .join(env::var("PROFILE").unwrap()); + + let lib_dest_path = target_dir.join(lib_filename); + fs::copy(&lib_source_path, &lib_dest_path).expect("Failed to copy library to target directory"); + + // On macOS/Linux, change the install name/soname to use @loader_path/$ORIGIN + if os == "darwin" { + use std::process::Command; + let _ = Command::new("install_name_tool") + .arg("-id") + .arg(format!("@loader_path/{}", lib_filename)) + .arg(&lib_source_path) + .status(); + let _ = Command::new("install_name_tool") + .arg("-id") + .arg(format!("@loader_path/{}", lib_filename)) + .arg(&lib_dest_path) + .status(); + } else if os == "linux" { + use std::process::Command; + // Use patchelf to set soname (if available) + let _ = Command::new("patchelf") + .arg("--set-soname") + .arg(lib_filename) + .arg(&lib_source_path) + .output(); + let _ = Command::new("patchelf") + .arg("--set-soname") + .arg(lib_filename) + .arg(&lib_dest_path) + .output(); + } + + // Set the library search path for the build-time linker + let lib_search_path = out_dir.join("libs"); + println!( + "cargo:rustc-link-search=native={}", + lib_search_path.display() + ); + + // Set the rpath for the run-time linker based on the OS + match os.as_str() { + "darwin" => { + // For macOS, add multiple rpath entries for IDE compatibility + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../.."); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); + } + } + "linux" => { + // For Linux, use $ORIGIN + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); + } + } + _ => {} // No rpath needed for Windows + } + + println!("cargo:rustc-link-lib=dylib=xgboost"); +} diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs new file mode 100644 index 0000000..8ad0ba9 --- /dev/null +++ b/examples/advanced_usage.rs @@ -0,0 +1,73 @@ +use xgboost_rust::{Booster, XGBoostResult}; + +fn main() -> XGBoostResult<()> { + println!("XGBoost Rust Bindings - Advanced Usage Example"); + println!("===============================================\n"); + + // Load model + let model_path = "iris_model.json"; + println!("Loading model from: {}", model_path); + + let booster = match Booster::load(model_path) { + Ok(b) => { + println!("✓ Model loaded successfully\n"); + b + } + Err(e) => { + eprintln!("✗ Failed to load model: {}", e); + eprintln!("\nPlease create a model file first. See basic_usage.rs for instructions."); + return Err(e); + } + }; + + // Example: Binary classification prediction + println!("=== Example 1: Normal Prediction ==="); + let data = vec![5.1, 3.5, 1.4, 0.2, 6.7, 3.0, 5.2, 2.3]; + + let predictions = booster.predict(&data, 2, 4, 0, false)?; + println!("Normal predictions (probabilities): {:?}\n", predictions); + + // Example: Get SHAP values (feature contributions) + println!("=== Example 2: Feature Contributions (SHAP) ==="); + use xgboost_rust::predict_option; + + let shap_values = booster.predict(&data, 2, 4, predict_option::PRED_CONTRIBS, false)?; + + println!("SHAP values for first sample:"); + // SHAP values include one extra value for bias term + let num_features = 4; + for (i, &value) in shap_values.iter().enumerate().take(num_features + 1) { + if i < num_features { + println!(" Feature {}: {:.4}", i, value); + } else { + println!(" Bias term: {:.4}", value); + } + } + println!(); + + // Example: Save model to a new location + println!("=== Example 3: Save Model ==="); + let save_path = "iris_model_copy.json"; + booster.save(save_path)?; + println!("✓ Model saved to: {}\n", save_path); + + // Example: Load from buffer (useful for embedded models) + println!("=== Example 4: Load from Buffer ==="); + let buffer = std::fs::read(model_path).map_err(|e| xgboost_rust::XGBoostError { + description: format!("Failed to read model file: {}", e), + })?; + let booster_from_buffer = Booster::load_from_buffer(&buffer)?; + println!("✓ Model loaded from buffer ({} bytes)", buffer.len()); + + let num_features = booster_from_buffer.num_features()?; + println!(" Model has {} features\n", num_features); + + // Make a prediction with the buffer-loaded model + let test_data = vec![5.1, 3.5, 1.4, 0.2]; + let pred = booster_from_buffer.predict(&test_data, 1, 4, 0, false)?; + println!("Prediction from buffer-loaded model: {:?}\n", pred); + + println!("Advanced examples completed successfully!"); + + Ok(()) +} diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs new file mode 100644 index 0000000..d167a0a --- /dev/null +++ b/examples/basic_usage.rs @@ -0,0 +1,95 @@ +use xgboost_rust::{Booster, XGBoostResult}; + +fn main() -> XGBoostResult<()> { + println!("XGBoost Rust Bindings - Basic Usage Example"); + println!("============================================\n"); + + // Note: This example assumes you have a trained XGBoost model file. + // To create one, you can use Python: + // + // ```python + // import xgboost as xgb + // from sklearn.datasets import load_iris + // from sklearn.model_selection import train_test_split + // + // # Load iris dataset + // iris = load_iris() + // X_train, X_test, y_train, y_test = train_test_split( + // iris.data, iris.target, test_size=0.2, random_state=42 + // ) + // + // # Train model + // dtrain = xgb.DMatrix(X_train, label=y_train) + // params = { + // 'objective': 'multi:softprob', + // 'num_class': 3, + // 'max_depth': 3, + // 'eta': 0.3 + // } + // bst = xgb.train(params, dtrain, num_boost_round=10) + // + // # Save model + // bst.save_model('iris_model.json') + // ``` + + // Load a pre-trained model + let model_path = "iris_model.json"; + + println!("Loading model from: {}", model_path); + let booster = match Booster::load(model_path) { + Ok(b) => { + println!("✓ Model loaded successfully\n"); + b + } + Err(e) => { + eprintln!("✗ Failed to load model: {}", e); + eprintln!("\nPlease create a model file first using the Python code in the example."); + return Err(e); + } + }; + + // Get model information + let num_features = booster.num_features()?; + println!("Model expects {} features\n", num_features); + + // Example prediction data (4 features for iris dataset) + // This is a sample from the iris dataset + let data = vec![ + 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa + 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica + 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor + ]; + let num_rows = 3; + let num_features = 4; + + println!("Making predictions on {} samples...", num_rows); + + // Make predictions + let predictions = booster.predict(&data, num_rows, num_features, 0, false)?; + + println!("✓ Predictions complete\n"); + + // Print results + println!("Predictions (probabilities for 3 classes):"); + for i in 0..num_rows { + println!("Sample {}:", i + 1); + println!(" Class 0 (Setosa): {:.4}", predictions[i * 3]); + println!(" Class 1 (Versicolor): {:.4}", predictions[i * 3 + 1]); + println!(" Class 2 (Virginica): {:.4}", predictions[i * 3 + 2]); + + // Find predicted class (argmax) + let mut max_prob = predictions[i * 3]; + let mut predicted_class = 0; + for j in 1..3 { + if predictions[i * 3 + j] > max_prob { + max_prob = predictions[i * 3 + j]; + predicted_class = j; + } + } + println!(" → Predicted class: {}\n", predicted_class); + } + + println!("Example completed successfully!"); + + Ok(()) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..88e17d5 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,39 @@ +use crate::sys; +use std::ffi::CStr; +use std::fmt; + +pub type XGBoostResult = std::result::Result; + +#[derive(Debug, Eq, PartialEq)] +pub struct XGBoostError { + pub description: String, +} + +impl XGBoostError { + /// Check the return value from an XGBoost FFI call, and return the last error message on error. + /// Return values of 0 are treated as success, non-zero values are treated as errors. + pub fn check_return_value(ret_val: i32) -> XGBoostResult<()> { + if ret_val == 0 { + Ok(()) + } else { + Err(XGBoostError::fetch_xgboost_error()) + } + } + + /// Fetch current error message from XGBoost. + fn fetch_xgboost_error() -> Self { + let c_str = unsafe { CStr::from_ptr(sys::XGBGetLastError()) }; + let str_slice = c_str.to_str().unwrap_or("Unknown error"); + XGBoostError { + description: str_slice.to_owned(), + } + } +} + +impl fmt::Display for XGBoostError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.description) + } +} + +impl std::error::Error for XGBoostError {} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..60d3931 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,22 @@ +// Include the XGBoost C API bindings +mod sys; + +mod error; +pub use crate::error::{XGBoostError, XGBoostResult}; + +mod model; +pub use crate::model::Booster; + +// Re-export prediction option constants for convenience +pub mod predict_option { + /// Normal prediction, output is the transformed probability + pub const OUTPUT_MARGIN: u32 = 0x01; + /// Output the untransformed margin value + pub const PRED_LEAF: u32 = 0x02; + /// Output the leaf index of trees + pub const PRED_CONTRIBS: u32 = 0x04; + /// Output feature contributions (SHAP values) + pub const PRED_APPROX_CONTRIBS: u32 = 0x08; + /// Output feature interaction contributions + pub const PRED_INTERACTIONS: u32 = 0x10; +} diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..9ab8a5d --- /dev/null +++ b/src/model.rs @@ -0,0 +1,296 @@ +use crate::error::{XGBoostError, XGBoostResult}; +use crate::sys; +use std::ffi::CString; +use std::path::Path; +use std::ptr; + +/// An XGBoost Booster for making predictions. +/// +/// # Thread Safety +/// +/// **Thread safety depends on XGBoost version:** +/// - **XGBoost ≥ 1.4**: Predictions are thread-safe for tree models (gbtree/dart). +/// `Send` and `Sync` are automatically implemented for these versions. +/// You can safely share `Arc` across threads. +/// - **XGBoost < 1.4**: NOT thread-safe. `Send` and `Sync` are NOT implemented. +/// +/// ## Usage with XGBoost ≥ 1.4 +/// +/// ```ignore +/// use std::sync::Arc; +/// use std::thread; +/// +/// let booster = Arc::new(Booster::load("model.json")?); +/// let booster_clone = booster.clone(); +/// +/// thread::spawn(move || { +/// // Safe to call predict concurrently +/// booster_clone.predict(...); +/// }); +/// ``` +/// +/// ## Usage with XGBoost < 1.4 +/// +/// For older versions, use one of these approaches: +/// +/// 1. **Create one Booster per thread** (recommended): +/// ```ignore +/// let booster = Booster::load("model.json")?; +/// thread::spawn(move || { +/// booster.predict(...); // Each thread owns its Booster +/// }); +/// ``` +/// +/// 2. **Wrap in Arc>**: +/// ```ignore +/// use std::sync::{Arc, Mutex}; +/// let booster = Arc::new(Mutex::new(Booster::load("model.json")?)); +/// ``` +pub struct Booster { + handle: sys::BoosterHandle, +} + +// Thread safety implementation based on XGBoost version +// XGBoost 1.4.0+ supports thread-safe predictions for tree models (gbtree/dart) +// See: https://github.com/dmlc/xgboost/issues/5339 +#[cfg(xgboost_thread_safe)] +unsafe impl Send for Booster {} + +#[cfg(xgboost_thread_safe)] +unsafe impl Sync for Booster {} + +// For XGBoost < 1.4, Send and Sync are NOT implemented. +// Users should wrap in Arc> or use one Booster per thread. + +impl Booster { + /// Load a model from a file + /// + /// # Arguments + /// * `path` - Path to the model file (can be JSON, binary, or deprecated text format) + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// ``` + pub fn load>(path: P) -> XGBoostResult { + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; + + // Create a booster first + let mut handle: sys::BoosterHandle = ptr::null_mut(); + XGBoostError::check_return_value(unsafe { + sys::XGBoosterCreate(ptr::null(), 0, &mut handle) + })?; + + // Load model into the booster + let result = XGBoostError::check_return_value(unsafe { + sys::XGBoosterLoadModel(handle, path_c_str.as_ptr()) + }); + + if let Err(e) = result { + unsafe { + sys::XGBoosterFree(handle); + } + return Err(e); + } + + Ok(Booster { handle }) + } + + /// Load a model from a memory buffer + /// + /// # Arguments + /// * `buffer` - Model content as bytes + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// use std::fs; + /// + /// let buffer = fs::read("model.json").unwrap(); + /// let booster = Booster::load_from_buffer(&buffer).unwrap(); + /// ``` + pub fn load_from_buffer(buffer: &[u8]) -> XGBoostResult { + // Create a booster first + let mut handle: sys::BoosterHandle = ptr::null_mut(); + XGBoostError::check_return_value(unsafe { + sys::XGBoosterCreate(ptr::null(), 0, &mut handle) + })?; + + // Load model from buffer into the booster + let result = XGBoostError::check_return_value(unsafe { + sys::XGBoosterLoadModelFromBuffer( + handle, + buffer.as_ptr() as *const std::os::raw::c_void, + buffer.len() as u64, + ) + }); + + if let Err(e) = result { + unsafe { + sys::XGBoosterFree(handle); + } + return Err(e); + } + + Ok(Booster { handle }) + } + + /// Make predictions on data + /// + /// # Arguments + /// * `data` - 2D array of features (row-major, num_rows x num_features) + /// * `num_rows` - Number of rows in the data + /// * `num_features` - Number of features per row + /// * `option_mask` - Prediction options (see `predict_option` module) + /// * `training` - Whether this is for training (false for inference) + /// + /// # Returns + /// A vector of prediction values + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// let data = vec![1.0, 2.0, 3.0, 4.0]; // 2 rows, 2 features + /// let predictions = booster.predict(&data, 2, 2, 0, false).unwrap(); + /// ``` + pub fn predict( + &self, + data: &[f32], + num_rows: usize, + num_features: usize, + option_mask: u32, + training: bool, + ) -> XGBoostResult> { + // Validate input dimensions + let expected_len = num_rows + .checked_mul(num_features) + .ok_or_else(|| XGBoostError { + description: format!( + "Integer overflow: num_rows ({}) * num_features ({}) exceeds usize::MAX", + num_rows, num_features + ), + })?; + + if data.len() != expected_len { + return Err(XGBoostError { + description: format!( + "Data length mismatch: expected {} elements ({}×{}), got {}", + expected_len, + num_rows, + num_features, + data.len() + ), + }); + } + + // Create DMatrix from data + let mut dmatrix_handle: sys::DMatrixHandle = ptr::null_mut(); + + XGBoostError::check_return_value(unsafe { + sys::XGDMatrixCreateFromMat( + data.as_ptr(), + num_rows as u64, + num_features as u64, + f32::NAN, + &mut dmatrix_handle, + ) + })?; + + // RAII guard to ensure DMatrix is always freed + struct DMatrixGuard(sys::DMatrixHandle); + impl Drop for DMatrixGuard { + fn drop(&mut self) { + unsafe { + sys::XGDMatrixFree(self.0); + } + } + } + let _guard = DMatrixGuard(dmatrix_handle); + + // Make prediction + let mut out_len: u64 = 0; + let mut out_result: *const f32 = ptr::null(); + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterPredict( + self.handle, + dmatrix_handle, + option_mask as i32, + 0, // ntree_limit (0 means use all trees) + training as i32, + &mut out_len, + &mut out_result, + ) + })?; + + // Validate output pointers + if out_result.is_null() || out_len == 0 { + return Err(XGBoostError { + description: "XGBoost returned null or empty prediction result".to_string(), + }); + } + + // Copy results to a Vec + let results = unsafe { std::slice::from_raw_parts(out_result, out_len as usize).to_vec() }; + + // DMatrix will be automatically freed when _guard goes out of scope + + Ok(results) + } + + /// Get the number of features the model expects + /// + /// # Returns + /// The number of features + pub fn num_features(&self) -> XGBoostResult { + let mut out_num_features: u64 = 0; + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterGetNumFeature(self.handle, &mut out_num_features) + })?; + + Ok(out_num_features as usize) + } + + /// Save the model to a file + /// + /// # Arguments + /// * `path` - Path where to save the model + /// + /// # Example + /// ```no_run + /// use xgboost_rust::Booster; + /// + /// let booster = Booster::load("model.json").unwrap(); + /// booster.save("model_copy.json").unwrap(); + /// ``` + pub fn save>(&self, path: P) -> XGBoostResult<()> { + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; + + XGBoostError::check_return_value(unsafe { + sys::XGBoosterSaveModel(self.handle, path_c_str.as_ptr()) + }) + } +} + +impl Drop for Booster { + fn drop(&mut self) { + unsafe { + sys::XGBoosterFree(self.handle); + } + } +} diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..cd503e4 --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,6 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); diff --git a/wrapper.h b/wrapper.h new file mode 100644 index 0000000..1213e62 --- /dev/null +++ b/wrapper.h @@ -0,0 +1 @@ +#include