diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ed2b0b3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,252 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Test ${{ matrix.os }} (${{ matrix.arch }}) - XGBoost ${{ matrix.xgboost_version }} + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + # macOS ARM64 (M1/M2/M3) - Test multiple versions + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.1.1" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "3.0.5" + - os: macos + arch: arm64 + runner: macos-14 + xgboost_version: "2.1.4" + + # macOS x86_64 (Intel) + - os: macos + arch: x86_64 + runner: macos-15-intel + xgboost_version: "3.1.1" + + # Linux x86_64 - Test multiple versions including thread-safety boundary + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.7.6" + - os: linux + arch: x86_64 + runner: ubuntu-latest + xgboost_version: "1.4.2" # First thread-safe version + + # Linux ARM64 + - os: linux + arch: arm64 + runner: ubuntu-latest + xgboost_version: "3.1.1" + + # Windows x86_64 - Disabled: Python wheels don't include import libraries (.lib) + # needed for MSVC linking. Enable when we add Windows-specific build support. + # - os: windows + # arch: x86_64 + # runner: windows-latest + # xgboost_version: "3.1.1" + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v4 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v4 + with: + path: target + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target- + + - name: Install system dependencies (Linux) + if: matrix.os == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Install system dependencies (macOS) + if: matrix.os == 'macos' + run: | + brew install libomp + + - name: Check build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo check --verbose + + - name: Build (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo build --verbose + + - name: Run tests (no features) + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: cargo test --verbose + + - name: Build examples + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + cargo build --example basic_usage --verbose + cargo build --example advanced_usage --verbose + + - name: Verify library architecture (macOS) + if: matrix.os == 'macos' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.dylib + lipo -info target/debug/libxgboost.dylib || otool -L target/debug/libxgboost.dylib + + - name: Verify library architecture (Linux) + if: matrix.os == 'linux' + run: | + echo "Checking library architecture..." + file target/debug/libxgboost.so + readelf -h target/debug/libxgboost.so | grep Machine + + - name: Verify library exists (Windows) + if: matrix.os == 'windows' + run: | + echo "Checking library exists..." + Get-Item target/debug/xgboost.dll + + - name: Verify thread safety detection + env: + XGBOOST_VERSION: ${{ matrix.xgboost_version }} + run: | + echo "XGBoost version: ${{ matrix.xgboost_version }}" + cargo build --verbose 2>&1 | grep "thread-safe" || echo "No thread-safe message found" + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Run clippy (no features) + run: cargo clippy -- -D warnings + + fmt: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + + # Test checksum verification + security-checksums: + name: Verify SHA256 checksums + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Test checksum verification for XGBoost 3.1.1 + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 3.1.1" + + - name: Test checksum verification for XGBoost 2.1.4 + env: + XGBOOST_VERSION: "2.1.4" + run: | + cargo clean + cargo check --verbose 2>&1 | tee build.log + grep "✓ Verified SHA256" build.log + echo "Checksum verification working for 2.1.4" + + # Test caching behavior + caching-test: + name: Test wheel caching + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: First build (should download) + env: + XGBOOST_VERSION: "3.1.1" + run: | + cargo check --verbose 2>&1 | tee build1.log + grep "Downloading XGBoost wheel" build1.log || true + + - name: Second build (should use cache) + env: + XGBOOST_VERSION: "3.1.1" + run: | + touch src/lib.rs + cargo check --verbose 2>&1 | tee build2.log + grep "Using cached XGBoost library" build2.log + echo "Caching is working correctly" diff --git a/README.md b/README.md index 72a68bf..036049e 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Rust bindings for XGBoost, a gradient boosting library for machine learning. ## Features - **Automatic Binary Download**: Downloads XGBoost binaries at build time from PyPI wheels -- **Cross-Platform**: Supports Linux (x86_64, aarch64), macOS (x86_64, arm64), and Windows (x86_64) +- **Cross-Platform**: Supports Linux (x86_64, aarch64) and macOS (x86_64, arm64) - **Version Control**: Specify XGBoost version via `XGBOOST_VERSION` environment variable - **Version-Aware Thread Safety**: Automatically enables `Send + Sync` for XGBoost ≥ 1.4 - **Easy to Use**: Simple, safe Rust API wrapping the XGBoost C API @@ -33,6 +33,11 @@ xgboost-rust = "0.1.0" **Linux**: - Optional: `patchelf` (for setting SONAME, but not required) +**Windows**: +- ⚠️ **Not currently supported via automatic download** +- Python wheels don't include import libraries (`.lib`) needed for MSVC linking +- Alternative: Build XGBoost from source or use WSL/MinGW + ## Usage ### Basic Example diff --git a/build.rs b/build.rs index ab614e8..04691f9 100644 --- a/build.rs +++ b/build.rs @@ -1,11 +1,11 @@ extern crate bindgen; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::env; -use std::path::{Path, PathBuf}; use std::fs; use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; use std::thread; use std::time::Duration; @@ -17,26 +17,41 @@ fn get_xgboost_version() -> String { fn get_header_checksums() -> HashMap<&'static str, (&'static str, &'static str)> { let mut checksums = HashMap::new(); // Format: version => (c_api.h SHA256, base.h SHA256) - checksums.insert("3.1.1", ( - "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", - "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a" - )); - checksums.insert("3.0.5", ( - "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", - "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7" - )); - checksums.insert("2.1.4", ( - "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", - "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce" - )); - checksums.insert("1.7.6", ( - "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", - "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935" - )); - checksums.insert("1.4.2", ( - "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", - "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7" - )); + checksums.insert( + "3.1.1", + ( + "c0f0a98eb36fb5e451fdd3e9ead2d185f4c61be2a6997fc295e5d1a94f3096e2", + "8d771fb20e03f3443e21cfdcd26ac5cd880be585b8817f2e0d146e7c5c7bb63a", + ), + ); + checksums.insert( + "3.0.5", + ( + "2ccec6e5301fa5a1324f60af48b9c6be5879e590ed583ec9d74297e6018860bc", + "47f0148706907ccecb72b8484687524bc36d58b4c6fe5e7b81e59de157261ea7", + ), + ); + checksums.insert( + "2.1.4", + ( + "b804850ec6c7a00f8e36f139dfce7fe348fc9ad066ff4cb7ac44a4f5420ec1dd", + "525c4a2ba2f6bd9b17a299978e16f91897d497d6ae0ae5df2335dd059f00d0ce", + ), + ); + checksums.insert( + "1.7.6", + ( + "145ed1df652937122b6f6bc31331051eabc02226a0b62349ea593cdbe841c20d", + "b26e17eadbcc6350dc900b35d164eedc02b1cd2a64913c560d4d416c81a68935", + ), + ); + checksums.insert( + "1.4.2", + ( + "3f5de5d046a3c9576e0c560abe5fa1e889f72b4b18ff2bf73e5f98290d47d0dc", + "e3abfcc730eee86acf44124d5496a2b41413f963c4bbf560513eeae0b7d12fb7", + ), + ); checksums } @@ -46,13 +61,18 @@ fn compute_sha256(data: &[u8]) -> String { format!("{:x}", hasher.finalize()) } -fn verify_checksum(data: &[u8], expected: &str, filename: &str) -> Result<(), Box> { +fn verify_checksum( + data: &[u8], + expected: &str, + filename: &str, +) -> Result<(), Box> { let actual = compute_sha256(data); if actual != expected { return Err(format!( "SHA256 checksum mismatch for {}:\n Expected: {}\n Got: {}", filename, expected, actual - ).into()); + ) + .into()); } println!("cargo:warning=✓ Verified SHA256 for {}", filename); Ok(()) @@ -60,7 +80,7 @@ fn verify_checksum(data: &[u8], expected: &str, filename: &str) -> Result<(), Bo fn parse_version(version: &str) -> (u32, u32, u32) { let parts: Vec<&str> = version.split('.').collect(); - let major = parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(0); + let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0); let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); (major, minor, patch) @@ -73,9 +93,15 @@ fn emit_version_cfg_flags(version: &str) { // See: https://github.com/dmlc/xgboost/issues/5339 if major > 1 || (major == 1 && minor >= 4) { println!("cargo:rustc-cfg=xgboost_thread_safe"); - println!("cargo:warning=XGBoost version {} supports thread-safe predictions", version); + println!( + "cargo:warning=XGBoost version {} supports thread-safe predictions", + version + ); } else { - println!("cargo:warning=XGBoost version {} does NOT support thread-safe predictions", version); + println!( + "cargo:warning=XGBoost version {} does NOT support thread-safe predictions", + version + ); } } @@ -112,12 +138,13 @@ fn download_xgboost_headers(out_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box> { println!("cargo:warning=Downloading {} from: {}", filename, url); // Download into memory buffer let response = ureq::get(url).call()?; let status = response.status(); - if status < 200 || status >= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download {}: HTTP {}", filename, status).into()); } @@ -178,14 +211,18 @@ fn download_with_retry(url: &str, max_retries: u32) -> Result, Box 0 { let backoff = Duration::from_millis(100 * 2_u64.pow(attempt)); - println!("cargo:warning=Retry attempt {} after {:?}", attempt + 1, backoff); + println!( + "cargo:warning=Retry attempt {} after {:?}", + attempt + 1, + backoff + ); thread::sleep(backoff); } match ureq::get(url).call() { Ok(response) => { let status = response.status(); - if status < 200 || status >= 300 { + if !(200..300).contains(&status) { last_error = Some(format!("HTTP {}", status)); continue; } @@ -208,19 +245,58 @@ fn download_with_retry(url: &str, max_retries: u32) -> Result, Box Result<(), Box> { let (os, arch) = get_platform_info(); let version = get_xgboost_version(); + let (major, minor, _patch) = parse_version(&version); - // Determine wheel filename based on platform + // Determine wheel filename based on platform and version + // Different XGBoost versions use different manylinux tags let wheel_filename = match (os.as_str(), arch.as_str()) { - ("linux", "x86_64") => format!("xgboost-{}-py3-none-manylinux_2_28_x86_64.whl", version), - ("linux", "aarch64") => format!("xgboost-{}-py3-none-manylinux_2_28_aarch64.whl", version), - ("darwin", "x86_64") => format!("xgboost-{}-py3-none-macosx_10_15_x86_64.whl", version), - ("darwin", "aarch64") => format!("xgboost-{}-py3-none-macosx_12_0_arm64.whl", version), + ("linux", "x86_64") => { + // Choose manylinux tag based on version + let manylinux_tag = if major >= 3 { + "manylinux_2_28" + } else if major == 1 && minor == 4 { + "manylinux2010" + } else { + "manylinux2014" + }; + format!("xgboost-{}-py3-none-{}_x86_64.whl", version, manylinux_tag) + } + ("linux", "aarch64") => { + let manylinux_tag = if major >= 3 { + "manylinux_2_28" + } else { + "manylinux2014" + }; + format!("xgboost-{}-py3-none-{}_aarch64.whl", version, manylinux_tag) + } + ("darwin", "x86_64") => { + // macOS x86_64 wheel names changed between versions + if major >= 3 { + format!("xgboost-{}-py3-none-macosx_10_15_x86_64.whl", version) + } else if major == 1 && minor == 4 { + format!("xgboost-{}-py3-none-macosx_10_14_x86_64.macosx_10_15_x86_64.macosx_11_0_x86_64.whl", version) + } else { + // Versions 1.7.x and 2.x use multi-platform tag + format!("xgboost-{}-py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64.whl", version) + } + } + ("darwin", "aarch64") => { + // macOS arm64 support started with version 1.5.0 + if major == 1 && minor < 5 { + return Err(format!( + "XGBoost {} does not have macOS arm64 support. Minimum version for arm64 is 1.5.0", + version + ).into()); + } + format!("xgboost-{}-py3-none-macosx_12_0_arm64.whl", version) + } ("windows", "x86_64") => format!("xgboost-{}-py3-none-win_amd64.whl", version), _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), }; @@ -242,13 +318,19 @@ fn download_and_extract_wheel(out_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box { // For Linux, use $ORIGIN println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); // Add the target directory to rpath as well if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); } - }, + } _ => {} // No rpath needed for Windows } diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index ab1acde..8ad0ba9 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -22,10 +22,7 @@ fn main() -> XGBoostResult<()> { // Example: Binary classification prediction println!("=== Example 1: Normal Prediction ==="); - let data = vec![ - 5.1, 3.5, 1.4, 0.2, - 6.7, 3.0, 5.2, 2.3, - ]; + let data = vec![5.1, 3.5, 1.4, 0.2, 6.7, 3.0, 5.2, 2.3]; let predictions = booster.predict(&data, 2, 4, 0, false)?; println!("Normal predictions (probabilities): {:?}\n", predictions); @@ -34,22 +31,16 @@ fn main() -> XGBoostResult<()> { println!("=== Example 2: Feature Contributions (SHAP) ==="); use xgboost_rust::predict_option; - let shap_values = booster.predict( - &data, - 2, - 4, - predict_option::PRED_CONTRIBS, - false - )?; + let shap_values = booster.predict(&data, 2, 4, predict_option::PRED_CONTRIBS, false)?; println!("SHAP values for first sample:"); // SHAP values include one extra value for bias term let num_features = 4; - for i in 0..=num_features { + for (i, &value) in shap_values.iter().enumerate().take(num_features + 1) { if i < num_features { - println!(" Feature {}: {:.4}", i, shap_values[i]); + println!(" Feature {}: {:.4}", i, value); } else { - println!(" Bias term: {:.4}", shap_values[i]); + println!(" Bias term: {:.4}", value); } } println!(); @@ -62,7 +53,9 @@ fn main() -> XGBoostResult<()> { // Example: Load from buffer (useful for embedded models) println!("=== Example 4: Load from Buffer ==="); - let buffer = std::fs::read(model_path)?; + let buffer = std::fs::read(model_path).map_err(|e| xgboost_rust::XGBoostError { + description: format!("Failed to read model file: {}", e), + })?; let booster_from_buffer = Booster::load_from_buffer(&buffer)?; println!("✓ Model loaded from buffer ({} bytes)", buffer.len()); diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index 9fa7bd4..d167a0a 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -55,9 +55,9 @@ fn main() -> XGBoostResult<()> { // Example prediction data (4 features for iris dataset) // This is a sample from the iris dataset let data = vec![ - 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa - 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica - 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor + 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa + 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica + 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor ]; let num_rows = 3; let num_features = 4; diff --git a/src/error.rs b/src/error.rs index f2bfc45..88e17d5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ +use crate::sys; use std::ffi::CStr; use std::fmt; -use crate::sys; pub type XGBoostResult = std::result::Result; diff --git a/src/model.rs b/src/model.rs index 47c135e..9ab8a5d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -75,14 +75,12 @@ impl Booster { /// let booster = Booster::load("model.json").unwrap(); /// ``` pub fn load>(path: P) -> XGBoostResult { - let path_str = path.as_ref().to_str() - .ok_or_else(|| XGBoostError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| XGBoostError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; // Create a booster first let mut handle: sys::BoosterHandle = ptr::null_mut(); @@ -92,14 +90,13 @@ impl Booster { // Load model into the booster let result = XGBoostError::check_return_value(unsafe { - sys::XGBoosterLoadModel( - handle, - path_c_str.as_ptr(), - ) + sys::XGBoosterLoadModel(handle, path_c_str.as_ptr()) }); if let Err(e) = result { - unsafe { sys::XGBoosterFree(handle); } + unsafe { + sys::XGBoosterFree(handle); + } return Err(e); } @@ -136,7 +133,9 @@ impl Booster { }); if let Err(e) = result { - unsafe { sys::XGBoosterFree(handle); } + unsafe { + sys::XGBoosterFree(handle); + } return Err(e); } @@ -172,7 +171,8 @@ impl Booster { training: bool, ) -> XGBoostResult> { // Validate input dimensions - let expected_len = num_rows.checked_mul(num_features) + let expected_len = num_rows + .checked_mul(num_features) .ok_or_else(|| XGBoostError { description: format!( "Integer overflow: num_rows ({}) * num_features ({}) exceeds usize::MAX", @@ -184,7 +184,10 @@ impl Booster { return Err(XGBoostError { description: format!( "Data length mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_features, data.len() + expected_len, + num_rows, + num_features, + data.len() ), }); } @@ -237,9 +240,7 @@ impl Booster { } // Copy results to a Vec - let results = unsafe { - std::slice::from_raw_parts(out_result, out_len as usize).to_vec() - }; + let results = unsafe { std::slice::from_raw_parts(out_result, out_len as usize).to_vec() }; // DMatrix will be automatically freed when _guard goes out of scope @@ -254,10 +255,7 @@ impl Booster { let mut out_num_features: u64 = 0; XGBoostError::check_return_value(unsafe { - sys::XGBoosterGetNumFeature( - self.handle, - &mut out_num_features, - ) + sys::XGBoosterGetNumFeature(self.handle, &mut out_num_features) })?; Ok(out_num_features as usize) @@ -276,20 +274,15 @@ impl Booster { /// booster.save("model_copy.json").unwrap(); /// ``` pub fn save>(&self, path: P) -> XGBoostResult<()> { - let path_str = path.as_ref().to_str() - .ok_or_else(|| XGBoostError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| XGBoostError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| XGBoostError { + description: format!("Path contains NUL byte: {}", e), + })?; XGBoostError::check_return_value(unsafe { - sys::XGBoosterSaveModel( - self.handle, - path_c_str.as_ptr(), - ) + sys::XGBoosterSaveModel(self.handle, path_c_str.as_ptr()) }) } } diff --git a/src/sys.rs b/src/sys.rs index a38a13a..cd503e4 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1,5 +1,6 @@ #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +#![allow(dead_code)] include!(concat!(env!("OUT_DIR"), "/bindings.rs"));