Conversation
WalkthroughAdds a full Rust wrapper for XGBoost: build automation that downloads/verifies headers and wheels, generates bindgen FFI, provides error and Booster model modules with load/predict/save, example programs, CI workflow across OS/arch/XGBoost versions, and project manifest and ignores. Changes
Sequence Diagram(s)sequenceDiagram
actor Dev
participant CargoBuild as build.rs
participant Network as Download Server
participant FS as Filesystem
participant Bindgen as bindgen
participant Rustc as Rust compiler/linker
Dev->>CargoBuild: cargo build (reads XGBOOST_VERSION, TARGET)
CargoBuild->>Network: download c_api.h
Network-->>CargoBuild: c_api.h bytes
CargoBuild->>CargoBuild: verify SHA256
CargoBuild->>FS: write OUT_DIR/include/c_api.h
CargoBuild->>Network: download platform wheel (.whl)
Network-->>CargoBuild: .whl file
CargoBuild->>FS: extract libxgboost.* to OUT_DIR
CargoBuild->>Bindgen: generate bindings from c_api.h
Bindgen-->>CargoBuild: bindings.rs
CargoBuild->>FS: write OUT_DIR/bindings.rs
CargoBuild->>CargoBuild: patch rpath/install_name/patchelf if available
CargoBuild->>Rustc: emit rustc-link-search, link flags, cfg(xgboost_thread_safe) if applicable
sequenceDiagram
participant User as application
participant Booster as Booster
participant FFI as generated sys
participant XGLib as libxgboost
User->>Booster: Booster::load(path)
Booster->>FFI: XGBoosterCreate / XGBoosterLoadModel
FFI->>XGLib: C calls
XGLib-->>FFI: status
Booster-->>User: Booster instance or error
User->>Booster: predict(data, rows, cols, options)
Booster->>FFI: XGDMatrixCreateFromMat
FFI->>XGLib: create DMatrix
Booster->>FFI: XGBoosterPredict
XGLib-->>FFI: predictions
Booster->>FFI: XGDMatrixFree
Booster-->>User: Vec<f32> or error
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
.github/workflows/ci.yml(1 hunks).gitignore(1 hunks)Cargo.toml(1 hunks)README.md(2 hunks)build.rs(1 hunks)examples/advanced_usage.rs(1 hunks)examples/basic_usage.rs(1 hunks)src/error.rs(1 hunks)src/lib.rs(1 hunks)src/model.rs(1 hunks)src/sys.rs(1 hunks)wrapper.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/advanced_usage.rs (2)
examples/basic_usage.rs (1)
main(3-95)src/model.rs (3)
load(77-107)num_features(253-264)load_from_buffer(122-144)
examples/basic_usage.rs (2)
examples/advanced_usage.rs (1)
main(3-84)src/model.rs (2)
load(77-107)num_features(253-264)
src/model.rs (1)
src/error.rs (1)
check_return_value(15-21)
| - os: macos | ||
| arch: x86_64 | ||
| runner: macos-15-large | ||
| xgboost_version: "3.1.1" | ||
|
|
There was a problem hiding this comment.
Switch off paid macOS large runner
macos-15-large targets a paid larger runner that only works when the repository belongs to an organization with the larger-runner entitlement; on a personal/open-source repo this job will fail before any step runs.(docs.github.com) Please move to the standard Intel label (for example macos-15-intel) so the Intel coverage remains but the workflow still executes.
- - os: macos
- arch: x86_64
- runner: macos-15-large
+ - os: macos
+ arch: x86_64
+ runner: macos-15-intel📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - os: macos | |
| arch: x86_64 | |
| runner: macos-15-large | |
| xgboost_version: "3.1.1" | |
| - os: macos | |
| arch: x86_64 | |
| runner: macos-15-intel | |
| xgboost_version: "3.1.1" |
🤖 Prompt for AI Agents
.github/workflows/ci.yml around lines 36 to 40: the job is using the paid larger
runner "macos-15-large" which will fail on personal/open-source repos; edit the
workflow to use the standard Intel runner by replacing the runner value
"macos-15-large" with "macos-15-intel" (keep os: macos and arch: x86_64
unchanged) so the job runs on GitHub-hosted Intel macOS runners.
| # Linux ARM64 | ||
| - os: linux | ||
| arch: arm64 | ||
| runner: ubuntu-latest | ||
| xgboost_version: "3.1.1" | ||
|
|
There was a problem hiding this comment.
Fix linux arm64 runner selection
Setting runner: ubuntu-latest still provisions an x86_64 VM; GitHub exposes dedicated labels (ubuntu-24.04-arm / ubuntu-22.04-arm) for hosted arm64 runners.(docs.github.com) Without switching labels, this matrix leg never validates arm64 builds despite claiming to. Update the label so the job actually runs on arm hardware.
- - os: linux
- arch: arm64
- runner: ubuntu-latest
+ - os: linux
+ arch: arm64
+ runner: ubuntu-24.04-arm📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Linux ARM64 | |
| - os: linux | |
| arch: arm64 | |
| runner: ubuntu-latest | |
| xgboost_version: "3.1.1" | |
| # Linux ARM64 | |
| - os: linux | |
| arch: arm64 | |
| runner: ubuntu-24.04-arm | |
| xgboost_version: "3.1.1" | |
🤖 Prompt for AI Agents
.github/workflows/ci.yml around lines 55 to 60: the matrix entry for "Linux
ARM64" incorrectly uses runner: ubuntu-latest which yields an x86_64 VM; replace
the runner label with a hosted ARM runner label (e.g., ubuntu-24.04-arm or
ubuntu-22.04-arm) so the job actually runs on arm64 hardware, and ensure any
OS/version matrix values remain consistent with the chosen ARM label.
| fn get_platform_info() -> (String, String) { | ||
| let target = env::var("TARGET").unwrap(); | ||
|
|
||
| // Determine OS | ||
| let os = if target.contains("apple-darwin") { | ||
| "darwin" | ||
| } else if target.contains("linux") { | ||
| "linux" | ||
| } else if target.contains("windows") { | ||
| "windows" | ||
| } else { | ||
| panic!("Unsupported target: {}", target); | ||
| }; | ||
|
|
||
| // Determine architecture | ||
| let arch = if target.contains("x86_64") { | ||
| "x86_64" | ||
| } else if target.contains("aarch64") || target.contains("arm64") { | ||
| "aarch64" | ||
| } else if target.contains("i686") || target.contains("i586") { | ||
| "i686" | ||
| } else { | ||
| panic!("Unsupported architecture for target: {}", target); | ||
| }; | ||
|
|
||
| (os.to_string(), arch.to_string()) | ||
| } |
There was a problem hiding this comment.
Fail fast on musl targets
get_platform_info lumps every *-linux-* triple together, so x86_64-unknown-linux-musl ends up downloading a glibc-only wheel. That toolchain can’t link or run the resulting .so, yielding linker errors later. Please detect musl targets up front and abort with a clear message (or add native support) instead of silently fetching the incompatible glibc build.
🤖 Prompt for AI Agents
In build.rs around lines 82 to 108, get_platform_info currently treats all linux
triples the same and will select a glibc build for musl targets; update it to
detect musl targets (e.g. target.contains("musl") or target.ends_with("-musl"))
early and fail fast with a clear panic message indicating musl is unsupported
(or add separate musl handling if you plan to support it). Ensure the musl check
runs before selecting the OS/arch and includes the full TARGET string in the
panic so downstream users see which triple caused the abort.
| let download_url = format!( | ||
| "https://files.pythonhosted.org/packages/py3/x/xgboost/{}", | ||
| wheel_filename | ||
| ); | ||
|
|
||
| println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url); | ||
| let buffer = download_with_retry(&download_url, 3)?; | ||
|
|
||
| // Write atomically (temp file + rename) | ||
| let temp_path = wheel_path.with_extension("tmp"); | ||
| { | ||
| let mut temp_file = fs::File::create(&temp_path)?; | ||
| temp_file.write_all(&buffer)?; | ||
| temp_file.sync_all()?; | ||
| } | ||
| fs::rename(&temp_path, &wheel_path)?; | ||
|
|
||
| println!("cargo:warning=✓ Downloaded and cached wheel"); | ||
| buffer | ||
| }; | ||
|
|
||
| // Extract library from wheel | ||
| println!("cargo:warning=Extracting library from wheel"); | ||
|
|
||
| let cursor = io::Cursor::new(wheel_buffer); | ||
| let mut archive = zip::ZipArchive::new(cursor)?; | ||
|
|
||
| // Search for the library file in the wheel | ||
| let mut found = false; | ||
| for i in 0..archive.len() { | ||
| let mut file = archive.by_index(i)?; | ||
| let file_path = file.name().to_string(); | ||
|
|
||
| // Look for the library file (usually in xgboost/lib/) | ||
| if file_path.ends_with(lib_filename) { | ||
| println!("cargo:warning=Found library at: {}", file_path); | ||
|
|
||
| // Extract to temp file, then rename atomically | ||
| let temp_dest_path = lib_dest_path.with_extension("tmp"); | ||
| { | ||
| let mut dest = fs::File::create(&temp_dest_path)?; | ||
| io::copy(&mut file, &mut dest)?; | ||
| dest.sync_all()?; | ||
| } | ||
| fs::rename(&temp_dest_path, &lib_dest_path)?; | ||
|
|
||
| found = true; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if !found { | ||
| return Err(format!("Library file {} not found in wheel", lib_filename).into()); | ||
| } | ||
|
|
||
| println!("cargo:warning=✓ Successfully extracted XGBoost library to: {}", lib_dir.display()); | ||
|
|
||
| Ok(()) |
There was a problem hiding this comment.
Reject unsigned wheel downloads
We verify header hashes but skip integrity checks for the downloaded wheel, so a man-in-the-middle or mirror compromise could silently swap in a malicious libxgboost. PyPI publishes SHA256 digests for every wheel (for example the 3.1.1 wheels listed on Oct 21 2025), so we should treat a missing hash check as a release blocker: download to a temp file, compute its SHA256, and compare against an allowlist just like we do for the headers before extracting or linking the binary.
@@
- println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url);
- let buffer = download_with_retry(&download_url, 3)?;
+ println!("cargo:warning=Downloading XGBoost wheel from: {}", download_url);
+ let buffer = download_with_retry(&download_url, 3)?;
+ verify_checksum(&buffer, expected_wheel_sha256(version, os.as_str(), arch.as_str())?,
+ &wheel_filename)?;Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/model.rs (1)
165-248: LGTM! Excellent use of RAII and thorough input validation.The
predictmethod demonstrates excellent practices:
- Overflow-safe dimension validation (lines 174-181)
- RAII guard pattern ensures DMatrix cleanup (lines 208-217)
- Comprehensive output validation (lines 235-240)
- Safe pointer-to-Vec conversion (line 243)
One minor optional improvement for the cast on line 227:
XGBoostError::check_return_value(unsafe { sys::XGBoosterPredict( self.handle, dmatrix_handle, - option_mask as i32, + option_mask.try_into().unwrap_or(0), 0, // ntree_limit (0 means use all trees) training as i32, &mut out_len, &mut out_result, ) })?;This prevents potential overflow if
option_maskexceedsi32::MAX, though in practice thepredict_optionconstants are small values.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
build.rs(1 hunks)examples/advanced_usage.rs(1 hunks)examples/basic_usage.rs(1 hunks)src/error.rs(1 hunks)src/model.rs(1 hunks)src/sys.rs(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- build.rs
- examples/advanced_usage.rs
- src/sys.rs
🧰 Additional context used
🧬 Code graph analysis (2)
src/model.rs (1)
src/error.rs (1)
check_return_value(15-21)
examples/basic_usage.rs (2)
examples/advanced_usage.rs (1)
main(3-73)src/model.rs (2)
load(77-104)num_features(254-262)
🔇 Additional comments (5)
src/error.rs (1)
1-39: LGTM! Clean error handling implementation.The error module provides a well-designed abstraction over XGBoost's C API error handling:
check_return_valuecorrectly maps return codes to RustResultfetch_xgboost_errorsafely handles potentially invalid UTF-8 withunwrap_or- Standard trait implementations (
Display,Error) follow Rust conventions- Public
descriptionfield enables flexible error constructionsrc/model.rs (4)
77-143: LGTM! Robust model loading with proper error handling.Both
loadandload_from_buffermethods demonstrate solid FFI practices:
- Path validation prevents UTF-8 and NUL byte issues before FFI calls
- Consistent error handling via
check_return_value- Proper cleanup of resources on failure (lines 96-100, 135-139)
- Clear separation of concerns (create handle, load model, handle errors)
254-287: LGTM! Simple and correct implementations.Both methods follow appropriate patterns:
num_features: Straightforward FFI call with proper error handlingsave: Mirrors the path validation approach fromloadfor consistency
290-296: LGTM! Correct RAII cleanup implementation.The
Dropimplementation ensures the XGBoost handle is properly freed when theBoostergoes out of scope. This is the standard pattern for resource management in FFI code.
53-60: Thexgboost_thread_safecfg flag mechanism is correctly implemented. Thebuild.rsproperly detects XGBoost version via theXGBOOST_VERSIONenvironment variable (defaulting to "3.1.1"), parses it safely, and sets the cfg flag only when version ≥ 1.4 is detected using the conditionmajor > 1 || (major == 1 && minor >= 4). The implementation includes defensive parsing with fallbacks, proper cfg registration, and build-time warnings for transparency.
| // Get model information | ||
| let num_features = booster.num_features()?; | ||
| println!("Model expects {} features\n", num_features); | ||
|
|
||
| // Example prediction data (4 features for iris dataset) | ||
| // This is a sample from the iris dataset | ||
| let data = vec![ | ||
| 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa | ||
| 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica | ||
| 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor | ||
| ]; | ||
| let num_rows = 3; | ||
| let num_features = 4; | ||
|
|
There was a problem hiding this comment.
Remove variable shadowing that masks the retrieved feature count.
The num_features retrieved from the model on line 52 is immediately shadowed by a hardcoded value on line 63. This means the actual model feature count is retrieved, printed, but never used for prediction validation.
If the model expects a different number of features than 4, this will cause a runtime error or incorrect predictions.
Apply this diff to use the retrieved feature count:
// Get model information
let num_features = booster.num_features()?;
println!("Model expects {} features\n", num_features);
// Example prediction data (4 features for iris dataset)
// This is a sample from the iris dataset
let data = vec![
5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
];
let num_rows = 3;
- let num_features = 4;
println!("Making predictions on {} samples...", num_rows);Alternatively, if you want to keep the hardcoded value for the example, validate it matches the model:
// Get model information
- let num_features = booster.num_features()?;
- println!("Model expects {} features\n", num_features);
+ let model_features = booster.num_features()?;
+ println!("Model expects {} features\n", model_features);
// Example prediction data (4 features for iris dataset)
// This is a sample from the iris dataset
let data = vec![
5.1, 3.5, 1.4, 0.2, // Row 1: Setosa
6.7, 3.0, 5.2, 2.3, // Row 2: Virginica
5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor
];
let num_rows = 3;
let num_features = 4;
+
+ if model_features != num_features {
+ return Err(xgboost_rust::XGBoostError {
+ description: format!(
+ "Model expects {} features but example provides {}",
+ model_features, num_features
+ ),
+ });
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Get model information | |
| let num_features = booster.num_features()?; | |
| println!("Model expects {} features\n", num_features); | |
| // Example prediction data (4 features for iris dataset) | |
| // This is a sample from the iris dataset | |
| let data = vec![ | |
| 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa | |
| 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica | |
| 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor | |
| ]; | |
| let num_rows = 3; | |
| let num_features = 4; | |
| // Get model information | |
| let num_features = booster.num_features()?; | |
| println!("Model expects {} features\n", num_features); | |
| // Example prediction data (4 features for iris dataset) | |
| // This is a sample from the iris dataset | |
| let data = vec![ | |
| 5.1, 3.5, 1.4, 0.2, // Row 1: Setosa | |
| 6.7, 3.0, 5.2, 2.3, // Row 2: Virginica | |
| 5.9, 3.0, 4.2, 1.5, // Row 3: Versicolor | |
| ]; | |
| let num_rows = 3; |
🤖 Prompt for AI Agents
In examples/basic_usage.rs around lines 51 to 64, the variable num_features
returned by booster.num_features() is being shadowed by a hardcoded let
num_features = 4; which discards the real model feature count; remove the
hardcoded shadowing and reuse the retrieved num_features for prediction shape
checks (or, if you intentionally want a hardcoded example, validate that
booster.num_features()? == 4 and bail with a clear error if it doesn't). Ensure
any local variables are renamed if you need both values (e.g.,
expected_num_features) and update the code that constructs/validates the data to
use the verified feature count so runtime mismatches can't occur.
Summary by CodeRabbit
New Features
Documentation
Chores