Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions applpy/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,29 +522,16 @@ def _mixture_discrete_functional(mixture_pdf_random_variables):


def _mixture_discrete(mix_parameters, mixture_pdf_random_variables):
# Compute the mixture rv by summing over the weights
mixture_support = []
mixture_functions = []
for i in range(len(mixture_pdf_random_variables)):
for j in range(len(mixture_pdf_random_variables[i].support)):
if mixture_pdf_random_variables[i].support[j] not in mixture_support:
mixture_support.append(mixture_pdf_random_variables[i].support[j])
mixture_functions.append(
mixture_pdf_random_variables[i].func[j] * mix_parameters[i]
)
else:
support_index = mixture_support.index(mixture_pdf_random_variables[i].support[j])
weighted_value = mixture_pdf_random_variables[i].func[j] * mix_parameters[i]
mixture_functions[support_index] += weighted_value
# Sort the values
sorted_support_function_pairs = list(zip(mixture_support, mixture_functions))
sorted_support_function_pairs.sort()
mixture_functions = []
mixture_support = []
for i in range(len(sorted_support_function_pairs)):
mixture_functions.append(sorted_support_function_pairs[i][1])
mixture_support.append(sorted_support_function_pairs[i][0])
return RV(mixture_functions, mixture_support, ["discrete", "pdf"])
fast_rv = applpy_rust.mixture_discrete(
random_variables=mixture_pdf_random_variables,
mix_weights=mix_parameters,
)
return RV(
func=fast_rv.function,
support=fast_rv.support,
functional_form=fast_rv.functional_form,
domain_type=fast_rv.domain_type,
)


# Backward-compatible aliases for legacy APPLPy function names.
Expand Down
7 changes: 7 additions & 0 deletions rust/src/algorithms/number.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(dead_code)]

use std::fmt;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, Mul, Sub};
use std::{collections::BTreeMap, f64::consts::E};

Expand Down Expand Up @@ -533,6 +534,12 @@ impl Div for Number {
}
}

impl Sum for Number {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Number::default(), |acc, x| acc + x)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
189 changes: 189 additions & 0 deletions rust/src/algorithms/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,120 @@ pub fn truncate_discrete(
Ok(truncated_rv)
}

/// Computes the mixture of two random variables
///
/// # Arguments
/// * `random_variables` - a list of random variables to mix
/// * `mix_weights` - the weight for each random variable. Must sum to 1
///
/// # Returns
/// * `mixed_rv` - the weighted mixture of the random variables
///
/// # Examples
/// ```
/// use applpy_rust::algorithms::number::Number;
/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
/// use applpy_rust::algorithms::transform::mixture_discrete;
/// use num_rational::Rational64;
///
/// let rv1 = RandomVariable {
/// function: vec![
/// Number::Rational(Rational64::new(1, 2)),
/// Number::Rational(Rational64::new(1, 2)),
/// ],
/// support: vec![Number::Integer(1), Number::Integer(2)],
/// functional_form: FunctionalForm::Pdf,
/// domain_type: DomainType::Discrete,
/// };
///
/// let rv2 = RandomVariable {
/// function: vec![
/// Number::Rational(Rational64::new(1, 4)),
/// Number::Rational(Rational64::new(3, 4)),
/// ],
/// support: vec![Number::Integer(2), Number::Integer(3)],
/// functional_form: FunctionalForm::Pdf,
/// domain_type: DomainType::Discrete,
/// };
///
/// let mixed = mixture_discrete(
/// &[&rv1, &rv2],
/// &[Number::Float(0.25), Number::Float(0.75)],
/// )
/// .unwrap();
///
/// assert_eq!(
/// mixed.support,
/// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
/// );
/// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
/// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
/// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
/// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
/// assert!(matches!(mixed.domain_type, DomainType::Discrete));
/// ```
pub fn mixture_discrete(
random_variables: &[&RandomVariable],
mix_weights: &[Number],
) -> Result<RandomVariable, String> {
if random_variables.len() != mix_weights.len() {
return Err("the number of random variables and mix weights must be equal".to_string());
}

let weight_sum: f64 = mix_weights.iter().copied().sum::<Number>().to_f64();
let tolerance = 1e-6;
let upper_bound = 1.0 + tolerance;
let lower_bound = 1.0 - tolerance;
if weight_sum < lower_bound || weight_sum > upper_bound {
return Err("the mix weights must sum to one".to_string());
}

let mut raw_mixture_function = Vec::new();
let mut raw_mixture_support = Vec::new();

for (&random_variable, &mix_weight) in random_variables.iter().zip(mix_weights.iter()) {
let function = &random_variable.to_pdf()?.function;
let support = &random_variable.to_pdf()?.support;

for (&function_value, &support_value) in function.iter().zip(support.iter()) {
let partial_probability = function_value * mix_weight;

let support_index = raw_mixture_support.iter().position(|&x| x == support_value);

match support_index {
Some(idx) => {
raw_mixture_function[idx] += partial_probability;
}
None => {
raw_mixture_support.push(support_value);
raw_mixture_function.push(partial_probability);
}
}
}
}

let mut raw_mixture_pair: Vec<_> = raw_mixture_support
.into_iter()
.zip(raw_mixture_function)
.collect();

raw_mixture_pair.sort_by(|a, b| {
let first_value = a.0.to_f64();
let second_value = b.0.to_f64();
first_value.total_cmp(&second_value)
});

let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip();

let mix_rv = RandomVariable {
function: mixture_function,
support: mixture_support,
functional_form: FunctionalForm::Pdf,
domain_type: DomainType::Discrete,
};
Ok(mix_rv)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -136,6 +250,30 @@ mod tests {
}
}

