Skip to content

Commit 1540128

Browse files
authored
feat(transform): add Rust discrete-mixture transform and wire it into Python API (#38)
* first pass on mixture rv * add error handling * fix summing * refactor of index * tweaks to mixture * add tests for mixture_discrete * add mixture rv to python api * add discrete mixture to lib * patch in rust mixture; fix weight sum bug
1 parent cee7496 commit 1540128

5 files changed

Lines changed: 232 additions & 23 deletions

File tree

applpy/transform.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -522,29 +522,16 @@ def _mixture_discrete_functional(mixture_pdf_random_variables):
522522

523523

524524
def _mixture_discrete(mix_parameters, mixture_pdf_random_variables):
525-
# Compute the mixture rv by summing over the weights
526-
mixture_support = []
527-
mixture_functions = []
528-
for i in range(len(mixture_pdf_random_variables)):
529-
for j in range(len(mixture_pdf_random_variables[i].support)):
530-
if mixture_pdf_random_variables[i].support[j] not in mixture_support:
531-
mixture_support.append(mixture_pdf_random_variables[i].support[j])
532-
mixture_functions.append(
533-
mixture_pdf_random_variables[i].func[j] * mix_parameters[i]
534-
)
535-
else:
536-
support_index = mixture_support.index(mixture_pdf_random_variables[i].support[j])
537-
weighted_value = mixture_pdf_random_variables[i].func[j] * mix_parameters[i]
538-
mixture_functions[support_index] += weighted_value
539-
# Sort the values
540-
sorted_support_function_pairs = list(zip(mixture_support, mixture_functions))
541-
sorted_support_function_pairs.sort()
542-
mixture_functions = []
543-
mixture_support = []
544-
for i in range(len(sorted_support_function_pairs)):
545-
mixture_functions.append(sorted_support_function_pairs[i][1])
546-
mixture_support.append(sorted_support_function_pairs[i][0])
547-
return RV(mixture_functions, mixture_support, ["discrete", "pdf"])
525+
fast_rv = applpy_rust.mixture_discrete(
526+
random_variables=mixture_pdf_random_variables,
527+
mix_weights=mix_parameters,
528+
)
529+
return RV(
530+
func=fast_rv.function,
531+
support=fast_rv.support,
532+
functional_form=fast_rv.functional_form,
533+
domain_type=fast_rv.domain_type,
534+
)
548535

549536

550537
# Backward-compatible aliases for legacy APPLPy function names.

rust/src/algorithms/number.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![allow(dead_code)]
22

33
use std::fmt;
4+
use std::iter::Sum;
45
use std::ops::{Add, AddAssign, Div, Mul, Sub};
56
use std::{collections::BTreeMap, f64::consts::E};
67

@@ -533,6 +534,12 @@ impl Div for Number {
533534
}
534535
}
535536

