From 287de76e9c5ab73420cadb955614469f1f475e28 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 01:05:11 +0200 Subject: [PATCH 1/7] support all verions. --- build.rs | 116 +++++++++++++++++++++++++++++++++++++++------------ src/model.rs | 81 +++++++++++++++++++++++++++-------- 2 files changed, 153 insertions(+), 44 deletions(-) diff --git a/build.rs b/build.rs index 52e8380..0444d74 100644 --- a/build.rs +++ b/build.rs @@ -69,37 +69,75 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( - "libcatboostmodel.so".to_string(), // The correct library name for the linker - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so", - version,version + // Parse version to determine URL format + // v1.0.x - v1.1.x use simple filenames + // v1.2+ use platform-specific versioned filenames + let version_parts: Vec<&str> = version.split('.').collect(); + let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1); + let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + + let use_new_format = major > 1 || (major == 1 && minor >= 2); + + // Determine download URL based on version and platform + let (lib_filename, download_url) = if use_new_format { + // v1.2+ format with platform and version in filename + match (os.as_str(), arch.as_str()) { + ("linux", "x86_64") => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so", + version, version + ), ), - ), - ("linux", "aarch64") => ( - "libcatboostmodel.so".to_string(), - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so", - version, version + ("linux", "aarch64") => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so", + version, version + ), ), - ), - ("darwin", "x86_64") | ("darwin", "aarch64") => ( - "libcatboostmodel.dylib".to_string(), // The correct library name for macOS - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib", - version, version + ("darwin", "x86_64") | ("darwin", "aarch64") => ( + "libcatboostmodel.dylib".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib", + version, version + ), ), - ), - ("windows", "x86_64") => ( - "catboostmodel.dll".to_string(), // The correct library name for Windows - format!( - "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", - version + ("windows", "x86_64") => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", + version + ), + ), + _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + } + } else { + // v1.0.x - v1.1.x format with simple filenames + match os.as_str() { + "linux" => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.so", + version + ), ), - ), - _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + "darwin" => ( + "libcatboostmodel.dylib".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.dylib", + version + ), + ), + "windows" => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", + version + ), + ), + _ => return Err(format!("Unsupported platform: {}", os).into()), + } }; println!( @@ -136,6 +174,30 @@ fn main() { let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let cb_model_interface_root = out_dir.join("libs/model_interface"); + // Parse version for feature detection + let version = get_catboost_version(); + let version_parts: Vec<&str> = version.split('.').collect(); + let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1); + let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + let patch: u32 = version_parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); + + // Emit cfg flags for version-specific features + // v1.1.1+: Embedding features support + if major > 1 || (major == 1 && minor > 1) || (major == 1 && minor == 1 && patch >= 1) { + println!("cargo:rustc-cfg=catboost_embeddings"); + } + + // v1.2+: Text features count function + if major > 1 || (major == 1 && minor >= 2) { + println!("cargo:rustc-cfg=catboost_text_count"); + } + + // v1.2.3+: Staged predictions and feature indices + if major > 1 || (major == 1 && minor > 2) || (major == 1 && minor == 2 && patch >= 3) { + println!("cargo:rustc-cfg=catboost_staged_prediction"); + println!("cargo:rustc-cfg=catboost_feature_indices"); + } + // Download the model interface headers if let Err(e) = download_model_interface_headers(&out_dir) { eprintln!("Failed to download model interface headers: {}", e); diff --git a/src/model.rs b/src/model.rs index 8a9835c..9b19180 100644 --- a/src/model.rs +++ b/src/model.rs @@ -183,23 +183,54 @@ impl Model { .collect::>(); let mut prediction = vec![0.0; object_count.unwrap() * self.get_dimensions_count()]; - CatBoostError::check_return_value(unsafe { - sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures( - self.handle, - object_count.unwrap(), - float_features_ptr.as_mut_ptr(), - if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, - hashed_cat_features_ptr.as_mut_ptr(), - if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, - text_features_ptr.as_mut_ptr(), - if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, - embedding_features_ptr.as_mut_ptr(), - embedding_dimensions.as_mut_ptr(), - embedding_dimensions.len(), - prediction.as_mut_ptr(), - prediction.len(), - ) - })?; + + #[cfg(catboost_embeddings)] + { + // v1.1.1+: Use function with embedding support + CatBoostError::check_return_value(unsafe { + sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures( + self.handle, + object_count.unwrap(), + float_features_ptr.as_mut_ptr(), + if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, + hashed_cat_features_ptr.as_mut_ptr(), + if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, + text_features_ptr.as_mut_ptr(), + if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, + embedding_features_ptr.as_mut_ptr(), + embedding_dimensions.as_mut_ptr(), + embedding_dimensions.len(), + prediction.as_mut_ptr(), + prediction.len(), + ) + })?; + } + + #[cfg(not(catboost_embeddings))] + { + // v1.0.x: Use function without embedding support (embeddings will be ignored) + if !features.embedding_features.as_ref().is_empty() { + return Err(CatBoostError { + description: "Embedding features are not supported in this CatBoost version. Please use v1.1.1 or later.".to_string() + }); + } + + CatBoostError::check_return_value(unsafe { + sys::CalcModelPredictionWithHashedCatFeaturesAndTextFeatures( + self.handle, + object_count.unwrap(), + float_features_ptr.as_mut_ptr(), + if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, + hashed_cat_features_ptr.as_mut_ptr(), + if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, + text_features_ptr.as_mut_ptr(), + if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, + prediction.as_mut_ptr(), + prediction.len(), + ) + })?; + } + Ok(prediction) } @@ -237,15 +268,31 @@ impl Model { } /// Get expected text feature count for model + /// Only available in CatBoost v1.2+ + #[cfg(catboost_text_count)] pub fn get_text_features_count(&self) -> usize { unsafe { sys::GetTextFeaturesCount(self.handle) } } + /// Get expected text feature count for model (returns 0 for older versions) + #[cfg(not(catboost_text_count))] + pub fn get_text_features_count(&self) -> usize { + 0 + } + /// Get expected embedding feature count for model + /// Only available in CatBoost v1.1.1+ + #[cfg(catboost_embeddings)] pub fn get_embedding_features_count(&self) -> usize { unsafe { sys::GetEmbeddingFeaturesCount(self.handle) } } + /// Get expected embedding feature count for model (returns 0 for older versions) + #[cfg(not(catboost_embeddings))] + pub fn get_embedding_features_count(&self) -> usize { + 0 + } + /// Get number of trees in model pub fn get_tree_count(&self) -> usize { unsafe { sys::GetTreeCount(self.handle)} From 14a4f87aafa0c89a34a6328eca70586ca53c9d42 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 09:16:45 +0200 Subject: [PATCH 2/7] change winbdows download url. --- build.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/build.rs b/build.rs index 0444d74..4a0658e 100644 --- a/build.rs +++ b/build.rs @@ -106,8 +106,15 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( "catboostmodel.dll".to_string(), format!( - "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", - version + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-x86_64-{}.dll", + version, version + ), + ), + ("windows", "aarch64") => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-aarch64-{}.dll", + version, version ), ), _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), From 514b3683cd0f353355692ec42f95a0ceed520599 Mon Sep 17 00:00:00 2001 From: aryehklein-rise Date: Thu, 16 Apr 2026 17:02:44 +0300 Subject: [PATCH 3/7] fix arm ci and also zero copy tests. --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5247e5e..1b93162 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,12 +40,12 @@ jobs: - os: linux arch: arm64 - runner: ubuntu-latest + runner: ubuntu-24.04-arm catboost_version: "1.2.7" - os: linux arch: arm64 - runner: ubuntu-latest + runner: ubuntu-24.04-arm catboost_version: "1.2.10" # Windows x86_64 From 2af8a738bcca5c5f73d2dde3c31d3d9ce240b60a Mon Sep 17 00:00:00 2001 From: aryehklein-rise Date: Thu, 16 Apr 2026 17:04:13 +0300 Subject: [PATCH 4/7] fix arm ci and also zero copy tests. --- src/model.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/model.rs b/src/model.rs index 1db07af..8a96025 100644 --- a/src/model.rs +++ b/src/model.rs @@ -375,11 +375,12 @@ mod tests { #[cfg(catboost_zero_copy)] #[test] fn test_load_buffer_zero_copy() { - let model_path = "tmp/model.cbm"; - if !std::path::Path::new(model_path).exists() { - eprintln!("Skipping test: {} not found", model_path); - return; - } + let model_path = "tmp/model.bin"; + assert!( + std::path::Path::new(model_path).exists(), + "Test fixture missing: {}", + model_path + ); let model_from_file = Model::load(model_path).unwrap(); let buffer = std::fs::read(model_path).unwrap(); From 510872e1dc168a92480d6b95052b0a4940ba3e71 Mon Sep 17 00:00:00 2001 From: aryehklein-rise Date: Thu, 16 Apr 2026 17:04:49 +0300 Subject: [PATCH 5/7] remove test. --- src/model.rs | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/src/model.rs b/src/model.rs index 8a96025..cbd106b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -367,40 +367,3 @@ impl Drop for Model { unsafe { sys::ModelCalcerDelete(self.handle) }; } } - -#[cfg(test)] -mod tests { - use super::*; - - #[cfg(catboost_zero_copy)] - #[test] - fn test_load_buffer_zero_copy() { - let model_path = "tmp/model.bin"; - assert!( - std::path::Path::new(model_path).exists(), - "Test fixture missing: {}", - model_path - ); - - let model_from_file = Model::load(model_path).unwrap(); - let buffer = std::fs::read(model_path).unwrap(); - let model_from_buffer = Model::load_buffer_zero_copy(buffer).unwrap(); - - assert_eq!( - model_from_file.get_float_features_count(), - model_from_buffer.get_float_features_count() - ); - assert_eq!( - model_from_file.get_cat_features_count(), - model_from_buffer.get_cat_features_count() - ); - assert_eq!( - model_from_file.get_tree_count(), - model_from_buffer.get_tree_count() - ); - assert_eq!( - model_from_file.get_dimensions_count(), - model_from_buffer.get_dimensions_count() - ); - } -} From de68fccefc0a36f09ec2080bca591d086c0c1933 Mon Sep 17 00:00:00 2001 From: aryehklein-rise Date: Thu, 16 Apr 2026 17:15:13 +0300 Subject: [PATCH 6/7] add zerocopy load in examples. --- examples/advanced_usage.rs | 9 ++++++++- examples/basic_usage.rs | 12 ++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index 7a32291..55f8a85 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -17,7 +17,14 @@ fn main() -> Result<(), CatBoostError> { // Load the model println!("Loading model from {}...", model_path); - let model = Model::load(model_path)?; + let buffer_res = fs::read(model_path); + if buffer_res.is_err() { + return Err(CatBoostError { + description: "could not read file into memory".to_string(), + }); + } + + let model = Model::load_buffer_zero_copy(buffer_res.unwrap())?; println!("Model loaded successfully!"); diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index feaf181..4aa0f9c 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -13,12 +13,20 @@ fn main() -> Result<(), CatBoostError> { model_path ); create_simple_example()?; - return Ok(()); + return Err(CatBoostError { + description: "No model file found at {}. Creating a simple example..".to_string(), + }); } // Load the model println!("Loading model from {}...", model_path); - let model = Model::load(model_path)?; + let buffer_res = fs::read(model_path); + if buffer_res.is_err() { + return Err(CatBoostError { + description: "could not read file into memory".to_string(), + }); + } + let model = Model::load_buffer_zero_copy(buffer_res.unwrap())?; println!("Model loaded successfully!"); println!("Model info:"); From 876253e96ee5836a2439efaed6cae1a7cd1f8d4e Mon Sep 17 00:00:00 2001 From: aryehklein-rise Date: Thu, 16 Apr 2026 18:11:06 +0300 Subject: [PATCH 7/7] add zerocopy load in examples. --- README.md | 14 ++++++++++++++ examples/advanced_usage.rs | 26 +++++++++++++++++--------- examples/basic_usage.rs | 25 +++++++++++++++++-------- examples/gpu_usage.rs | 17 +++++++++++++++-- 4 files changed, 63 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 0770879..a53c637 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,20 @@ let features = ObjectsOrderFeatures::new() let predictions = model.predict(features)?; ``` +### Zero-Copy Buffer Loading (Recommended) + +`Model::load_buffer_zero_copy` is the recommended way to load models from memory. Unlike `load_buffer`, it avoids copying the model data, resulting in lower memory usage, faster loading, and no internal memory pool leaks. Requires CatBoost v1.2.9+ (the default). + +```rust +use catboost_rust::Model; +use std::fs; + +let buffer = fs::read("model.cbm")?; +let model = Model::load_buffer_zero_copy(buffer)?; +``` + +The buffer is owned by the `Model` and freed automatically when it is dropped. + ## Configuration ### CatBoost Version diff --git a/examples/advanced_usage.rs b/examples/advanced_usage.rs index 55f8a85..8365194 100644 --- a/examples/advanced_usage.rs +++ b/examples/advanced_usage.rs @@ -15,16 +15,9 @@ fn main() -> Result<(), CatBoostError> { return Ok(()); } - // Load the model + // Load the model (prefer zero-copy when available) println!("Loading model from {}...", model_path); - let buffer_res = fs::read(model_path); - if buffer_res.is_err() { - return Err(CatBoostError { - description: "could not read file into memory".to_string(), - }); - } - - let model = Model::load_buffer_zero_copy(buffer_res.unwrap())?; + let model = load_model(model_path)?; println!("Model loaded successfully!"); @@ -121,6 +114,21 @@ fn main() -> Result<(), CatBoostError> { Ok(()) } +#[cfg(catboost_zero_copy)] +fn load_model(path: &str) -> Result { + println!(" (using zero-copy buffer loading)"); + let buffer = fs::read(path).map_err(|e| CatBoostError { + description: format!("could not read file into memory: {}", e), + })?; + Model::load_buffer_zero_copy(buffer) +} + +#[cfg(not(catboost_zero_copy))] +fn load_model(path: &str) -> Result { + println!(" (using file loading - zero-copy not available in this CatBoost version)"); + Model::load(path) +} + fn display_model_info(model: &Model) -> Result<(), CatBoostError> { println!("Model Information:"); println!( diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs index 4aa0f9c..a9d7619 100644 --- a/examples/basic_usage.rs +++ b/examples/basic_usage.rs @@ -18,15 +18,9 @@ fn main() -> Result<(), CatBoostError> { }); } - // Load the model + // Load the model (prefer zero-copy when available) println!("Loading model from {}...", model_path); - let buffer_res = fs::read(model_path); - if buffer_res.is_err() { - return Err(CatBoostError { - description: "could not read file into memory".to_string(), - }); - } - let model = Model::load_buffer_zero_copy(buffer_res.unwrap())?; + let model = load_model(model_path)?; println!("Model loaded successfully!"); println!("Model info:"); @@ -96,6 +90,21 @@ fn main() -> Result<(), CatBoostError> { Ok(()) } +#[cfg(catboost_zero_copy)] +fn load_model(path: &str) -> Result { + println!(" (using zero-copy buffer loading)"); + let buffer = fs::read(path).map_err(|e| CatBoostError { + description: format!("could not read file into memory: {}", e), + })?; + Model::load_buffer_zero_copy(buffer) +} + +#[cfg(not(catboost_zero_copy))] +fn load_model(path: &str) -> Result { + println!(" (using file loading - zero-copy not available in this CatBoost version)"); + Model::load(path) +} + fn create_simple_example() -> Result<(), CatBoostError> { println!("Since no model file is available, here's how you would use the library:"); println!(); diff --git a/examples/gpu_usage.rs b/examples/gpu_usage.rs index 3948fda..5d36cee 100644 --- a/examples/gpu_usage.rs +++ b/examples/gpu_usage.rs @@ -1,12 +1,25 @@ use catboost_rust::{Model, ObjectsOrderFeatures}; +#[cfg(catboost_zero_copy)] +fn load_model(path: &str) -> Result> { + println!(" (using zero-copy buffer loading)"); + let buffer = std::fs::read(path)?; + Ok(Model::load_buffer_zero_copy(buffer)?) +} + +#[cfg(not(catboost_zero_copy))] +fn load_model(path: &str) -> Result> { + println!(" (using file loading - zero-copy not available in this CatBoost version)"); + Ok(Model::load(path)?) +} + fn main() -> Result<(), Box> { println!("CatBoost Rust Example - GPU Usage"); println!("=================================="); - // Load a model + // Load a model (prefer zero-copy when available) println!("Loading model from tmp/model.bin..."); - let model = Model::load("tmp/model.bin")?; + let model = load_model("tmp/model.bin")?; println!("Model loaded successfully!"); // Display model information