Skip to content

add ci.#2

Closed
aryehlev wants to merge 11 commits into
mainfrom
bindings
Closed

add ci.#2
aryehlev wants to merge 11 commits into
mainfrom
bindings

Conversation

@aryehlev
Copy link
Copy Markdown
Owner

@aryehlev aryehlev commented Nov 10, 2025

Summary by CodeRabbit

  • New Features

    • Rust XGBoost library: load/save models, buffer-based loading, prediction (including SHAP/feature contributions, margins, leaf outputs)
    • Exposed user-facing error type and prediction option flags
    • Version-aware thread-safety: safe concurrent sharing for newer XGBoost versions
  • Documentation

    • Added/basic and advanced usage examples and clearer README guidance for thread-safety
  • Chores

    • Added CI pipeline, build automation and project manifest; updated ignore patterns

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Nov 10, 2025

Walkthrough

Adds a full Rust wrapper for XGBoost: build automation that downloads/verifies headers and wheels, generates bindgen FFI, provides error and Booster model modules with load/predict/save, example programs, CI workflow across OS/arch/XGBoost versions, and project manifest and ignores.

Changes

Cohort / File(s) Summary
CI workflow
​.github/workflows/ci.yml
Adds multi-job GitHub Actions pipeline: matrix across OS (macos/linux/windows), arch (arm64/x86_64), multiple XGBoost versions; jobs for tests, clippy, fmt, checksum, caching, and platform-specific verification.
Build & packaging
build.rs, wrapper.h, .gitignore, Cargo.toml
Adds build script that reads XGBOOST_VERSION, downloads/verifies c_api.h and platform wheel, extracts shared library, runs bindgen to emit bindings, patches rpaths, emits linker cfgs (including thread-safety flag), and caches artifacts. Adds C API include and ignore patterns; introduces Cargo manifest and features.
Core library
src/lib.rs, src/sys.rs, src/error.rs, src/model.rs
Adds generated FFI include, XGBoostError/XGBoostResult error types, Booster struct with load/load_from_buffer/predict/num_features/save, Drop and DMatrix guard, conditional unsafe Send/Sync based on build-time cfg, and re-exports plus predict option constants.
Examples
examples/basic_usage.rs, examples/advanced_usage.rs
Adds basic and advanced examples demonstrating model load, buffer load, prediction, SHAP (PRED_CONTRIBS), save, and error handling.
Documentation
README.md
Updates README to describe version-aware thread-safety (XGBoost ≥1.4), example using Arc, and build-time XGBOOST_VERSION behavior.

Sequence Diagram(s)

sequenceDiagram
    actor Dev
    participant CargoBuild as build.rs
    participant Network as Download Server
    participant FS as Filesystem
    participant Bindgen as bindgen
    participant Rustc as Rust compiler/linker

    Dev->>CargoBuild: cargo build (reads XGBOOST_VERSION, TARGET)
    CargoBuild->>Network: download c_api.h
    Network-->>CargoBuild: c_api.h bytes
    CargoBuild->>CargoBuild: verify SHA256
    CargoBuild->>FS: write OUT_DIR/include/c_api.h
    CargoBuild->>Network: download platform wheel (.whl)
    Network-->>CargoBuild: .whl file
    CargoBuild->>FS: extract libxgboost.* to OUT_DIR
    CargoBuild->>Bindgen: generate bindings from c_api.h
    Bindgen-->>CargoBuild: bindings.rs
    CargoBuild->>FS: write OUT_DIR/bindings.rs
    CargoBuild->>CargoBuild: patch rpath/install_name/patchelf if available
    CargoBuild->>Rustc: emit rustc-link-search, link flags, cfg(xgboost_thread_safe) if applicable