537+
impl Sum for Number {
538+
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
539+
iter.fold(Number::default(), |acc, x| acc + x)
540+
}
541+
}
542+
536543
#[cfg(test)]
537544
mod tests {
538545
use super::*;

rust/src/algorithms/transform.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,120 @@ pub fn truncate_discrete(
112112
Ok(truncated_rv)
113113
}
114114

115+
/// Computes the mixture of two random variables
116+
///
117+
/// # Arguments
118+
/// * `random_variables` - a list of random variables to mix
119+
/// * `mix_weights` - the weight for each random variable. Must sum to 1
120+
///
121+
/// # Returns
122+
/// * `mixed_rv` - the weighted mixture of the random variables
123+
///
124+
/// # Examples
125+
/// ```
126+
/// use applpy_rust::algorithms::number::Number;
127+
/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
128+
/// use applpy_rust::algorithms::transform::mixture_discrete;
129+
/// use num_rational::Rational64;
130+
///
131+
/// let rv1 = RandomVariable {
132+
/// function: vec![
133+
/// Number::Rational(Rational64::new(1, 2)),
134+
/// Number::Rational(Rational64::new(1, 2)),
135+
/// ],
136+
/// support: vec![Number::Integer(1), Number::Integer(2)],
137+
/// functional_form: FunctionalForm::Pdf,
138+
/// domain_type: DomainType::Discrete,
139+
/// };
140+
///
141+
/// let rv2 = RandomVariable {
142+
/// function: vec![
143+
/// Number::Rational(Rational64::new(1, 4)),
144+
/// Number::Rational(Rational64::new(3, 4)),
145+
/// ],
146+
/// support: vec![Number::Integer(2), Number::Integer(3)],
147+
/// functional_form: FunctionalForm::Pdf,
148+
/// domain_type: DomainType::Discrete,
149+
/// };
150+
///
151+
/// let mixed = mixture_discrete(
152+
/// &[&rv1, &rv2],
153+
/// &[Number::Float(0.25), Number::Float(0.75)],
154+
/// )
155+
/// .unwrap();
156+
///
157+
/// assert_eq!(
158+
/// mixed.support,
159+
/// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
160+
/// );
161+
/// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
162+
/// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
163+
/// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
164+
/// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
165+
/// assert!(matches!(mixed.domain_type, DomainType::Discrete));
166+
/// ```
167+
pub fn mixture_discrete(
168+
random_variables: &[&RandomVariable],
169+
mix_weights: &[Number],
170+
) -> Result<RandomVariable, String> {
171+
if random_variables.len() != mix_weights.len() {
172+
return Err("the number of random variables and mix weights must be equal".to_string());
173+
}
174+
175+
let weight_sum: f64 = mix_weights.iter().copied().sum::<Number>().to_f64();
176+
let tolerance = 1e-6;
177+
let upper_bound = 1.0 + tolerance;
178+
let lower_bound = 1.0 - tolerance;
179+
if weight_sum < lower_bound || weight_sum > upper_bound {
180+
return Err("the mix weights must sum to one".to_string());
181+
}
182+
183+
let mut raw_mixture_function = Vec::new();
184+
let mut raw_mixture_support = Vec::new();
185+
186+
for (&random_variable, &mix_weight) in random_variables.iter().zip(mix_weights.iter()) {
187+
let function = &random_variable.to_pdf()?.function;
188+
let support = &random_variable.to_pdf()?.support;
189+
190+
for (&function_value, &support_value) in function.iter().zip(support.iter()) {
191+
let partial_probability = function_value * mix_weight;
192+
193+
let support_index = raw_mixture_support.iter().position(|&x| x == support_value);
194+
195+
match support_index {
196+
Some(idx) => {
197+
raw_mixture_function[idx] += partial_probability;
198+
}
199+
None => {
200+
raw_mixture_support.push(support_value);
201+
raw_mixture_function.push(partial_probability);
202+
}
203+
}
204+
}
205+
}
206+
207+
let mut raw_mixture_pair: Vec<_> = raw_mixture_support
208+
.into_iter()
209+
.zip(raw_mixture_function)
210+
.collect();
211+
212+
raw_mixture_pair.sort_by(|a, b| {
213+
let first_value = a.0.to_f64();
214+
let second_value = b.0.to_f64();
215+
first_value.total_cmp(&second_value)
216+
});
217+
218+
let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip();
219+
220+
let mix_rv = RandomVariable {
221+
function: mixture_function,
222+
support: mixture_support,
223+
functional_form: FunctionalForm::Pdf,
224+
domain_type: DomainType::Discrete,
225+
};
226+
Ok(mix_rv)
227+
}
228+
115229
#[cfg(test)]
116230
mod tests {
117231
use super::*;
@@ -136,6 +250,30 @@ mod tests {
136250
}
137251
}
138252

253+
fn sample_mixture_rv_1() -> RandomVariable {
254+
RandomVariable {
255+
function: vec![
256+
Number::Rational(Rational64::new(1, 2)),
257+
Number::Rational(Rational64::new(1, 2)),
258+
],
259+
support: vec![Number::Integer(1), Number::Integer(2)],
260+
functional_form: FunctionalForm::Pdf,
261+
domain_type: DomainType::Discrete,
262+
}
263+
}
264+
265+
fn sample_mixture_rv_2() -> RandomVariable {
266+
RandomVariable {
267+
function: vec![
268+
Number::Rational(Rational64::new(1, 4)),
269+
Number::Rational(Rational64::new(3, 4)),
270+
],
271+
support: vec![Number::Integer(2), Number::Integer(3)],
272+
functional_form: FunctionalForm::Pdf,
273+
domain_type: DomainType::Discrete,
274+
}
275+
}
276+
139277
#[test]
140278
fn truncate_discrete_renormalizes_probabilities_within_range() {
141279
let rv = sample_discrete_rv();
@@ -166,4 +304,55 @@ mod tests {
166304
Err(msg) if msg == "min support must be greater than or equal to the lowest support value"
167305
));
168306
}
307+
308+
#[test]
309+
fn mixture_discrete_combines_duplicate_support_and_sorts_output() {
310+
let rv1 = sample_mixture_rv_1();
311+
let rv2 = sample_mixture_rv_2();
312+
313+
let mixed =
314+
mixture_discrete(&[&rv1, &rv2], &[Number::Float(0.25), Number::Float(0.75)]).unwrap();
315+
316+
assert_eq!(
317+
mixed.support,
318+
vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
319+
);
320+
assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
321+
assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
322+
assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
323+
assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
324+
assert!(matches!(mixed.domain_type, DomainType::Discrete));
325+
}
326+
327+
#[test]
328+
fn mixture_discrete_returns_error_when_lengths_do_not_match() {
329+
let rv1 = sample_mixture_rv_1();
330+
let rv2 = sample_mixture_rv_2();
331+
332+
let result = mixture_discrete(&[&rv1, &rv2], &[Number::Rational(Rational64::new(1, 1))]);
333+
334+
assert!(matches!(
335+
result,
336+
Err(msg) if msg == "the number of random variables and mix weights must be equal"
337+
));
338+
}
339+
340+
#[test]
341+
fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one() {
342+
let rv1 = sample_mixture_rv_1();
343+
let rv2 = sample_mixture_rv_2();
344+
345+
let result = mixture_discrete(
346+
&[&rv1, &rv2],
347+
&[
348+
Number::Rational(Rational64::new(1, 3)),
349+
Number::Rational(Rational64::new(1, 3)),
350+
],
351+
);
352+
353+
assert!(matches!(
354+
result,
355+
Err(msg) if msg == "the mix weights must sum to one"
356+
));
357+
}
169358
}

