diff --git a/Cargo.toml b/Cargo.toml index e783256..39ac8f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,13 +9,20 @@ categories = ["science"] readme = "README.md" rust-version = "1.70" +[dependencies] +polars = { version = "0.45", optional = true, default-features = false, features = ["dtype-full"] } +rayon = "1.10" + [build-dependencies] bindgen = "0.72.0" -ureq = "2.0" +ureq = { version = "2.0", features = ["json"] } +zip = "0.6" +serde_json = "1.0" [features] default = [] gpu = [] +polars = ["dep:polars"] [[example]] name = "basic_usage" diff --git a/build.rs b/build.rs index 3f577d3..b436daf 100644 --- a/build.rs +++ b/build.rs @@ -1,9 +1,10 @@ extern crate bindgen; use std::env; -use std::path::{Path, PathBuf}; use std::fs; use std::io; +use std::path::{Path, PathBuf}; +use serde_json::Value; fn get_lightgbm_version() -> String { env::var("LIGHTGBM_VERSION").unwrap_or_else(|_| "4.6.0".to_string()) @@ -54,7 +55,7 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download c_api.h: HTTP {}", status).into()); } @@ -72,7 +73,7 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 300 { + if !(200..300).contains(&status) { return Err(format!("Failed to download export.h: HTTP {}", status).into()); } @@ -87,7 +88,10 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 200 && response.status() < 300 => { @@ -102,7 +106,10 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box= 200 && resp.status() < 300 => { @@ -117,70 +124,240 @@ fn download_lightgbm_headers(out_dir: &Path) -> Result<(), Box { - println!("cargo:warning=arrow.h not available for this version (optional, only in v4.2.0+)"); + println!( + "cargo:warning=arrow.h not available for this version (optional, only in v4.2.0+)" + ); } } Ok(()) } -fn download_compiled_library(out_dir: &Path) -> Result<(), Box> { - let (os, _arch) = get_platform_info(); - let version = get_lightgbm_version(); +/// Try to find the wheel URL from PyPI JSON API +fn find_wheel_url_from_pypi( + version: &str, + os: &str, + arch: &str, +) -> Result> { + let pypi_api_url = format!("https://pypi.org/pypi/lightgbm/{}/json", version); + println!("cargo:warning=Querying PyPI API: {}", pypi_api_url); + + let response = ureq::get(&pypi_api_url).call()?; + let json: Value = response.into_json()?; + + // Get the URLs array for this version + let urls = json["urls"] + .as_array() + .ok_or("No URLs found in PyPI response")?; + + // Determine the wheel pattern to match + let wheel_patterns: Vec = match (os, arch) { + ("darwin", "aarch64") => vec![ + format!("lightgbm-{}-py3-none-macosx_12_0_arm64.whl", version), + format!("lightgbm-{}-py3-none-macosx_11_0_arm64.whl", version), + ], + ("darwin", "x86_64") => vec![ + format!("lightgbm-{}-py3-none-macosx_10_15_x86_64.whl", version), + format!("lightgbm-{}-py3-none-macosx_10_14_x86_64.whl", version), + ], + ("linux", "aarch64") => vec![ + format!("lightgbm-{}-py3-none-manylinux2014_aarch64.whl", version), + format!("lightgbm-{}-py3-none-manylinux_2_17_aarch64.whl", version), + ], + ("linux", "x86_64") => vec![ + format!("lightgbm-{}-py3-none-manylinux_2_28_x86_64.whl", version), + format!("lightgbm-{}-py3-none-manylinux2014_x86_64.whl", version), + ], + ("windows", _) => vec![format!("lightgbm-{}-py3-none-win_amd64.whl", version)], + _ => return Err(format!("Unsupported platform: {} {}", os, arch).into()), + }; - // LightGBM release binaries (platform-specific) - let (lib_filename, download_url) = match os.as_str() { - "linux" => ( - "lib_lightgbm.so".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.so", - version - ), - ), - "darwin" => ( - "lib_lightgbm.dylib".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dylib", - version - ), - ), - "windows" => ( - "lib_lightgbm.dll".to_string(), - format!( - "https://github.com/microsoft/LightGBM/releases/download/v{}/lib_lightgbm.dll", - version - ), - ), - _ => return Err(format!("Unsupported platform: {}", os).into()), + // Try to find a matching wheel + for pattern in &wheel_patterns { + for url_obj in urls { + if let Some(filename) = url_obj["filename"].as_str() { + if filename == pattern { + if let Some(url) = url_obj["url"].as_str() { + println!("cargo:warning=Found wheel URL from PyPI: {}", url); + return Ok(url.to_string()); + } + } + } + } + } + + Err(format!( + "No suitable wheel found in PyPI for {} {} (tried: {:?})", + os, arch, wheel_patterns + ) + .into()) +} + +/// Try to download library directly from GitHub releases +fn try_github_release( + version: &str, + os: &str, + arch: &str, + lib_dir: &Path, +) -> Result<(), Box> { + // Determine the library filename and extension + let (lib_name, extension) = match os { + "darwin" => ("lib_lightgbm", "dylib"), + "linux" => ("lib_lightgbm", "so"), + "windows" => ("lib_lightgbm", "dll"), + _ => return Err(format!("Unsupported OS: {}", os).into()), }; - println!( - "cargo:warning=Downloading LightGBM v{} library from: {}", - version, download_url - ); + // Try various GitHub release asset naming patterns + let possible_names = vec![ + format!("{}_{}.{}", lib_name, os, extension), + format!("{}_{}_{}.{}", lib_name, os, arch, extension), + format!("{}.{}", lib_name, extension), + ]; + + for asset_name in &possible_names { + let github_url = format!( + "https://github.com/microsoft/LightGBM/releases/download/v{}/{}", + version, asset_name + ); + + println!( + "cargo:warning=Trying GitHub release asset: {}", + github_url + ); + + match ureq::get(&github_url).call() { + Ok(response) if response.status() >= 200 && response.status() < 300 => { + println!("cargo:warning=Found GitHub release asset: {}", github_url); + + let lib_path = lib_dir.join(format!("{}.{}", lib_name, extension)); + let mut file = fs::File::create(&lib_path)?; + io::copy(&mut response.into_reader(), &mut file)?; + + println!( + "cargo:warning=✓ Successfully downloaded library from GitHub to: {}", + lib_dir.display() + ); + return Ok(()); + } + _ => continue, + } + } + + Err(format!( + "No GitHub release asset found for LightGBM v{} (tried: {:?})", + version, possible_names + ) + .into()) +} + +fn download_compiled_library(out_dir: &Path) -> Result<(), Box> { + let (os, arch) = get_platform_info(); + let version = get_lightgbm_version(); // Create the library directory let lib_dir = out_dir.join("libs"); fs::create_dir_all(&lib_dir)?; - // Download the library directly into the `libs` directory with its correct name - let lib_path = lib_dir.join(&lib_filename); - let mut dest = fs::File::create(&lib_path)?; + // Strategy 1: Try PyPI JSON API to find the correct wheel URL (most reliable) + println!("cargo:warning=Querying PyPI for LightGBM wheel..."); + match find_wheel_url_from_pypi(&version, &os, &arch) { + Ok(wheel_url) => { + println!("cargo:warning=Downloading wheel from: {}", wheel_url); - let response = ureq::get(&download_url).call()?; - let status = response.status(); - if status < 200 || status >= 300 { - return Err(format!("Failed to download library: HTTP {}", status).into()); - } + let response = ureq::get(&wheel_url).call()?; + let status = response.status(); + if !(200..300).contains(&status) { + return Err(format!("Failed to download wheel: HTTP {}", status).into()); + } - io::copy(&mut response.into_reader(), &mut dest)?; + // Extract filename from URL + let wheel_name = wheel_url + .split('/') + .last() + .unwrap_or("lightgbm.whl") + .to_string(); + let wheel_path = out_dir.join(&wheel_name); + + // Download wheel + let mut wheel_file = fs::File::create(&wheel_path)?; + io::copy(&mut response.into_reader(), &mut wheel_file)?; + drop(wheel_file); + + println!("cargo:warning=✓ Downloaded wheel"); + + // Extract the library from the wheel + println!("cargo:warning=Extracting library from wheel"); + + let file = fs::File::open(&wheel_path)?; + let mut archive = zip::ZipArchive::new(file)?; + + // Determine the library extension + let lib_extension = match os.as_str() { + "darwin" => "dylib", + "linux" => "so", + "windows" => "dll", + _ => return Err(format!("Unsupported OS: {}", os).into()), + }; + + // Look for the library in the wheel + let mut found = false; + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let name = file.name().to_string(); + + if name.contains("lib_lightgbm") && name.ends_with(lib_extension) { + println!("cargo:warning=Found library at: {}", name); + + let lib_path = lib_dir.join(format!("lib_lightgbm.{}", lib_extension)); + let mut outfile = fs::File::create(&lib_path)?; + io::copy(&mut file, &mut outfile)?; + + println!( + "cargo:warning=✓ Successfully extracted LightGBM library to: {}", + lib_dir.display() + ); + found = true; + break; + } + } - println!( - "cargo:warning=Downloaded LightGBM library to: {}", - lib_path.display() - ); + if !found { + return Err(format!( + "Could not find lib_lightgbm.{} in wheel", + lib_extension + ) + .into()); + } - Ok(()) + Ok(()) + } + Err(pypi_err) => { + // Strategy 2: Fallback to GitHub releases + println!( + "cargo:warning=PyPI wheel not found ({}), trying GitHub releases as fallback...", + pypi_err + ); + match try_github_release(&version, &os, &arch, &lib_dir) { + Ok(_) => Ok(()), + Err(github_err) => { + // Strategy 3: Final error with helpful message + Err(format!( + "Failed to download LightGBM library:\n\ + - PyPI: {}\n\ + - GitHub releases: {}\n\ + \n\ + Please try:\n\ + 1. Using a different LightGBM version (set LIGHTGBM_VERSION env var)\n\ + 2. Building LightGBM from source and setting LIGHTGBM_LIB_DIR env var\n\ + 3. Checking https://pypi.org/project/lightgbm/{}/", + pypi_err, github_err, version + ) + .into()) + } + } + } + } } fn main() { @@ -224,7 +401,6 @@ fn main() { .blocklist_type(".*_Tp.*") .blocklist_type(".*_Pred.*") .size_t_is_usize(true) - .rustfmt_bindings(true) .generate() .expect("Unable to generate bindings."); @@ -253,8 +429,7 @@ fn main() { .join(env::var("PROFILE").unwrap()); let lib_dest_path = target_dir.join(lib_filename); - fs::copy(&lib_source_path, &lib_dest_path) - .expect("Failed to copy library to target directory"); + fs::copy(&lib_source_path, &lib_dest_path).expect("Failed to copy library to target directory"); // Set the library search path for the build-time linker let lib_search_path = out_dir.join("libs"); @@ -269,21 +444,33 @@ fn main() { // For macOS, add multiple rpath entries for IDE compatibility println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); // Add the target directory to rpath as well if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/debug", target_root.display()); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}/release", target_root.display()); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() + ); } - }, + } "linux" => { // For Linux, use $ORIGIN println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_search_path.display()); - }, + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + } _ => {} // No rpath needed for Windows } - println!("cargo:rustc-link-lib=dylib=lib_lightgbm"); + println!("cargo:rustc-link-lib=dylib=_lightgbm"); } diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index 6c2b70d..2638e08 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -1,4 +1,4 @@ -use lightgbm_rust::{Booster, predict_type}; +use lightgbm_rust::{predict_type, Booster}; fn main() -> Result<(), Box> { // Load a trained LightGBM model @@ -16,10 +16,7 @@ fn main() -> Result<(), Box> { println!(" Classes: {}", num_classes); // Example data with f32 (more memory efficient for large datasets) - let data_f32: Vec = vec![ - 1.0, 2.0, 3.0, 4.0, 5.0, - 2.0, 3.0, 4.0, 5.0, 6.0, - ]; + let data_f32: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let num_rows = 2; let num_cols = 5; @@ -32,11 +29,13 @@ fn main() -> Result<(), Box> { println!("Raw scores: {:?}", raw_scores); println!("\n--- Leaf Index Prediction ---"); - let leaf_indices = booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::LEAF_INDEX)?; + let leaf_indices = + booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::LEAF_INDEX)?; println!("Leaf indices: {:?}", leaf_indices); println!("\n--- Feature Contribution (SHAP) ---"); - let contributions = booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::CONTRIB)?; + let contributions = + booster.predict_f32(&data_f32, num_rows, num_cols, predict_type::CONTRIB)?; println!("Feature contributions: {:?}", contributions); Ok(()) diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index f4c3037..5cc36f8 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -1,4 +1,4 @@ -use lightgbm_rust::{Booster, predict_type}; +use lightgbm_rust::{predict_type, Booster}; fn main() -> Result<(), Box> { // Load a trained LightGBM model @@ -29,15 +29,16 @@ fn main() -> Result<(), Box> { // Example: Predict for multiple samples (batch prediction) let batch_data = vec![ - 1.0, 2.0, 3.0, 4.0, // Sample 1 - 2.0, 3.0, 4.0, 5.0, // Sample 2 - 3.0, 4.0, 5.0, 6.0, // Sample 3 + 1.0, 2.0, 3.0, 4.0, // Sample 1 + 2.0, 3.0, 4.0, 5.0, // Sample 2 + 3.0, 4.0, 5.0, 6.0, // Sample 3 ]; let num_rows = 3; let num_cols = 4; println!("\nMaking batch prediction..."); - let batch_predictions = booster.predict(&batch_data, num_rows, num_cols, predict_type::NORMAL)?; + let batch_predictions = + booster.predict(&batch_data, num_rows, num_cols, predict_type::NORMAL)?; println!("Batch predictions: {:?}", batch_predictions); diff --git a/src/error.rs b/src/error.rs index 37793a2..3e78ca8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ +use crate::sys; use std::ffi::CStr; use std::fmt; -use crate::sys; pub type LightGBMResult = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 03f7484..aee86e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,11 @@ pub use crate::error::{LightGBMError, LightGBMResult}; mod model; pub use crate::model::Booster; +#[cfg(feature = "polars")] +pub mod polars_ext; +#[cfg(feature = "polars")] +pub use crate::polars_ext::BoosterPolarsExt; + // Re-export prediction type constants for convenience pub mod predict_type { /// Normal prediction diff --git a/src/model.rs b/src/model.rs index 4df0943..20ff678 100644 --- a/src/model.rs +++ b/src/model.rs @@ -49,14 +49,12 @@ pub struct Booster { impl Booster { /// Load a model from a file pub fn load>(path: P) -> LightGBMResult { - let path_str = path.as_ref().to_str() - .ok_or_else(|| LightGBMError { - description: "Path contains invalid UTF-8 characters".to_string(), - })?; - let path_c_str = CString::new(path_str) - .map_err(|e| LightGBMError { - description: format!("Path contains NUL byte: {}", e), - })?; + let path_str = path.as_ref().to_str().ok_or_else(|| LightGBMError { + description: "Path contains invalid UTF-8 characters".to_string(), + })?; + let path_c_str = CString::new(path_str).map_err(|e| LightGBMError { + description: format!("Path contains NUL byte: {}", e), + })?; let mut handle: sys::BoosterHandle = ptr::null_mut(); let mut num_iterations = 0i32; @@ -85,10 +83,9 @@ impl Booster { /// let booster = Booster::load_from_string(&model_string).unwrap(); /// ``` pub fn load_from_string(model_str: &str) -> LightGBMResult { - let model_c_str = CString::new(model_str) - .map_err(|e| LightGBMError { - description: format!("Model string contains NUL byte: {}", e), - })?; + let model_c_str = CString::new(model_str).map_err(|e| LightGBMError { + description: format!("Model string contains NUL byte: {}", e), + })?; let mut handle: sys::BoosterHandle = ptr::null_mut(); let mut num_iterations = 0i32; @@ -118,10 +115,9 @@ impl Booster { /// ``` pub fn load_from_buffer(buffer: &[u8]) -> LightGBMResult { // Convert bytes to string (LightGBM models are text-based) - let model_str = std::str::from_utf8(buffer) - .map_err(|e| LightGBMError { - description: format!("Invalid UTF-8 in model buffer: {}", e), - })?; + let model_str = std::str::from_utf8(buffer).map_err(|e| LightGBMError { + description: format!("Invalid UTF-8 in model buffer: {}", e), + })?; Self::load_from_string(model_str) } @@ -173,7 +169,10 @@ impl Booster { return Err(LightGBMError { description: format!( "Input data size mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_cols, data.len() + expected_len, + num_rows, + num_cols, + data.len() ), }); } @@ -243,7 +242,10 @@ impl Booster { return Err(LightGBMError { description: format!( "Input data size mismatch: expected {} elements ({}×{}), got {}", - expected_len, num_rows, num_cols, data.len() + expected_len, + num_rows, + num_cols, + data.len() ), }); } diff --git a/src/polars_ext.rs b/src/polars_ext.rs new file mode 100644 index 0000000..c1652ce --- /dev/null +++ b/src/polars_ext.rs @@ -0,0 +1,110 @@ +use crate::error::{LightGBMError, LightGBMResult}; +use crate::Booster; +use polars::prelude::*; + +/// Extension trait for LightGBM Booster to support Polars DataFrames +pub trait BoosterPolarsExt { + /// Predict using a Polars DataFrame as input + /// + /// This method efficiently converts the DataFrame to the format LightGBM expects + /// and runs prediction. All numeric columns will be used as features. + /// + /// # Arguments + /// * `df` - Input DataFrame with numeric features + /// * `predict_type` - Type of prediction (see `predict_type` module) + /// + /// # Returns + /// A vector of prediction values + /// + /// # Example + /// ```no_run + /// # use lightgbm_rust::{Booster, BoosterPolarsExt, predict_type}; + /// # use polars::prelude::*; + /// let booster = Booster::load("model.txt").unwrap(); + /// + /// let df = df! { + /// "feature1" => [1.0f32, 2.0, 3.0], + /// "feature2" => [4.0f32, 5.0, 6.0], + /// }.unwrap(); + /// + /// let predictions = booster.predict_dataframe(&df, predict_type::NORMAL).unwrap(); + /// ``` + fn predict_dataframe(&self, df: &DataFrame, predict_type: i32) -> LightGBMResult>; + + /// Predict using specific columns from a Polars DataFrame + /// + /// # Arguments + /// * `df` - Input DataFrame + /// * `columns` - Column names to use as features (in order) + /// * `predict_type` - Type of prediction + fn predict_dataframe_with_columns( + &self, + df: &DataFrame, + columns: &[&str], + predict_type: i32, + ) -> LightGBMResult>; +} + +impl BoosterPolarsExt for Booster { + fn predict_dataframe(&self, df: &DataFrame, predict_type: i32) -> LightGBMResult> { + let (data, num_rows, num_cols) = dataframe_to_dense(df)?; + self.predict(&data, num_rows, num_cols, predict_type) + } + + fn predict_dataframe_with_columns( + &self, + df: &DataFrame, + columns: &[&str], + predict_type: i32, + ) -> LightGBMResult> { + let column_names: Vec = columns.iter().map(|s| s.to_string()).collect(); + let selected = df.select(column_names).map_err(|e| LightGBMError { + description: format!("Failed to select columns: {}", e), + })?; + + let (data, num_rows, num_cols) = dataframe_to_dense(&selected)?; + self.predict(&data, num_rows, num_cols, predict_type) + } +} + +/// Convert a Polars DataFrame to dense f64 data in row-major format +/// +/// Optimized column-by-column conversion using Polars' cast for simplicity and speed. +fn dataframe_to_dense(df: &DataFrame) -> LightGBMResult<(Vec, i32, i32)> { + let num_rows = df.height(); + let num_cols = df.width(); + + if num_rows == 0 || num_cols == 0 { + return Err(LightGBMError { + description: "DataFrame has zero rows or columns".to_string(), + }); + } + + // Pre-allocate with exact size + let total_elements = num_rows * num_cols; + let mut data = vec![0.0f64; total_elements]; + + // Process column by column - cast to Float64 for simplicity and speed + for (col_idx, column) in df.get_columns().iter().enumerate() { + let series = column.as_materialized_series(); + + // Cast to Float64 - Polars handles all type conversions efficiently + let f64_series = series.cast(&DataType::Float64).map_err(|e| LightGBMError { + description: format!("Failed to cast column to f64: {}", e), + })?; + + let ca = f64_series.f64().map_err(|e| LightGBMError { + description: format!("Failed to get f64 array: {}", e), + })?; + + for (row_idx, opt_val) in ca.iter().enumerate() { + let val = opt_val.ok_or_else(|| LightGBMError { + description: format!("Null value at row {}, col {}", row_idx, col_idx), + })?; + data[row_idx * num_cols + col_idx] = val; + } + } + + Ok((data, num_rows as i32, num_cols as i32)) +} +