Loading
sequenceDiagram
    participant User as application
    participant Booster as Booster
    participant FFI as generated sys
    participant XGLib as libxgboost

    User->>Booster: Booster::load(path)
    Booster->>FFI: XGBoosterCreate / XGBoosterLoadModel
    FFI->>XGLib: C calls
    XGLib-->>FFI: status
    Booster-->>User: Booster instance or error

    User->>Booster: predict(data, rows, cols, options)
    Booster->>FFI: XGDMatrixCreateFromMat
    FFI->>XGLib: create DMatrix
    Booster->>FFI: XGBoosterPredict
    XGLib-->>FFI: predictions
    Booster->>FFI: XGDMatrixFree
    Booster-->>User: Vec<f32> or error
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • review focus:
    • build.rs: checksum map, download/extract retry, platform detection, rpath/install_name/patchelf usage, emitted cfg for thread-safety.
    • src/model.rs & src/error.rs: FFI safety, pointer handling, CString/NUL handling, RAII guards, error propagation.
    • CI workflow: matrix correctness and platform-specific verification steps.

Possibly related PRs

  • add bindings and build.rs script. #1 — Appears to modify the same files (build.rs, src/{lib.rs,model.rs,error.rs,sys.rs}, wrapper.h, examples, Cargo.toml, .gitignore) and likely overlaps in initial bindings and build infrastructure.

Poem

🐰 I fetched headers, checked their sum,
I spun the bindings, one by one,
Boosters now hop safe across threads,
Examples sing and CI treads,
A tiny rabbit cheers: build is done!

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'add ci.' is vague and generic. While CI workflows are added, the title does not convey meaningful information about the changeset's scope or purpose. Consider a more descriptive title such as 'Add GitHub Actions CI workflow' or 'Set up CI pipeline with multi-OS matrix testing' to better reflect the comprehensive changes beyond just the workflow file.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bindings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5e3d55a and c790f02.

📒 Files selected for processing (12)
  • .github/workflows/ci.yml (1 hunks)
  • .gitignore (1 hunks)
  • Cargo.toml (1 hunks)
  • README.md (2 hunks)
  • build.rs (1 hunks)
  • examples/advanced_usage.rs (1 hunks)
  • examples/basic_usage.rs (1 hunks)
  • src/error.rs (1 hunks)
  • src/lib.rs (1 hunks)
  • src/model.rs (1 hunks)
  • src/sys.rs (1 hunks)
  • wrapper.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/advanced_usage.rs (2)
examples/basic_usage.rs (1)
  • main (3-95)
src/model.rs (3)
  • load (77-107)
  • num_features (253-264)
  • load_from_buffer (122-144)
examples/basic_usage.rs (2)
examples/advanced_usage.rs (1)
  • main (3-84)
src/model.rs (2)
  • load (77-107)
  • num_features (253-264)
src/model.rs (1)
src/error.rs (1)
  • check_return_value (15-21)

Comment thread .github/workflows/ci.yml
Comment on lines +36 to +40
- os: macos
arch: x86_64
runner: macos-15-large
xgboost_version: "3.1.1"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Switch off paid macOS large runner

macos-15-large targets a paid larger runner that only works when the repository belongs to an organization with the larger-runner entitlement; on a personal/open-source repo this job will fail before any step runs.(docs.github.com) Please move to the standard Intel label (for example macos-15-intel) so the Intel coverage remains but the workflow still executes.

-          - os: macos
-            arch: x86_64
-            runner: macos-15-large
+          - os: macos
+            arch: x86_64
+            runner: macos-15-intel
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- os: macos
arch: x86_64
runner: macos-15-large
xgboost_version: "3.1.1"
- os: macos
arch: x86_64
runner: macos-15-intel
xgboost_version: "3.1.1"
🤖 Prompt for AI Agents
.github/workflows/ci.yml around lines 36 to 40: the job is using the paid larger
runner "macos-15-large" which will fail on personal/open-source repos; edit the
workflow to use the standard Intel runner by replacing the runner value
"macos-15-large" with "macos-15-intel" (keep os: macos and arch: x86_64
unchanged) so the job runs on GitHub-hosted Intel macOS runners.