fn sample_mixture_rv_1() -> RandomVariable {
RandomVariable {
function: vec![
Number::Rational(Rational64::new(1, 2)),
Number::Rational(Rational64::new(1, 2)),
],
support: vec![Number::Integer(1), Number::Integer(2)],
functional_form: FunctionalForm::Pdf,
domain_type: DomainType::Discrete,
}
}

fn sample_mixture_rv_2() -> RandomVariable {
RandomVariable {
function: vec![
Number::Rational(Rational64::new(1, 4)),
Number::Rational(Rational64::new(3, 4)),
],
support: vec![Number::Integer(2), Number::Integer(3)],
functional_form: FunctionalForm::Pdf,
domain_type: DomainType::Discrete,
}
}

#[test]
fn truncate_discrete_renormalizes_probabilities_within_range() {
let rv = sample_discrete_rv();
Expand Down Expand Up @@ -166,4 +304,55 @@ mod tests {
Err(msg) if msg == "min support must be greater than or equal to the lowest support value"
));
}

#[test]
fn mixture_discrete_combines_duplicate_support_and_sorts_output() {
let rv1 = sample_mixture_rv_1();
let rv2 = sample_mixture_rv_2();

let mixed =
mixture_discrete(&[&rv1, &rv2], &[Number::Float(0.25), Number::Float(0.75)]).unwrap();

assert_eq!(
mixed.support,
vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
);
assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
assert!(matches!(mixed.domain_type, DomainType::Discrete));
}

#[test]
fn mixture_discrete_returns_error_when_lengths_do_not_match() {
let rv1 = sample_mixture_rv_1();
let rv2 = sample_mixture_rv_2();

let result = mixture_discrete(&[&rv1, &rv2], &[Number::Rational(Rational64::new(1, 1))]);

assert!(matches!(
result,
Err(msg) if msg == "the number of random variables and mix weights must be equal"
));
}

#[test]
fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one() {
let rv1 = sample_mixture_rv_1();
let rv2 = sample_mixture_rv_2();

let result = mixture_discrete(
&[&rv1, &rv2],
&[
Number::Rational(Rational64::new(1, 3)),
Number::Rational(Rational64::new(1, 3)),
],
);

assert!(matches!(
result,
Err(msg) if msg == "the mix weights must sum to one"
));
}
}
1 change: 1 addition & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {

// transformation functions
module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?;
module.add_function(wrap_pyfunction!(python::api::mixture_discrete_py, module)?)?;

// dummy function to validate imports
module.add_function(wrap_pyfunction!(dummy_ping, module)?)?;
Expand Down
25 changes: 25 additions & 0 deletions rust/src/python/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,31 @@ pub fn truncate_discrete_py(
))
}

#[pyfunction(name = "mixture_discrete", signature = (random_variables, mix_weights))]
pub fn mixture_discrete_py(
random_variables: Vec<Bound<'_, PyAny>>,
mix_weights: Vec<Number>,
) -> PyResult<FastRV> {
let extracted: PyResult<Vec<FastRV>> = random_variables
.into_iter()
.map(|rv| rv.extract::<FastRV>())
.collect();
let extracted: Vec<FastRV> = extracted?;

let extracted_rvs: Vec<&RandomVariable> =
extracted.iter().map(|fast_rv| &fast_rv.inner).collect();

let mixed_rv =
transform::mixture_discrete(&extracted_rvs, &mix_weights).map_err(PyValueError::new_err)?;

Ok(FastRV::new(
mixed_rv.function,
mixed_rv.support,
mixed_rv.functional_form,
mixed_rv.domain_type,
))
}

#[pyclass]
pub struct FastRV {
inner: RandomVariable,
Expand Down
Loading