diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..82112dc --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,207 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Test ${{ matrix.os }} (${{ matrix.arch }}) - LightGBM ${{ matrix.lightgbm_version }} + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + # macOS ARM64 (M1/M2/M3) + - os: macos + arch: arm64 + runner: macos-14 + lightgbm_version: "4.6.0" + - os: macos + arch: arm64 + runner: macos-14 + lightgbm_version: "4.5.0" + + # macOS x86_64 (Intel) + - os: macos + arch: x86_64 + runner: macos-15-large + lightgbm_version: "4.6.0" + + # Linux x86_64 + - os: linux + arch: x86_64 + runner: ubuntu-latest + lightgbm_version: "4.6.0" + - os: linux + arch: x86_64 + runner: ubuntu-latest + lightgbm_version: "4.5.0" + + - os: linux + arch: arm64 + runner: ubuntu-latest + lightgbm_version: "4.6.0" + + # Windows x86_64 + - os: windows + arch: x86_64 + runner: windows-latest + lightgbm_version: "4.6.0" + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-registry- + + - name: Cache cargo index + uses: actions/cache@v4 + with: + path: ~/.cargo/git + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-index- + + - name: Cache cargo build + uses: actions/cache@v4 + with: + path: target + key: ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-${{ matrix.arch }}-cargo-build-target- + + - name: Install system dependencies (Linux) + if: matrix.os == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + - name: Install system dependencies (macOS) + if: matrix.os == 'macos' + run: | + brew install libomp + + - name: Check build (no features) + env: + LIGHTGBM_VERSION: ${{ matrix.lightgbm_version }} + run: cargo check --verbose + + - name: Build (no features) + env: + LIGHTGBM_VERSION: ${{ matrix.lightgbm_version }} + run: cargo build --verbose + + - name: Run tests (no features) + env: + LIGHTGBM_VERSION: ${{ matrix.lightgbm_version }} + run: cargo test --verbose + + - name: Build examples + env: + LIGHTGBM_VERSION: ${{ matrix.lightgbm_version }} + run: | + cargo build --example basic_usage --verbose + cargo build --example advanced_usage --verbose + + - name: Verify library architecture (macOS) + if: matrix.os == 'macos' + run: | + echo "Checking library architecture..." + file target/debug/lib_lightgbm.dylib + lipo -info target/debug/lib_lightgbm.dylib + + - name: Verify library architecture (Linux) + if: matrix.os == 'linux' + run: | + echo "Checking library architecture..." + file target/debug/lib_lightgbm.so + readelf -h target/debug/lib_lightgbm.so | grep Machine + + - name: Verify library exists (Windows) + if: matrix.os == 'windows' + run: | + echo "Checking library exists..." + Get-Item target/debug/lib_lightgbm.dll + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Run clippy (no features) + run: cargo clippy -- -D warnings + + fmt: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + + # Test minimum supported LightGBM version + min-version: + name: Test minimum LightGBM version + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Test with LightGBM 4.0.0 + env: + LIGHTGBM_VERSION: "4.0.0" + run: | + cargo check --verbose + cargo build --verbose + + # Test unsupported platform (should fail gracefully) + # Commented out because it requires manual verification + unsupported-platform: + name: Test Windows ARM64 (unsupported) + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Add ARM64 target + run: rustup target add aarch64-pc-windows-msvc + + - name: Test unsupported platform detection + continue-on-error: true + run: cargo build --target aarch64-pc-windows-msvc --verbose + id: unsupported_build + + - name: Verify build failed with correct error + if: steps.unsupported_build.outcome == 'failure' + run: echo "Build failed as expected for Windows ARM64" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..1d7812a --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,206 @@ +name: Release + +on: + workflow_dispatch: + inputs: + release_type: + description: 'Release type' + required: true + type: choice + options: + - patch + - minor + - major + release: + types: [published] + +jobs: + # Create release with version bump + create-release: + name: Create Release + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' + outputs: + version: ${{ steps.version.outputs.version }} + tag: ${{ steps.version.outputs.tag }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Configure Git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Bump version + id: version + run: | + # Get current version from Cargo.toml + CURRENT_VERSION=$(grep "^version" Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') + echo "Current version: $CURRENT_VERSION" + + # Split version into major.minor.patch + IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT_VERSION" + + # Bump version based on input + case "${{ github.event.inputs.release_type }}" in + major) + MAJOR=$((MAJOR + 1)) + MINOR=0 + PATCH=0 + ;; + minor) + MINOR=$((MINOR + 1)) + PATCH=0 + ;; + patch) + PATCH=$((PATCH + 1)) + ;; + esac + + NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}" + echo "New version: $NEW_VERSION" + + # Update Cargo.toml + sed -i "s/^version = \".*\"/version = \"$NEW_VERSION\"/" Cargo.toml + + # Output for later steps + echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT + echo "tag=v$NEW_VERSION" >> $GITHUB_OUTPUT + + - name: Commit version bump + run: | + git add Cargo.toml + git commit -m "chore: bump version to ${{ steps.version.outputs.version }}" + git push origin ${{ github.ref_name }} + + - name: Create tag + run: | + git tag ${{ steps.version.outputs.tag }} + git push origin ${{ steps.version.outputs.tag }} + + - name: Create GitHub Release + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ steps.version.outputs.tag }} + name: Release ${{ steps.version.outputs.version }} + draft: false + prerelease: false + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + # Comprehensive pre-release testing across all platforms + pre-release-test: + name: Pre-release test ${{ matrix.os }} (${{ matrix.arch }}) + runs-on: ${{ matrix.runner }} + needs: [create-release] + if: always() && (needs.create-release.result == 'success' || needs.create-release.result == 'skipped') + strategy: + fail-fast: false + matrix: + include: + - os: macos + arch: arm64 + runner: macos-14 + - os: macos + arch: x86_64 + runner: macos-13 + - os: linux + arch: x86_64 + runner: ubuntu-latest + - os: linux + arch: arm64 + runner: ubuntu-24.04-arm64 + - os: windows + arch: x86_64 + runner: windows-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install system dependencies (Linux) + if: matrix.os == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y libclang-dev + + # Test with latest LightGBM version + - name: Build and test (latest LightGBM) + run: | + cargo build --release --verbose + cargo test --release --verbose + + # Build all examples + - name: Build examples + run: | + cargo build --release --example basic_usage + cargo build --release --example advanced_usage + + # Archive release artifacts + - name: Archive artifacts (Unix) + if: matrix.os != 'windows' + run: | + mkdir -p artifacts + cp target/release/lib_lightgbm.* artifacts/ || true + tar -czf lightgbm-rust-${{ matrix.os }}-${{ matrix.arch }}.tar.gz artifacts/ + + - name: Archive artifacts (Windows) + if: matrix.os == 'windows' + run: | + New-Item -ItemType Directory -Force -Path artifacts + Copy-Item target/release/lib_lightgbm.* artifacts/ -ErrorAction SilentlyContinue + Compress-Archive -Path artifacts/* -DestinationPath lightgbm-rust-${{ matrix.os }}-${{ matrix.arch }}.zip + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: lightgbm-rust-${{ matrix.os }}-${{ matrix.arch }} + path: | + lightgbm-rust-${{ matrix.os }}-${{ matrix.arch }}.* + + # Upload artifacts to the release + upload-release-artifacts: + name: Upload Release Artifacts + needs: [create-release, pre-release-test] + runs-on: ubuntu-latest + if: always() && needs.pre-release-test.result == 'success' && (github.event_name == 'release' || github.event_name == 'workflow_dispatch') + steps: + - uses: actions/checkout@v4 + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + + - name: Upload artifacts to release + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ github.event_name == 'workflow_dispatch' && needs.create-release.outputs.tag || github.event.release.tag_name }} + files: artifacts/**/* + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + # Publish to crates.io + publish: + name: Publish to crates.io + needs: [create-release, pre-release-test] + runs-on: ubuntu-latest + if: always() && needs.pre-release-test.result == 'success' && ((github.event_name == 'release' && !github.event.release.prerelease) || github.event_name == 'workflow_dispatch') + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'workflow_dispatch' && needs.create-release.outputs.tag || github.ref }} + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Publish to crates.io + run: cargo publish --token ${{ secrets.CARGO_REGISTRY_TOKEN }} + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index e783256..6ca22de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,9 +9,12 @@ categories = ["science"] readme = "README.md" rust-version = "1.70" +[dependencies] + [build-dependencies] bindgen = "0.72.0" ureq = "2.0" +zip = "2.2" [features] default = [] diff --git a/build.rs b/build.rs index 3f577d3..82b3df7 100644 --- a/build.rs +++ b/build.rs @@ -1,9 +1,9 @@ extern crate bindgen; use std::env; -use std::path::{Path, PathBuf}; use std::fs; use std::io; +use std::path::{Path, PathBuf}; fn get_lightgbm_version() -> String { env::var("LIGHTGBM_VERSION").unwrap_or_else(|_| "4.6.0".to_string()) @@ -54,7 +54,7 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download c_api.h: HTTP {}", status).into()); } @@ -72,7 +72,7 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download export.h: HTTP {}", status).into()); } @@ -87,7 +87,10 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 200 && response.status() < 300 => { @@ -102,7 +105,10 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 200 && resp.status() < 300 => { @@ -117,7 +123,9 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box { - println!("cargo:warning=arrow.h not available for this version (optional, only in v4.2.0+)"); + println!( + "cargo:warning=arrow.h not available for this version (optional, only in v4.2.0+)" + ); } } @@ -125,60 +133,194 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box Result<(), Box> { - let (os, _arch) = get_platform_info(); + let (os, arch) = get_platform_info(); let version = get_lightgbm_version(); - // LightGBM release binaries (platform-specific) - let (lib_filename, download_url) = match os.as_str() { - "linux" => ( - "lib_lightgbm.so".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.so", - version - ), - ), - "darwin" => ( - "lib_lightgbm.dylib".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dylib", - version - ), - ), - "windows" => ( - "lib_lightgbm.dll".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dll", - version - ), - ), - _ => return Err(format!("Unsupported platform: {}", os).into()), - }; - - println!( - "cargo:warning=Downloading LightGBM v{} library from: {}", - version, download_url - ); - // Create the library directory let lib_dir = out_dir.join("libs"); fs::create_dir_all(&lib_dir)?; - // Download the library directly into the `libs` directory with its correct name - let lib_path = lib_dir.join(&lib_filename); - let mut dest = fs::File::create(&lib_path)?; + // For macOS and Linux, extract from Python wheel to get architecture-specific binaries + match (os.as_str(), arch.as_str()) { + // macOS - both x86_64 and ARM64 available + ("darwin", "aarch64") | ("darwin", "x86_64") => { + let wheel_arch = if arch == "aarch64" { "arm64" } else { "x86_64" }; + let macos_version = if arch == "aarch64" { "12_0" } else { "10_15" }; + let wheel_url = format!( + "https://github.com/microsoft/LightGBM/releases/download/v{}/lightgbm-{}-py3-none-macosx_{}_{}.whl", + version, version, macos_version, wheel_arch + ); + + println!( + "cargo:warning=Downloading LightGBM v{} macOS {} wheel from: {}", + version, wheel_arch, wheel_url + ); - let response = ureq::get(&download_url).call()?; + download_and_extract_from_wheel(&wheel_url, out_dir, &lib_dir, "lib_lightgbm.dylib")?; + } + + // Linux - both x86_64 and ARM64 available + ("linux", "aarch64") | ("linux", "x86_64") => { + let (wheel_platform, lib_pattern) = if arch == "aarch64" { + ("manylinux2014_aarch64", "lib_lightgbm.so") + } else { + ("manylinux_2_28_x86_64", "lib_lightgbm.so") + }; + + let wheel_url = format!( + "https://github.com/microsoft/LightGBM/releases/download/v{}/lightgbm-{}-py3-none-{}.whl", + version, version, wheel_platform + ); + + println!( + "cargo:warning=Downloading LightGBM v{} Linux {} wheel from: {}", + version, arch, wheel_url + ); + + download_and_extract_from_wheel(&wheel_url, out_dir, &lib_dir, lib_pattern)?; + } + + // Windows - only x86_64 available + ("windows", "x86_64") => { + // For Windows, extract from wheel - need both DLL and import library + let wheel_url = format!( + "https://github.com/microsoft/LightGBM/releases/download/v{}/lightgbm-{}-py3-none-win_amd64.whl", + version, version + ); + + println!( + "cargo:warning=Downloading LightGBM v{} Windows x86_64 wheel from: {}", + version, wheel_url + ); + + download_and_extract_windows_libs(&wheel_url, out_dir, &lib_dir)?; + } + + ("windows", "i686") => { + return Err("Windows 32-bit (i686) is not supported by LightGBM releases. Please use x86_64 Windows or compile LightGBM from source.".into()); + } + + ("windows", "aarch64") => { + return Err("Windows ARM64 is not currently supported by LightGBM releases. Please use x86_64 Windows or compile LightGBM from source.".into()); + } + + _ => { + return Err(format!( + "Unsupported platform/architecture combination: {} / {}", + os, arch + ) + .into()); + } + } + + Ok(()) +} + +fn download_and_extract_from_wheel( + wheel_url: &str, + out_dir: &Path, + lib_dir: &Path, + lib_filename: &str, +) -> Result<(), Box> { + // Download the wheel to a temp file + let wheel_path = out_dir.join("lightgbm.whl"); + let mut dest = fs::File::create(&wheel_path)?; + + let response = ureq::get(wheel_url).call()?; let status = response.status(); - if status < 200 || status >= 300 { - return Err(format!("Failed to download library: HTTP {}", status).into()); + if !(200..300).contains(&status) { + return Err(format!("Failed to download wheel: HTTP {}", status).into()); } io::copy(&mut response.into_reader(), &mut dest)?; + drop(dest); // Close file before reading + + // Extract the library from the wheel + // Wheels are just zip files + let wheel_file = fs::File::open(&wheel_path)?; + let mut archive = zip::ZipArchive::new(wheel_file)?; + + // Find and extract the library + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + if file.name().ends_with(lib_filename) { + let lib_path = lib_dir.join(lib_filename); + let mut outfile = fs::File::create(&lib_path)?; + io::copy(&mut file, &mut outfile)?; + + println!( + "cargo:warning=Extracted LightGBM library to: {}", + lib_path.display() + ); + return Ok(()); + } + } - println!( - "cargo:warning=Downloaded LightGBM library to: {}", - lib_path.display() - ); + Err(format!("{} not found in wheel", lib_filename).into()) +} + +fn download_and_extract_windows_libs( + wheel_url: &str, + out_dir: &Path, + lib_dir: &Path, +) -> Result<(), Box> { + // Download the wheel to a temp file + let wheel_path = out_dir.join("lightgbm.whl"); + let mut dest = fs::File::create(&wheel_path)?; + + let response = ureq::get(wheel_url).call()?; + let status = response.status(); + if !(200..300).contains(&status) { + return Err(format!("Failed to download wheel: HTTP {}", status).into()); + } + + io::copy(&mut response.into_reader(), &mut dest)?; + drop(dest); // Close file before reading + + // Extract both the DLL and the import library from the wheel + // Wheels are just zip files + let wheel_file = fs::File::open(&wheel_path)?; + let mut archive = zip::ZipArchive::new(wheel_file)?; + + let mut dll_found = false; + let mut lib_found = false; + + // Find and extract both lib_lightgbm.dll and lib_lightgbm.lib + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let filename = file.name(); + + if filename.ends_with("lib_lightgbm.dll") { + let lib_path = lib_dir.join("lib_lightgbm.dll"); + let mut outfile = fs::File::create(&lib_path)?; + io::copy(&mut file, &mut outfile)?; + println!( + "cargo:warning=Extracted LightGBM DLL to: {}", + lib_path.display() + ); + dll_found = true; + } else if filename.ends_with("lib_lightgbm.lib") { + let lib_path = lib_dir.join("lib_lightgbm.lib"); + let mut outfile = fs::File::create(&lib_path)?; + io::copy(&mut file, &mut outfile)?; + println!( + "cargo:warning=Extracted LightGBM import library to: {}", + lib_path.display() + ); + lib_found = true; + } + + if dll_found && lib_found { + return Ok(()); + } + } + + if !dll_found { + return Err("lib_lightgbm.dll not found in wheel".into()); + } + if !lib_found { + return Err("lib_lightgbm.lib not found in wheel".into()); + } Ok(()) } @@ -203,7 +345,7 @@ fn main() { .header("wrapper.h") .clang_arg(format!("-I{}", lgbm_include_root.display())) .clang_arg("-xc++") - .clang_arg("-std=c++11") + .clang_arg("-std=c++17") // Only generate bindings for functions starting with LGBM_ .allowlist_function("LGBM_.*") // Allowlist the main types we need @@ -224,7 +366,6 @@ fn main() { .blocklist_type(".*_Tp.*") .blocklist_type(".*_Pred.*") .size_t_is_usize(true) - .rustfmt_bindings(true) .generate() .expect("Unable to generate bindings."); @@ -253,8 +394,19 @@ fn main() { .join(env::var("PROFILE").unwrap()); let lib_dest_path = target_dir.join(lib_filename); - fs::copy(&lib_source_path, &lib_dest_path) - .expect("Failed to copy library to target directory"); + fs::copy(&lib_source_path, &lib_dest_path).expect("Failed to copy library to target directory"); + + // On Windows, also copy the import library (.lib) to the libs directory for linking + if os == "windows" { + let import_lib_source = out_dir.join("libs").join("lib_lightgbm.lib"); + if import_lib_source.exists() { + // No need to copy the .lib to target dir, it's only used during linking + println!( + "cargo:warning=Found import library at: {}", + import_lib_source.display() + ); + } + } // Set the library search path for the build-time linker let lib_search_path = out_dir.join("libs"); @@ -269,21 +421,38 @@ fn main() { // For macOS, add multiple rpath entries for IDE compatibility println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); // Add the target directory to rpath as well if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); } - }, + println!("cargo:rustc-link-lib=dylib=_lightgbm"); + } "linux" => { // For Linux, use $ORIGIN println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); - }, - _ => {} // No rpath needed for Windows + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + println!("cargo:rustc-link-lib=dylib=_lightgbm"); + } + "windows" => { + // On Windows, we need to tell the linker where to find the DLL at runtime + // Copy the DLL to the output directory (already done above) + println!("cargo:rustc-link-lib=dylib=lib_lightgbm"); + } + _ => {} } - - println!("cargo:rustc-link-lib=dylib=lib_lightgbm"); } diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index 6c2b70d..2638e08 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -1,4 +1,4 @@ -use lightgbm_rust::{Booster, predict_type}; +use lightgbm_rust::{predict_type, Booster}; fn main() -> Result<(), Box> { // Load a trained LightGBM model @@ -16,10 +16,7 @@ fn main() -> Result<(), Box> { println!(" Classes: {}", num_classes); // Example data with f32 (more memory efficient for large datasets) - let data_f32: Vec = vec![ - 1.0, 2.0, 3.0, 4.0, 5.0, - 2.0, 3.0, 4.0, 5.0, 6.0, - ]; + let data_f32: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let num_rows = 2; let num_cols = 5; @@ -32,11 +29,13 @@ fn main() -> Result<(), Box> { println!("Raw scores: {:?}", raw_scores); println!("\n--- Leaf Index Prediction ---"); - let leaf_indices = booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::LEAF_INDEX)?; + let leaf_indices = + booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::LEAF_INDEX)?; println!("Leaf indices: {:?}", leaf_indices); println!("\n--- Feature Contribution (SHAP) ---"); - let contributions = booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::CONTRIB)?; + let contributions = + booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::CONTRIB)?; println!("Feature contributions: {:?}", contributions); Ok(()) diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index f4c3037..5cc36f8 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -1,4 +1,4 @@ -use lightgbm_rust::{Booster, predict_type}; +use lightgbm_rust::{predict_type, Booster}; fn main() -> Result<(), Box> { // Load a trained LightGBM model @@ -29,15 +29,16 @@ fn main() -> Result<(), Box> { // Example: Predict for multiple samples (batch prediction) let batch_data = vec![ - 1.0, 2.0, 3.0, 4.0, // Sample 1 - 2.0, 3.0, 4.0, 5.0, // Sample 2 - 3.0, 4.0, 5.0, 6.0, // Sample 3 + 1.0, 2.0, 3.0, 4.0, // Sample 1 + 2.0, 3.0, 4.0, 5.0, // Sample 2 + 3.0, 4.0, 5.0, 6.0, // Sample 3 ]; let num_rows = 3; let num_cols = 4; println!("\nMaking batch prediction..."); - let batch_predictions = booster.predict(&batch_data, num_rows, num_cols, predict_type::NORMAL)?; + let batch_predictions = + booster.predict(&batch_data, num_rows, num_cols, predict_type::NORMAL)?; println!("Batch predictions: {:?}", batch_predictions); diff --git a/src/error.rs b/src/error.rs index 37793a2..3e78ca8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ +use crate::sys; use std::ffi::CStr; use std::fmt; -use crate::sys; pub type LightGBMResult = std::result::Result; diff --git a/src/model.rs b/src/model.rs index 4df0943..20ff678 100644 --- a/src/model.rs +++ b/src/model.rs @@ -49,14 +49,12 @@ pub struct Booster { impl Booster { /// Load a model from a file pub fn load>(path: P) -> LightGBMResult { - let path_str = path.as_ref().to_str() - .ok_or_else(|| LightGBMError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| LightGBMError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| LightGBMError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| LightGBMError { + description: format!("Path contains NUL byte: {}", e), + })?; let mut handle: sys::BoosterHandle = ptr::null_mut(); let mut num_iterations = 0i32; @@ -85,10 +83,9 @@ impl Booster { /// let booster = Booster::load_from_string(&model_string).unwrap(); /// ``` pub fn load_from_string(model_str: &str) -> LightGBMResult { - let model_c_str = CString::new(model_str) - .map_err(|e| LightGBMError { - description: format!("Model string contains NUL byte: {}", e), - })?; + let model_c_str = CString::new(model_str).map_err(|e| LightGBMError { + description: format!("Model string contains NUL byte: {}", e), + })?; let mut handle: sys::BoosterHandle = ptr::null_mut(); let mut num_iterations = 0i32; @@ -118,10 +115,9 @@ impl Booster { /// ``` pub fn load_from_buffer(buffer: &[u8]) -> LightGBMResult { // Convert bytes to string (LightGBM models are text-based) - let model_str = std::str::from_utf8(buffer) - .map_err(|e| LightGBMError { - description: format!("Invalid UTF-8 in model buffer: {}", e), - })?; + let model_str = std::str::from_utf8(buffer).map_err(|e| LightGBMError { + description: format!("Invalid UTF-8 in model buffer: {}", e), + })?; Self::load_from_string(model_str) } @@ -173,7 +169,10 @@ impl Booster { return Err(LightGBMError { description: format!( "Input data size mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_cols, data.len() + expected_len, + num_rows, + num_cols, + data.len() ), }); } @@ -243,7 +242,10 @@ impl Booster { return Err(LightGBMError { description: format!( "Input data size mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_cols, data.len() + expected_len, + num_rows, + num_cols, + data.len() ), }); } diff --git a/src/sys.rs b/src/sys.rs index a38a13a..cd503e4 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1,5 +1,6 @@ #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +#![allow(dead_code)] include!(concat!(env!("OUT_DIR"), "/bindings.rs"));