Comment thread .github/workflows/ci.yml
Comment on lines +55 to +60
# Linux ARM64
- os: linux
arch: arm64
runner: ubuntu-latest
xgboost_version: "3.1.1"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix linux arm64 runner selection

Setting runner: ubuntu-latest still provisions an x86_64 VM; GitHub exposes dedicated labels (ubuntu-24.04-arm / ubuntu-22.04-arm) for hosted arm64 runners.(docs.github.com) Without switching labels, this matrix leg never validates arm64 builds despite claiming to. Update the label so the job actually runs on arm hardware.

-          - os: linux
-            arch: arm64
-            runner: ubuntu-latest
+          - os: linux
+            arch: arm64
+            runner: ubuntu-24.04-arm
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Linux ARM64
- os: linux
arch: arm64
runner: ubuntu-latest
xgboost_version: "3.1.1"
# Linux ARM64
- os: linux
arch: arm64
runner: ubuntu-24.04-arm
xgboost_version: "3.1.1"
🤖 Prompt for AI Agents
.github/workflows/ci.yml around lines 55 to 60: the matrix entry for "Linux
ARM64" incorrectly uses runner: ubuntu-latest which yields an x86_64 VM; replace
the runner label with a hosted ARM runner label (e.g., ubuntu-24.04-arm or
ubuntu-22.04-arm) so the job actually runs on arm64 hardware, and ensure any
OS/version matrix values remain consistent with the chosen ARM label.

Comment thread build.rs
Comment on lines +82 to +108
fn get_platform_info() -> (String, String) {
let target = env::var("TARGET").unwrap();

// Determine OS
let os = if target.contains("apple-darwin") {
"darwin"
} else if target.contains("linux") {
"linux"
} else if target.contains("windows") {
"windows"
} else {
panic!("Unsupported target: {}", target);
};

// Determine architecture
let arch = if target.contains("x86_64") {
"x86_64"
} else if target.contains("aarch64") || target.contains("arm64") {
"aarch64"
} else if target.contains("i686") || target.contains("i586") {
"i686"
} else {
panic!("Unsupported architecture for target: {}", target);
};

(os.to_string(), arch.to_string())
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on musl targets

get_platform_info lumps every *-linux-* triple together, so x86_64-unknown-linux-musl ends up downloading a glibc-only wheel. That toolchain can’t link or run the resulting .so, yielding linker errors later. Please detect musl targets up front and abort with a clear message (or add native support) instead of silently fetching the incompatible glibc build.

🤖 Prompt for AI Agents
In build.rs around lines 82 to 108, get_platform_info currently treats all linux
triples the same and will select a glibc build for musl targets; update it to
detect musl targets (e.g. target.contains("musl") or target.ends_with("-musl"))
early and fail fast with a clear panic message indicating musl is unsupported
(or add separate musl handling if you plan to support it). Ensure the musl check
runs before selecting the OS/arch and includes the full TARGET string in the
panic so downstream users see which triple caused the abort.

Comment thread build.rs
Comment on lines +255 to +312
let download_url = format!(
"https://files.pythonhosted.org/packages/py3/x/xgboost/{}",
wheel_filename
);

println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url);
let buffer = download_with_retry(&download_url, 3)?;

// Write atomically (temp file + rename)
let temp_path = wheel_path.with_extension("tmp");
{
let mut temp_file = fs::File::create(&temp_path)?;
temp_file.write_all(&buffer)?;
temp_file.sync_all()?;
}
fs::rename(&temp_path, &wheel_path)?;

println!("cargo:warning=✓ Downloaded and cached wheel");
buffer
};

// Extract library from wheel
println!("cargo:warning=Extracting library from wheel");

let cursor = io::Cursor::new(wheel_buffer);
let mut archive = zip::ZipArchive::new(cursor)?;