rust/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
3737

3838
// transformation functions
3939
module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?;
40+
module.add_function(wrap_pyfunction!(python::api::mixture_discrete_py, module)?)?;
4041

4142
// dummy function to validate imports
4243
module.add_function(wrap_pyfunction!(dummy_ping, module)?)?;

rust/src/python/api.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,31 @@ pub fn truncate_discrete_py(
169169
))
170170
}
171171

172+
#[pyfunction(name = "mixture_discrete", signature = (random_variables, mix_weights))]
173+
pub fn mixture_discrete_py(
174+
random_variables: Vec<Bound<'_, PyAny>>,
175+
mix_weights: Vec<Number>,
176+
) -> PyResult<FastRV> {
177+
let extracted: PyResult<Vec<FastRV>> = random_variables
178+
.into_iter()
179+
.map(|rv| rv.extract::<FastRV>())
180+
.collect();
181+
let extracted: Vec<FastRV> = extracted?;
182+
183+
let extracted_rvs: Vec<&RandomVariable> =
184+
extracted.iter().map(|fast_rv| &fast_rv.inner).collect();
185+
186+
let mixed_rv =
187+
transform::mixture_discrete(&extracted_rvs, &mix_weights).map_err(PyValueError::new_err)?;
188+
189+
Ok(FastRV::new(
190+
mixed_rv.function,
191+
mixed_rv.support,
192+
mixed_rv.functional_form,
193+
mixed_rv.domain_type,
194+
))
195+
}
196+
172197
#[pyclass]
173198
pub struct FastRV {
174199
inner: RandomVariable,

0 commit comments

Comments
 (0)