Skip to content

Commit feb95d2

Browse files
committed
add tests for mixture_discrete
1 parent 7badf71 commit feb95d2

1 file changed

Lines changed: 119 additions & 0 deletions

File tree

rust/src/algorithms/transform.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,50 @@ pub fn truncate_discrete(
120120
///
121121
/// # Returns
122122
/// * `mixed_rv` - the weighted mixture of the random variables
123+
///
124+
/// # Examples
125+
/// ```
126+
/// use applpy_rust::algorithms::number::Number;
127+
/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
128+
/// use applpy_rust::algorithms::transform::mixture_discrete;
129+
/// use num_rational::Rational64;
130+
///
131+
/// let rv1 = RandomVariable {
132+
/// function: vec![
133+
/// Number::Rational(Rational64::new(1, 2)),
134+
/// Number::Rational(Rational64::new(1, 2)),
135+
/// ],
136+
/// support: vec![Number::Integer(1), Number::Integer(2)],
137+
/// functional_form: FunctionalForm::Pdf,
138+
/// domain_type: DomainType::Discrete,
139+
/// };
140+
///
141+
/// let rv2 = RandomVariable {
142+
/// function: vec![
143+
/// Number::Rational(Rational64::new(1, 4)),
144+
/// Number::Rational(Rational64::new(3, 4)),
145+
/// ],
146+
/// support: vec![Number::Integer(2), Number::Integer(3)],
147+
/// functional_form: FunctionalForm::Pdf,
148+
/// domain_type: DomainType::Discrete,
149+
/// };
150+
///
151+
/// let mixed = mixture_discrete(
152+
/// &[&rv1, &rv2],
153+
/// &[Number::Float(0.25), Number::Float(0.75)],
154+
/// )
155+
/// .unwrap();
156+
///
157+
/// assert_eq!(
158+
/// mixed.support,
159+
/// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
160+
/// );
161+
/// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
162+
/// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
163+
/// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
164+
/// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
165+
/// assert!(matches!(mixed.domain_type, DomainType::Discrete));
166+
/// ```
123167
pub fn mixture_discrete(
124168
random_variables: &[&RandomVariable],
125169
mix_weights: &[Number],
@@ -205,6 +249,30 @@ mod tests {
205249
}
206250
}
207251

252+
fn sample_mixture_rv_1() -> RandomVariable {
253+
RandomVariable {
254+
function: vec![
255+
Number::Rational(Rational64::new(1, 2)),
256+
Number::Rational(Rational64::new(1, 2)),
257+
],
258+
support: vec![Number::Integer(1), Number::Integer(2)],
259+
functional_form: FunctionalForm::Pdf,
260+
domain_type: DomainType::Discrete,
261+
}
262+
}
263+
264+
fn sample_mixture_rv_2() -> RandomVariable {
265+
RandomVariable {
266+
function: vec![
267+
Number::Rational(Rational64::new(1, 4)),
268+
Number::Rational(Rational64::new(3, 4)),
269+
],
270+
support: vec![Number::Integer(2), Number::Integer(3)],
271+
functional_form: FunctionalForm::Pdf,
272+
domain_type: DomainType::Discrete,
273+
}
274+
}
275+
208276
#[test]
209277
fn truncate_discrete_renormalizes_probabilities_within_range() {
210278
let rv = sample_discrete_rv();
@@ -235,4 +303,55 @@ mod tests {
235303
Err(msg) if msg == "min support must be greater than or equal to the lowest support value"
236304
));
237305
}
306+
307+
#[test]
308+
fn mixture_discrete_combines_duplicate_support_and_sorts_output() {
309+
let rv1 = sample_mixture_rv_1();
310+
let rv2 = sample_mixture_rv_2();
311+
312+
let mixed =
313+
mixture_discrete(&[&rv1, &rv2], &[Number::Float(0.25), Number::Float(0.75)]).unwrap();
314+
315+
assert_eq!(
316+
mixed.support,
317+
vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
318+
);
319+
assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
320+
assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
321+
assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
322+
assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
323+
assert!(matches!(mixed.domain_type, DomainType::Discrete));
324+
}
325+
326+
#[test]
327+
fn mixture_discrete_returns_error_when_lengths_do_not_match() {
328+
let rv1 = sample_mixture_rv_1();
329+
let rv2 = sample_mixture_rv_2();
330+
331+
let result = mixture_discrete(&[&rv1, &rv2], &[Number::Rational(Rational64::new(1, 1))]);
332+
333+
assert!(matches!(
334+
result,
335+
Err(msg) if msg == "the number of random variables and mix weights must be equal"
336+
));
337+
}
338+
339+
#[test]
340+
fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one() {
341+
let rv1 = sample_mixture_rv_1();
342+
let rv2 = sample_mixture_rv_2();
343+
344+
let result = mixture_discrete(
345+
&[&rv1, &rv2],
346+
&[
347+
Number::Rational(Rational64::new(1, 3)),
348+
Number::Rational(Rational64::new(1, 3)),
349+
],
350+
);
351+
352+
assert!(matches!(
353+
result,
354+
Err(msg) if msg == "the mix weights must sum to one"
355+
));
356+
}
238357
}

0 commit comments

Comments
 (0)