// Search for the library file in the wheel
let mut found = false;
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let file_path = file.name().to_string();

// Look for the library file (usually in xgboost/lib/)
if file_path.ends_with(lib_filename) {
println!("cargo:warning=Found library at: {}", file_path);

// Extract to temp file, then rename atomically
let temp_dest_path = lib_dest_path.with_extension("tmp");
{
let mut dest = fs::File::create(&temp_dest_path)?;
io::copy(&mut file, &mut dest)?;
dest.sync_all()?;
}
fs::rename(&temp_dest_path, &lib_dest_path)?;

found = true;
break;
}
}

if !found {
return Err(format!("Library file {} not found in wheel", lib_filename).into());
}

println!("cargo:warning=✓ Successfully extracted XGBoost library to: {}", lib_dir.display());

Ok(())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject unsigned wheel downloads

We verify header hashes but skip integrity checks for the downloaded wheel, so a man-in-the-middle or mirror compromise could silently swap in a malicious libxgboost. PyPI publishes SHA256 digests for every wheel (for example the 3.1.1 wheels listed on Oct 21 2025), so we should treat a missing hash check as a release blocker: download to a temp file, compute its SHA256, and compare against an allowlist just like we do for the headers before extracting or linking the binary.

@@
-        println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url);
-        let buffer = download_with_retry(&download_url, 3)?;
+        println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url);
+        let buffer = download_with_retry(&download_url, 3)?;
+        verify_checksum(&buffer, expected_wheel_sha256(version, os.as_str(), arch.as_str())?,
+                        &wheel_filename)?;

Committable suggestion skipped: line range outside the PR's diff.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/model.rs (1)

165-248: LGTM! Excellent use of RAII and thorough input validation.

The predict method demonstrates excellent practices:

  • Overflow-safe dimension validation (lines 174-181)
  • RAII guard pattern ensures DMatrix cleanup (lines 208-217)
  • Comprehensive output validation (lines 235-240)
  • Safe pointer-to-Vec conversion (line 243)

One minor optional improvement for the cast on line 227:

         XGBoostError::check_return_value(unsafe {
             sys::XGBoosterPredict(
                 self.handle,
                 dmatrix_handle,
-                option_mask as i32,
+                option_mask.try_into().unwrap_or(0),
                 0, // ntree_limit (0 means use all trees)
                 training as i32,
                 &mut out_len,
                 &mut out_result,
             )
         })?;

This prevents potential overflow if option_mask exceeds i32::MAX, though in practice the predict_option constants are small values.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c790f02 and 2f808fd.

