From f00f33e6336e4a36cd5110b3d0aad31dd890d812 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 21:45:09 -0400 Subject: [PATCH 1/9] first pass on mixture rv --- rust/src/algorithms/transform.rs | 56 ++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 8283c79..0cc9fd6 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -112,6 +112,62 @@ pub fn truncate_discrete( Ok(truncated_rv) } +/// Computes the mixture of two random variables +/// +/// # Arguments +/// * `random_variables` - a list of random variables to mix +/// * `mix_weights` - the weight for each random variable. Must sum to 1 +/// +/// # Returns +/// * `mixed_rv` - the weighted mixture of the random variables +pub fn mixture_discrete( + random_variables: &[&RandomVariable], + mix_weights: &[Number], +) -> Result { + + let mut raw_mixture_function = Vec::new(); + let mut raw_mixture_support = Vec::new(); + + for (&random_variable, &mix_weight) in random_variables.iter().zip(mix_weights.iter()) { + let function = &random_variable.function; + let support = &random_variable.support; + + for (&function_value, &support_value) in function.iter().zip(support.iter()) { + let partial_probability = function_value * mix_weight; + if raw_mixture_support.contains(&support_value) { + raw_mixture_support.push(support_value); + raw_mixture_function.push(partial_probability); + } else { + let support_index = support.iter().position(|&x| x == support_value) + .expect("support value not found in mixture support"); + raw_mixture_function[support_index] += partial_probability; + } + } + } + + let mut raw_mixture_pair: Vec<_> = raw_mixture_support.into_iter() + .zip(raw_mixture_function) + .collect(); + + raw_mixture_pair.sort_by(|a, b| { + let first_value = a.0.to_f64(); + let second_value = b.0.to_f64(); + first_value.total_cmp(&second_value) + }); + + let (mixture_support, mixture_function) = raw_mixture_pair + .into_iter() + .unzip(); + + let mix_rv = RandomVariable { + function: mixture_function, + support: mixture_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + Ok(mix_rv) +} + #[cfg(test)] mod tests { use super::*; From 13611a5511d48e1860e0a169f205e53d2961cfca Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 21:55:17 -0400 Subject: [PATCH 2/9] add error handling --- rust/src/algorithms/transform.rs | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 0cc9fd6..8079c93 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -124,6 +124,19 @@ pub fn mixture_discrete( random_variables: &[&RandomVariable], mix_weights: &[Number], ) -> Result { + if random_variables.len() != mix_weights.len() { + return Err("the number of random variables and mix weights must be equal".to_string()); + } + + let weight_sum = mix_weights + .iter() + .fold(Number::default(), |acc, x| acc + *x); + + let one = Number::Integer(1); + let tolerance = Number::Float(1e-6); + if weight_sum < one - tolerance || weight_sum > one + tolerance { + return Err("the mix weights must sum to one".to_string()); + } let mut raw_mixture_function = Vec::new(); let mut raw_mixture_support = Vec::new(); @@ -138,14 +151,17 @@ pub fn mixture_discrete( raw_mixture_support.push(support_value); raw_mixture_function.push(partial_probability); } else { - let support_index = support.iter().position(|&x| x == support_value) + let support_index = support + .iter() + .position(|&x| x == support_value) .expect("support value not found in mixture support"); raw_mixture_function[support_index] += partial_probability; } } } - let mut raw_mixture_pair: Vec<_> = raw_mixture_support.into_iter() + let mut raw_mixture_pair: Vec<_> = raw_mixture_support + .into_iter() .zip(raw_mixture_function) .collect(); @@ -155,9 +171,7 @@ pub fn mixture_discrete( first_value.total_cmp(&second_value) }); - let (mixture_support, mixture_function) = raw_mixture_pair - .into_iter() - .unzip(); + let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip(); let mix_rv = RandomVariable { function: mixture_function, From 4e94a2b38d0b8bf9a9d49a1410c379b046d8679d Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 21:57:46 -0400 Subject: [PATCH 3/9] fix summing --- rust/src/algorithms/number.rs | 7 +++++++ rust/src/algorithms/transform.rs | 5 +---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/rust/src/algorithms/number.rs b/rust/src/algorithms/number.rs index d2ce18e..33dd504 100644 --- a/rust/src/algorithms/number.rs +++ b/rust/src/algorithms/number.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::fmt; +use std::iter::Sum; use std::ops::{Add, AddAssign, Div, Mul, Sub}; use std::{collections::BTreeMap, f64::consts::E}; @@ -533,6 +534,12 @@ impl Div for Number { } } +impl Sum for Number { + fn sum>(iter: I) -> Self { + iter.fold(Number::default(), |acc, x| acc + x) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 8079c93..5306bf2 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -128,10 +128,7 @@ pub fn mixture_discrete( return Err("the number of random variables and mix weights must be equal".to_string()); } - let weight_sum = mix_weights - .iter() - .fold(Number::default(), |acc, x| acc + *x); - + let weight_sum: Number = mix_weights.iter().copied().sum(); let one = Number::Integer(1); let tolerance = Number::Float(1e-6); if weight_sum < one - tolerance || weight_sum > one + tolerance { From 1eabc6655e4f4428eabf0c012b7b519bbd44965f Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 22:03:31 -0400 Subject: [PATCH 4/9] refactor of index --- rust/src/algorithms/transform.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 5306bf2..60c7e6e 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -144,15 +144,13 @@ pub fn mixture_discrete( for (&function_value, &support_value) in function.iter().zip(support.iter()) { let partial_probability = function_value * mix_weight; - if raw_mixture_support.contains(&support_value) { + if let Some(support_index) = + raw_mixture_support.iter().position(|&x| x == support_value) + { + raw_mixture_function[support_index] += partial_probability; + } else { raw_mixture_support.push(support_value); raw_mixture_function.push(partial_probability); - } else { - let support_index = support - .iter() - .position(|&x| x == support_value) - .expect("support value not found in mixture support"); - raw_mixture_function[support_index] += partial_probability; } } } From 7badf71336fad03f253ff985f73d0df8e3bea8f9 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 22:10:07 -0400 Subject: [PATCH 5/9] tweaks to mixture --- rust/src/algorithms/transform.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 60c7e6e..eaa6a57 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -139,18 +139,22 @@ pub fn mixture_discrete( let mut raw_mixture_support = Vec::new(); for (&random_variable, &mix_weight) in random_variables.iter().zip(mix_weights.iter()) { - let function = &random_variable.function; - let support = &random_variable.support; + let function = &random_variable.to_pdf()?.function; + let support = &random_variable.to_pdf()?.support; for (&function_value, &support_value) in function.iter().zip(support.iter()) { let partial_probability = function_value * mix_weight; - if let Some(support_index) = - raw_mixture_support.iter().position(|&x| x == support_value) - { - raw_mixture_function[support_index] += partial_probability; - } else { - raw_mixture_support.push(support_value); - raw_mixture_function.push(partial_probability); + + let support_index = raw_mixture_support.iter().position(|&x| x == support_value); + + match support_index { + Some(idx) => { + raw_mixture_function[idx] += partial_probability; + } + None => { + raw_mixture_support.push(support_value); + raw_mixture_function.push(partial_probability); + } } } } From feb95d24e2af42b9b0331b7925708c41a1a1c391 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 22:13:22 -0400 Subject: [PATCH 6/9] add tests for mixture_discrete --- rust/src/algorithms/transform.rs | 119 +++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index eaa6a57..d79131d 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -120,6 +120,50 @@ pub fn truncate_discrete( /// /// # Returns /// * `mixed_rv` - the weighted mixture of the random variables +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::mixture_discrete; +/// use num_rational::Rational64; +/// +/// let rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(3, 4)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let mixed = mixture_discrete( +/// &[&rv1, &rv2], +/// &[Number::Float(0.25), Number::Float(0.75)], +/// ) +/// .unwrap(); +/// +/// assert_eq!( +/// mixed.support, +/// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] +/// ); +/// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12); +/// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12); +/// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12); +/// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf)); +/// assert!(matches!(mixed.domain_type, DomainType::Discrete)); +/// ``` pub fn mixture_discrete( random_variables: &[&RandomVariable], mix_weights: &[Number], @@ -205,6 +249,30 @@ mod tests { } } + fn sample_mixture_rv_1() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + + fn sample_mixture_rv_2() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(3, 4)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + #[test] fn truncate_discrete_renormalizes_probabilities_within_range() { let rv = sample_discrete_rv(); @@ -235,4 +303,55 @@ mod tests { Err(msg) if msg == "min support must be greater than or equal to the lowest support value" )); } + + #[test] + fn mixture_discrete_combines_duplicate_support_and_sorts_output() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let mixed = + mixture_discrete(&[&rv1, &rv2], &[Number::Float(0.25), Number::Float(0.75)]).unwrap(); + + assert_eq!( + mixed.support, + vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] + ); + assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12); + assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12); + assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12); + assert!(matches!(mixed.functional_form, FunctionalForm::Pdf)); + assert!(matches!(mixed.domain_type, DomainType::Discrete)); + } + + #[test] + fn mixture_discrete_returns_error_when_lengths_do_not_match() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let result = mixture_discrete(&[&rv1, &rv2], &[Number::Rational(Rational64::new(1, 1))]); + + assert!(matches!( + result, + Err(msg) if msg == "the number of random variables and mix weights must be equal" + )); + } + + #[test] + fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one() { + let rv1 = sample_mixture_rv_1(); + let rv2 = sample_mixture_rv_2(); + + let result = mixture_discrete( + &[&rv1, &rv2], + &[ + Number::Rational(Rational64::new(1, 3)), + Number::Rational(Rational64::new(1, 3)), + ], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the mix weights must sum to one" + )); + } } From 4dcf0dea565ae8fae7ea026bbc830eef98bf6182 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Wed, 25 Mar 2026 14:13:51 -0400 Subject: [PATCH 7/9] add mixture rv to python api --- rust/src/python/api.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index 37c5ce5..95e5908 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -169,6 +169,31 @@ pub fn truncate_discrete_py( )) } +#[pyfunction(name = "mixture_discrete", signature = (random_variables, mix_weights))] +pub fn mixture_discrete_py( + random_variables: Vec>, + mix_weights: Vec, +) -> PyResult { + let extracted: PyResult> = random_variables + .into_iter() + .map(|rv| rv.extract::()) + .collect(); + let extracted: Vec = extracted?; + + let extracted_rvs: Vec<&RandomVariable> = + extracted.iter().map(|fast_rv| &fast_rv.inner).collect(); + + let mixed_rv = + transform::mixture_discrete(&extracted_rvs, &mix_weights).map_err(PyValueError::new_err)?; + + Ok(FastRV::new( + mixed_rv.function, + mixed_rv.support, + mixed_rv.functional_form, + mixed_rv.domain_type, + )) +} + #[pyclass] pub struct FastRV { inner: RandomVariable, From 055b335a1fc35eda47956b80a94fd0893540783c Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Wed, 25 Mar 2026 14:15:36 -0400 Subject: [PATCH 8/9] add discrete mixture to lib --- rust/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 814fa18..9b51c78 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -37,6 +37,7 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { // transformation functions module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?; + module.add_function(wrap_pyfunction!(python::api::mixture_discrete_py, module)?)?; // dummy function to validate imports module.add_function(wrap_pyfunction!(dummy_ping, module)?)?; From 0e1dadd4d79e36949f704bcc6af2868f08279921 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Wed, 25 Mar 2026 14:33:23 -0400 Subject: [PATCH 9/9] patch in rust mixture; fix weight sum bug --- applpy/transform.py | 33 ++++++++++---------------------- rust/src/algorithms/transform.rs | 9 +++++---- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/applpy/transform.py b/applpy/transform.py index 404b6a1..ace8f5d 100644 --- a/applpy/transform.py +++ b/applpy/transform.py @@ -522,29 +522,16 @@ def _mixture_discrete_functional(mixture_pdf_random_variables): def _mixture_discrete(mix_parameters, mixture_pdf_random_variables): - # Compute the mixture rv by summing over the weights - mixture_support = [] - mixture_functions = [] - for i in range(len(mixture_pdf_random_variables)): - for j in range(len(mixture_pdf_random_variables[i].support)): - if mixture_pdf_random_variables[i].support[j] not in mixture_support: - mixture_support.append(mixture_pdf_random_variables[i].support[j]) - mixture_functions.append( - mixture_pdf_random_variables[i].func[j] * mix_parameters[i] - ) - else: - support_index = mixture_support.index(mixture_pdf_random_variables[i].support[j]) - weighted_value = mixture_pdf_random_variables[i].func[j] * mix_parameters[i] - mixture_functions[support_index] += weighted_value - # Sort the values - sorted_support_function_pairs = list(zip(mixture_support, mixture_functions)) - sorted_support_function_pairs.sort() - mixture_functions = [] - mixture_support = [] - for i in range(len(sorted_support_function_pairs)): - mixture_functions.append(sorted_support_function_pairs[i][1]) - mixture_support.append(sorted_support_function_pairs[i][0]) - return RV(mixture_functions, mixture_support, ["discrete", "pdf"]) + fast_rv = applpy_rust.mixture_discrete( + random_variables=mixture_pdf_random_variables, + mix_weights=mix_parameters, + ) + return RV( + func=fast_rv.function, + support=fast_rv.support, + functional_form=fast_rv.functional_form, + domain_type=fast_rv.domain_type, + ) # Backward-compatible aliases for legacy APPLPy function names. diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index d79131d..a65a844 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -172,10 +172,11 @@ pub fn mixture_discrete( return Err("the number of random variables and mix weights must be equal".to_string()); } - let weight_sum: Number = mix_weights.iter().copied().sum(); - let one = Number::Integer(1); - let tolerance = Number::Float(1e-6); - if weight_sum < one - tolerance || weight_sum > one + tolerance { + let weight_sum: f64 = mix_weights.iter().copied().sum::().to_f64(); + let tolerance = 1e-6; + let upper_bound = 1.0 + tolerance; + let lower_bound = 1.0 - tolerance; + if weight_sum < lower_bound || weight_sum > upper_bound { return Err("the mix weights must sum to one".to_string()); }