@@ -120,6 +120,50 @@ pub fn truncate_discrete(
120120///
121121/// # Returns
122122/// * `mixed_rv` - the weighted mixture of the random variables
123+ ///
124+ /// # Examples
125+ /// ```
126+ /// use applpy_rust::algorithms::number::Number;
127+ /// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
128+ /// use applpy_rust::algorithms::transform::mixture_discrete;
129+ /// use num_rational::Rational64;
130+ ///
131+ /// let rv1 = RandomVariable {
132+ /// function: vec![
133+ /// Number::Rational(Rational64::new(1, 2)),
134+ /// Number::Rational(Rational64::new(1, 2)),
135+ /// ],
136+ /// support: vec![Number::Integer(1), Number::Integer(2)],
137+ /// functional_form: FunctionalForm::Pdf,
138+ /// domain_type: DomainType::Discrete,
139+ /// };
140+ ///
141+ /// let rv2 = RandomVariable {
142+ /// function: vec![
143+ /// Number::Rational(Rational64::new(1, 4)),
144+ /// Number::Rational(Rational64::new(3, 4)),
145+ /// ],
146+ /// support: vec![Number::Integer(2), Number::Integer(3)],
147+ /// functional_form: FunctionalForm::Pdf,
148+ /// domain_type: DomainType::Discrete,
149+ /// };
150+ ///
151+ /// let mixed = mixture_discrete(
152+ /// &[&rv1, &rv2],
153+ /// &[Number::Float(0.25), Number::Float(0.75)],
154+ /// )
155+ /// .unwrap();
156+ ///
157+ /// assert_eq!(
158+ /// mixed.support,
159+ /// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
160+ /// );
161+ /// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
162+ /// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
163+ /// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
164+ /// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
165+ /// assert!(matches!(mixed.domain_type, DomainType::Discrete));
166+ /// ```
123167pub fn mixture_discrete (
124168 random_variables : & [ & RandomVariable ] ,
125169 mix_weights : & [ Number ] ,
@@ -205,6 +249,30 @@ mod tests {
205249 }
206250 }
207251
252+ fn sample_mixture_rv_1 ( ) -> RandomVariable {
253+ RandomVariable {
254+ function : vec ! [
255+ Number :: Rational ( Rational64 :: new( 1 , 2 ) ) ,
256+ Number :: Rational ( Rational64 :: new( 1 , 2 ) ) ,
257+ ] ,
258+ support : vec ! [ Number :: Integer ( 1 ) , Number :: Integer ( 2 ) ] ,
259+ functional_form : FunctionalForm :: Pdf ,
260+ domain_type : DomainType :: Discrete ,
261+ }
262+ }
263+
264+ fn sample_mixture_rv_2 ( ) -> RandomVariable {
265+ RandomVariable {
266+ function : vec ! [
267+ Number :: Rational ( Rational64 :: new( 1 , 4 ) ) ,
268+ Number :: Rational ( Rational64 :: new( 3 , 4 ) ) ,
269+ ] ,
270+ support : vec ! [ Number :: Integer ( 2 ) , Number :: Integer ( 3 ) ] ,
271+ functional_form : FunctionalForm :: Pdf ,
272+ domain_type : DomainType :: Discrete ,
273+ }
274+ }
275+
208276 #[ test]
209277 fn truncate_discrete_renormalizes_probabilities_within_range ( ) {
210278 let rv = sample_discrete_rv ( ) ;
@@ -235,4 +303,55 @@ mod tests {
235303 Err ( msg) if msg == "min support must be greater than or equal to the lowest support value"
236304 ) ) ;
237305 }
306+
307+ #[ test]
308+ fn mixture_discrete_combines_duplicate_support_and_sorts_output ( ) {
309+ let rv1 = sample_mixture_rv_1 ( ) ;
310+ let rv2 = sample_mixture_rv_2 ( ) ;
311+
312+ let mixed =
313+ mixture_discrete ( & [ & rv1, & rv2] , & [ Number :: Float ( 0.25 ) , Number :: Float ( 0.75 ) ] ) . unwrap ( ) ;
314+
315+ assert_eq ! (
316+ mixed. support,
317+ vec![ Number :: Integer ( 1 ) , Number :: Integer ( 2 ) , Number :: Integer ( 3 ) ]
318+ ) ;
319+ assert ! ( ( mixed. function[ 0 ] . to_f64( ) - 0.125 ) . abs( ) < 1e-12 ) ;
320+ assert ! ( ( mixed. function[ 1 ] . to_f64( ) - 0.3125 ) . abs( ) < 1e-12 ) ;
321+ assert ! ( ( mixed. function[ 2 ] . to_f64( ) - 0.5625 ) . abs( ) < 1e-12 ) ;
322+ assert ! ( matches!( mixed. functional_form, FunctionalForm :: Pdf ) ) ;
323+ assert ! ( matches!( mixed. domain_type, DomainType :: Discrete ) ) ;
324+ }
325+
326+ #[ test]
327+ fn mixture_discrete_returns_error_when_lengths_do_not_match ( ) {
328+ let rv1 = sample_mixture_rv_1 ( ) ;
329+ let rv2 = sample_mixture_rv_2 ( ) ;
330+
331+ let result = mixture_discrete ( & [ & rv1, & rv2] , & [ Number :: Rational ( Rational64 :: new ( 1 , 1 ) ) ] ) ;
332+
333+ assert ! ( matches!(
334+ result,
335+ Err ( msg) if msg == "the number of random variables and mix weights must be equal"
336+ ) ) ;
337+ }
338+
339+ #[ test]
340+ fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one ( ) {
341+ let rv1 = sample_mixture_rv_1 ( ) ;
342+ let rv2 = sample_mixture_rv_2 ( ) ;
343+
344+ let result = mixture_discrete (
345+ & [ & rv1, & rv2] ,
346+ & [
347+ Number :: Rational ( Rational64 :: new ( 1 , 3 ) ) ,
348+ Number :: Rational ( Rational64 :: new ( 1 , 3 ) ) ,
349+ ] ,
350+ ) ;
351+
352+ assert ! ( matches!(
353+ result,
354+ Err ( msg) if msg == "the mix weights must sum to one"
355+ ) ) ;
356+ }
238357}
0 commit comments