diff --git a/applpy/transform.py b/applpy/transform.py index 404b6a1..ace8f5d 100644 --- a/applpy/transform.py +++ b/applpy/transform.py @@ -522,29 +522,16 @@ def _mixture_discrete_functional(mixture_pdf_random_variables): def _mixture_discrete(mix_parameters, mixture_pdf_random_variables): - # Compute the mixture rv by summing over the weights - mixture_support = [] - mixture_functions = [] - for i in range(len(mixture_pdf_random_variables)): - for j in range(len(mixture_pdf_random_variables[i].support)): - if mixture_pdf_random_variables[i].support[j] not in mixture_support: - mixture_support.append(mixture_pdf_random_variables[i].support[j]) - mixture_functions.append( - mixture_pdf_random_variables[i].func[j] * mix_parameters[i] - ) - else: - support_index = mixture_support.index(mixture_pdf_random_variables[i].support[j]) - weighted_value = mixture_pdf_random_variables[i].func[j] * mix_parameters[i] - mixture_functions[support_index] += weighted_value - # Sort the values - sorted_support_function_pairs = list(zip(mixture_support, mixture_functions)) - sorted_support_function_pairs.sort() - mixture_functions = [] - mixture_support = [] - for i in range(len(sorted_support_function_pairs)): - mixture_functions.append(sorted_support_function_pairs[i][1]) - mixture_support.append(sorted_support_function_pairs[i][0]) - return RV(mixture_functions, mixture_support, ["discrete", "pdf"]) + fast_rv = applpy_rust.mixture_discrete( + random_variables=mixture_pdf_random_variables, + mix_weights=mix_parameters, + ) + return RV( + func=fast_rv.function, + support=fast_rv.support, + functional_form=fast_rv.functional_form, + domain_type=fast_rv.domain_type, + ) # Backward-compatible aliases for legacy APPLPy function names. diff --git a/rust/src/algorithms/number.rs b/rust/src/algorithms/number.rs index d2ce18e..33dd504 100644 --- a/rust/src/algorithms/number.rs +++ b/rust/src/algorithms/number.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::fmt; +use std::iter::Sum; use std::ops::{Add, AddAssign, Div, Mul, Sub}; use std::{collections::BTreeMap, f64::consts::E}; @@ -533,6 +534,12 @@ impl Div for Number { } } +impl Sum for Number { + fn sum>(iter: I) -> Self { + iter.fold(Number::default(), |acc, x| acc + x) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 8283c79..a65a844 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -112,6 +112,120 @@ pub fn truncate_discrete( Ok(truncated_rv) } +/// Computes the mixture of two random variables +/// +/// # Arguments +/// * `random_variables` - a list of random variables to mix +/// * `mix_weights` - the weight for each random variable. Must sum to 1 +/// +/// # Returns +/// * `mixed_rv` - the weighted mixture of the random variables +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::mixture_discrete; +/// use num_rational::Rational64; +/// +/// let rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(3, 4)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let mixed = mixture_discrete( +/// &[&rv1, &rv2], +/// &[Number::Float(0.25), Number::Float(0.75)], +/// ) +/// .unwrap(); +/// +/// assert_eq!( +/// mixed.support, +/// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] +/// ); +/// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12); +/// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12); +/// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12); +/// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf)); +/// assert!(matches!(mixed.domain_type, DomainType::Discrete)); +/// ``` +pub fn mixture_discrete( + random_variables: &[&RandomVariable], + mix_weights: &[Number], +) -> Result { + if random_variables.len() != mix_weights.len() { + return Err("the number of random variables and mix weights must be equal".to_string()); + } + + let weight_sum: f64 = mix_weights.iter().copied().sum::().to_f64(); + let tolerance = 1e-6; + let upper_bound = 1.0 + tolerance; + let lower_bound = 1.0 - tolerance; + if weight_sum < lower_bound || weight_sum > upper_bound { + return Err("the mix weights must sum to one".to_string()); + } + + let mut raw_mixture_function = Vec::new(); + let mut raw_mixture_support = Vec::new(); + + for (&random_variable, &mix_weight) in random_variables.iter().zip(mix_weights.iter()) { + let function = &random_variable.to_pdf()?.function; + let support = &random_variable.to_pdf()?.support; + + for (&function_value, &support_value) in function.iter().zip(support.iter()) { + let partial_probability = function_value * mix_weight; + + let support_index = raw_mixture_support.iter().position(|&x| x == support_value); + + match support_index { + Some(idx) => { + raw_mixture_function[idx] += partial_probability; + } + None => { + raw_mixture_support.push(support_value); + raw_mixture_function.push(partial_probability); + } + } + } + } + + let mut raw_mixture_pair: Vec<_> = raw_mixture_support + .into_iter() + .zip(raw_mixture_function) + .collect(); + + raw_mixture_pair.sort_by(|a, b| { + let first_value = a.0.to_f64(); + let second_value = b.0.to_f64(); + first_value.total_cmp(&second_value) + }); + + let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip(); + + let mix_rv = RandomVariable { + function: mixture_function, + support: mixture_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + Ok(mix_rv) +} + #[cfg(test)] mod tests { use super::*; @@ -136,6 +250,30 @@ mod tests { } } + fn sample_mixture_rv_1() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + + fn sample_mixture_rv_2() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(3, 4)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + #[test] fn truncate_discrete_renormalizes_probabilities_within_range() { let rv = sample_discrete_rv(); @@ -166,4 +304,55 @@ mod tests { Err(msg) if msg == "min support must be greater than or equal to the lowest support value" )); } + + #[test] + fn mixture_discrete_combines_duplicate_support_and_sorts_output() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let mixed = + mixture_discrete(&[&rv1, &rv2], &[Number::Float(0.25), Number::Float(0.75)]).unwrap(); + + assert_eq!( + mixed.support, + vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] + ); + assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12); + assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12); + assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12); + assert!(matches!(mixed.functional_form, FunctionalForm::Pdf)); + assert!(matches!(mixed.domain_type, DomainType::Discrete)); + } + + #[test] + fn mixture_discrete_returns_error_when_lengths_do_not_match() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let result = mixture_discrete(&[&rv1, &rv2], &[Number::Rational(Rational64::new(1, 1))]); + + assert!(matches!( + result, + Err(msg) if msg == "the number of random variables and mix weights must be equal" + )); + } + + #[test] + fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let result = mixture_discrete( + &[&rv1, &rv2], + &[ + Number::Rational(Rational64::new(1, 3)), + Number::Rational(Rational64::new(1, 3)), + ], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the mix weights must sum to one" + )); + } } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 814fa18..9b51c78 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -37,6 +37,7 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { // transformation functions module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?; + module.add_function(wrap_pyfunction!(python::api::mixture_discrete_py, module)?)?; // dummy function to validate imports module.add_function(wrap_pyfunction!(dummy_ping, module)?)?; diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index 37c5ce5..95e5908 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -169,6 +169,31 @@ pub fn truncate_discrete_py( )) } +#[pyfunction(name = "mixture_discrete", signature = (random_variables, mix_weights))] +pub fn mixture_discrete_py( + random_variables: Vec>, + mix_weights: Vec, +) -> PyResult { + let extracted: PyResult> = random_variables + .into_iter() + .map(|rv| rv.extract::()) + .collect(); + let extracted: Vec = extracted?; + + let extracted_rvs: Vec<&RandomVariable> = + extracted.iter().map(|fast_rv| &fast_rv.inner).collect(); + + let mixed_rv = + transform::mixture_discrete(&extracted_rvs, &mix_weights).map_err(PyValueError::new_err)?; + + Ok(FastRV::new( + mixed_rv.function, + mixed_rv.support, + mixed_rv.functional_form, + mixed_rv.domain_type, + )) +} + #[pyclass] pub struct FastRV { inner: RandomVariable,