📒 Files selected for processing (6)
  • build.rs (1 hunks)
  • examples/advanced_usage.rs (1 hunks)
  • examples/basic_usage.rs (1 hunks)
  • src/error.rs (1 hunks)
  • src/model.rs (1 hunks)
  • src/sys.rs (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • build.rs
  • examples/advanced_usage.rs
  • src/sys.rs
🧰 Additional context used
🧬 Code graph analysis (2)
src/model.rs (1)
src/error.rs (1)
  • check_return_value (15-21)
examples/basic_usage.rs (2)
examples/advanced_usage.rs (1)
  • main (3-73)
src/model.rs (2)
  • load (77-104)
  • num_features (254-262)
🔇 Additional comments (5)
src/error.rs (1)

1-39: LGTM! Clean error handling implementation.

The error module provides a well-designed abstraction over XGBoost's C API error handling:

  • check_return_value correctly maps return codes to Rust Result
  • fetch_xgboost_error safely handles potentially invalid UTF-8 with unwrap_or
  • Standard trait implementations (Display, Error) follow Rust conventions
  • Public description field enables flexible error construction
src/model.rs (4)

77-143: LGTM! Robust model loading with proper error handling.

Both load and load_from_buffer methods demonstrate solid FFI practices:

  • Path validation prevents UTF-8 and NUL byte issues before FFI calls
  • Consistent error handling via check_return_value
  • Proper cleanup of resources on failure (lines 96-100, 135-139)
  • Clear separation of concerns (create handle, load model, handle errors)

254-287: LGTM! Simple and correct implementations.

Both methods follow appropriate patterns:

  • num_features: Straightforward FFI call with proper error handling
  • save: Mirrors the path validation approach from load for consistency

290-296: LGTM! Correct RAII cleanup implementation.

The Drop implementation ensures the XGBoost handle is properly freed when the Booster goes out of scope. This is the standard pattern for resource management in FFI code.


53-60: The xgboost_thread_safe cfg flag mechanism is correctly implemented. The build.rs properly detects XGBoost version via the XGBOOST_VERSION environment variable (defaulting to "3.1.1"), parses it safely, and sets the cfg flag only when version ≥ 1.4 is detected using the condition major > 1 || (major == 1 && minor >= 4). The implementation includes defensive parsing with fallbacks, proper cfg registration, and build-time warnings for transparency.

Comment thread examples/basic_usage.rs
Comment on lines +51 to +64
// Get model information
let num_features = booster.num_features()?;
println!("Model expects {} features\n", num_features);

// Example prediction data (4 features for iris dataset)
// This is a sample from the iris dataset
let data = vec![
5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
];
let num_rows = 3;
let num_features = 4;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove variable shadowing that masks the retrieved feature count.

The num_features retrieved from the model on line 52 is immediately shadowed by a hardcoded value on line 63. This means the actual model feature count is retrieved, printed, but never used for prediction validation.

If the model expects a different number of features than 4, this will cause a runtime error or incorrect predictions.

Apply this diff to use the retrieved feature count:

     // Get model information
     let num_features = booster.num_features()?;
     println!("Model expects {} features\n", num_features);
 
     // Example prediction data (4 features for iris dataset)
     // This is a sample from the iris dataset
     let data = vec![
         5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
         6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
         5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
     ];
     let num_rows = 3;
-    let num_features = 4;
 
     println!("Making predictions on {} samples...", num_rows);

Alternatively, if you want to keep the hardcoded value for the example, validate it matches the model:

     // Get model information
-    let num_features = booster.num_features()?;
-    println!("Model expects {} features\n", num_features);
+    let model_features = booster.num_features()?;
+    println!("Model expects {} features\n", model_features);
 
     // Example prediction data (4 features for iris dataset)
     // This is a sample from the iris dataset
     let data = vec![
         5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
         6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
         5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
     ];
     let num_rows = 3;
     let num_features = 4;
+    
+    if model_features != num_features {
+        return Err(xgboost_rust::XGBoostError {
+            description: format!(
+                "Model expects {} features but example provides {}",
+                model_features, num_features
+            ),
+        });
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Get model information
let num_features = booster.num_features()?;
println!("Model expects {} features\n", num_features);
// Example prediction data (4 features for iris dataset)
// This is a sample from the iris dataset
let data = vec![
5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
];
let num_rows = 3;
let num_features = 4;
// Get model information
let num_features = booster.num_features()?;
println!("Model expects {} features\n", num_features);
// Example prediction data (4 features for iris dataset)
// This is a sample from the iris dataset
let data = vec![
5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
];
let num_rows = 3;
🤖 Prompt for AI Agents
In examples/basic_usage.rs around lines 51 to 64, the variable num_features
returned by booster.num_features() is being shadowed by a hardcoded let
num_features = 4; which discards the real model feature count; remove the
hardcoded shadowing and reuse the retrieved num_features for prediction shape
checks (or, if you intentionally want a hardcoded example, validate that
booster.num_features()? == 4 and bail with a clear error if it doesn't). Ensure
any local variables are renamed if you need both values (e.g.,
expected_num_features) and update the code that constructs/validates the data to
use the verified feature count so runtime mismatches can't occur.

@aryehlev aryehlev closed this Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant