From 44985438aef9220f06b7c97ea12f6bd0cbc28f0d Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Fri, 22 May 2026 12:25:07 -0400 Subject: [PATCH] feat(fold): add ObjectiveRegistry for dynamic objective dispatch Ports ObjectiveRegistry from khive-internal, adapted to OSS Selection.item API. Provides thread-safe registration and lookup of named objective functions for fold composition. Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/khive-fold/src/objective/mod.rs | 2 + crates/khive-fold/src/objective/registry.rs | 275 ++++++++++++++++++++ 2 files changed, 277 insertions(+) create mode 100644 crates/khive-fold/src/objective/registry.rs diff --git a/crates/khive-fold/src/objective/mod.rs b/crates/khive-fold/src/objective/mod.rs index c4504982..e2040fb6 100644 --- a/crates/khive-fold/src/objective/mod.rs +++ b/crates/khive-fold/src/objective/mod.rs @@ -4,11 +4,13 @@ pub mod builtin; pub mod compose; mod context; pub mod error; +pub mod registry; mod selection; mod traits; pub use context::ObjectiveContext; pub use error::{ObjectiveError, ObjectiveResult}; +pub use registry::{ObjectiveRegistry, RegisteredObjective}; pub use selection::Selection; pub use traits::{objective_fn, DeterministicObjective, Objective}; diff --git a/crates/khive-fold/src/objective/registry.rs b/crates/khive-fold/src/objective/registry.rs new file mode 100644 index 00000000..4ce97815 --- /dev/null +++ b/crates/khive-fold/src/objective/registry.rs @@ -0,0 +1,275 @@ +//! Objective registry for dynamic dispatch. + +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::RwLock; + +use crate::{Objective, ObjectiveContext, ObjectiveError, ObjectiveResult, Selection}; + +/// A type-erased objective wrapper. +pub struct RegisteredObjective { + /// Name of the objective + pub name: String, + /// Description + pub description: Option, + /// The objective implementation + objective: Box>, +} + +impl RegisteredObjective { + /// Create a new registered objective + pub fn new(name: impl Into, objective: Box>) -> Self { + Self { + name: name.into(), + description: None, + objective, + } + } + + /// Add a description + pub fn with_description(mut self, desc: impl Into) -> Self { + self.description = Some(desc.into()); + self + } + + /// Score a candidate + pub fn score(&self, candidate: &T, context: &ObjectiveContext) -> f64 { + self.objective.score(candidate, context) + } + + /// Select from candidates + pub fn select<'a>( + &self, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + self.objective.select(candidates, context) + } +} + +/// Registry of named objectives. +pub struct ObjectiveRegistry { + objectives: RwLock>>>, + default: RwLock>, +} + +impl Default for ObjectiveRegistry { + fn default() -> Self { + Self::new() + } +} + +impl ObjectiveRegistry { + /// Create a new empty registry + pub fn new() -> Self { + Self { + objectives: RwLock::new(HashMap::new()), + default: RwLock::new(None), + } + } + + /// Register an objective. + /// + /// Returns the previously registered objective if one existed with the same name. + pub fn register( + &self, + name: impl Into, + objective: Box>, + ) -> Option>> { + let name = name.into(); + let registered = Arc::new(RegisteredObjective::new(name.clone(), objective)); + + let mut objectives = self.objectives.write(); + objectives.insert(name, registered) + } + + /// Register an objective with description. + /// + /// Returns the previously registered objective if one existed with the same name. + pub fn register_with_desc( + &self, + name: impl Into, + description: impl Into, + objective: Box>, + ) -> Option>> { + let name = name.into(); + let registered = Arc::new( + RegisteredObjective::new(name.clone(), objective).with_description(description), + ); + + let mut objectives = self.objectives.write(); + objectives.insert(name, registered) + } + + /// Set the default objective + pub fn set_default(&self, name: impl Into) -> ObjectiveResult<()> { + let name = name.into(); + + let objectives = self.objectives.read(); + if !objectives.contains_key(&name) { + return Err(ObjectiveError::NotFound(name)); + } + drop(objectives); + + let mut default = self.default.write(); + *default = Some(name); + Ok(()) + } + + /// Get an objective by name + pub fn get(&self, name: &str) -> ObjectiveResult>> { + let objectives = self.objectives.read(); + objectives + .get(name) + .cloned() + .ok_or_else(|| ObjectiveError::NotFound(name.to_string())) + } + + /// Get the default objective + pub fn get_default(&self) -> ObjectiveResult>> { + let default = self.default.read(); + match default.as_ref() { + Some(name) => { + let name: String = name.clone(); + drop(default); + self.get(&name) + } + None => Err(ObjectiveError::NotFound("No default set".to_string())), + } + } + + /// List all registered objective names. + /// + /// Returns names in sorted order for deterministic output. + pub fn list(&self) -> Vec { + let objectives = self.objectives.read(); + let mut names: Vec = objectives.keys().cloned().collect(); + names.sort(); + names + } + + /// Check if an objective is registered + pub fn contains(&self, name: &str) -> bool { + let objectives = self.objectives.read(); + objectives.contains_key(name) + } + + /// Score using a named objective + pub fn score( + &self, + name: &str, + candidate: &T, + context: &ObjectiveContext, + ) -> ObjectiveResult { + let objective = self.get(name)?; + Ok(objective.score(candidate, context)) + } + + /// Select using a named objective + pub fn select<'a>( + &self, + name: &str, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + let objective = self.get(name)?; + objective.select(candidates, context) + } + + /// Select using the default objective + pub fn select_default<'a>( + &self, + candidates: &'a [T], + context: &ObjectiveContext, + ) -> ObjectiveResult> { + let objective = self.get_default()?; + objective.select(candidates, context) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::objective_fn; + + #[test] + fn test_register_and_get() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let old = registry.register("max", Box::new(obj)); + + assert!(old.is_none()); + assert!(registry.contains("max")); + assert!(!registry.contains("min")); + } + + #[test] + fn test_register_overwrites() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64)); + + let old1 = registry.register("test", Box::new(obj1)); + assert!(old1.is_none()); + + let old2 = registry.register("test", Box::new(obj2)); + assert!(old2.is_some()); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select("test", &candidates, &ObjectiveContext::new()) + .unwrap(); + assert_eq!(*selection.item, 1); + } + + #[test] + fn test_select_by_name() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + registry.register("max", Box::new(obj)); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select("max", &candidates, &ObjectiveContext::new()) + .unwrap(); + + assert_eq!(*selection.item, 5); + } + + #[test] + fn test_default_objective() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + registry.register("max", Box::new(obj)); + registry.set_default("max").unwrap(); + + let candidates = vec![1, 5, 3]; + let selection = registry + .select_default(&candidates, &ObjectiveContext::new()) + .unwrap(); + + assert_eq!(*selection.item, 5); + } + + #[test] + fn test_list_objectives_sorted() { + let registry: ObjectiveRegistry = ObjectiveRegistry::new(); + + let obj1 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| *n as f64); + let obj2 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| -(*n as f64)); + let obj3 = objective_fn(|n: &i32, _ctx: &ObjectiveContext| (*n as f64).abs()); + + registry.register("zebra", Box::new(obj1)); + registry.register("alpha", Box::new(obj2)); + registry.register("middle", Box::new(obj3)); + + let names = registry.list(); + assert_eq!(names.len(), 3); + assert_eq!(names, vec!["alpha", "middle", "zebra"]); + } +}