@@ -112,6 +112,120 @@ pub fn truncate_discrete(
112112 Ok ( truncated_rv)
113113}
114114
115+ /// Computes the mixture of two random variables
116+ ///
117+ /// # Arguments
118+ /// * `random_variables` - a list of random variables to mix
119+ /// * `mix_weights` - the weight for each random variable. Must sum to 1
120+ ///
121+ /// # Returns
122+ /// * `mixed_rv` - the weighted mixture of the random variables
123+ ///
124+ /// # Examples
125+ /// ```
126+ /// use applpy_rust::algorithms::number::Number;
127+ /// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
128+ /// use applpy_rust::algorithms::transform::mixture_discrete;
129+ /// use num_rational::Rational64;
130+ ///
131+ /// let rv1 = RandomVariable {
132+ /// function: vec![
133+ /// Number::Rational(Rational64::new(1, 2)),
134+ /// Number::Rational(Rational64::new(1, 2)),
135+ /// ],
136+ /// support: vec![Number::Integer(1), Number::Integer(2)],
137+ /// functional_form: FunctionalForm::Pdf,
138+ /// domain_type: DomainType::Discrete,
139+ /// };
140+ ///
141+ /// let rv2 = RandomVariable {
142+ /// function: vec![
143+ /// Number::Rational(Rational64::new(1, 4)),
144+ /// Number::Rational(Rational64::new(3, 4)),
145+ /// ],
146+ /// support: vec![Number::Integer(2), Number::Integer(3)],
147+ /// functional_form: FunctionalForm::Pdf,
148+ /// domain_type: DomainType::Discrete,
149+ /// };
150+ ///
151+ /// let mixed = mixture_discrete(
152+ /// &[&rv1, &rv2],
153+ /// &[Number::Float(0.25), Number::Float(0.75)],
154+ /// )
155+ /// .unwrap();
156+ ///
157+ /// assert_eq!(
158+ /// mixed.support,
159+ /// vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)]
160+ /// );
161+ /// assert!((mixed.function[0].to_f64() - 0.125).abs() < 1e-12);
162+ /// assert!((mixed.function[1].to_f64() - 0.3125).abs() < 1e-12);
163+ /// assert!((mixed.function[2].to_f64() - 0.5625).abs() < 1e-12);
164+ /// assert!(matches!(mixed.functional_form, FunctionalForm::Pdf));
165+ /// assert!(matches!(mixed.domain_type, DomainType::Discrete));
166+ /// ```
167+ pub fn mixture_discrete (
168+ random_variables : & [ & RandomVariable ] ,
169+ mix_weights : & [ Number ] ,
170+ ) -> Result < RandomVariable , String > {
171+ if random_variables. len ( ) != mix_weights. len ( ) {
172+ return Err ( "the number of random variables and mix weights must be equal" . to_string ( ) ) ;
173+ }
174+
175+ let weight_sum: f64 = mix_weights. iter ( ) . copied ( ) . sum :: < Number > ( ) . to_f64 ( ) ;
176+ let tolerance = 1e-6 ;
177+ let upper_bound = 1.0 + tolerance;
178+ let lower_bound = 1.0 - tolerance;
179+ if weight_sum < lower_bound || weight_sum > upper_bound {
180+ return Err ( "the mix weights must sum to one" . to_string ( ) ) ;
181+ }
182+
183+ let mut raw_mixture_function = Vec :: new ( ) ;
184+ let mut raw_mixture_support = Vec :: new ( ) ;
185+
186+ for ( & random_variable, & mix_weight) in random_variables. iter ( ) . zip ( mix_weights. iter ( ) ) {
187+ let function = & random_variable. to_pdf ( ) ?. function ;
188+ let support = & random_variable. to_pdf ( ) ?. support ;
189+
190+ for ( & function_value, & support_value) in function. iter ( ) . zip ( support. iter ( ) ) {
191+ let partial_probability = function_value * mix_weight;
192+
193+ let support_index = raw_mixture_support. iter ( ) . position ( |& x| x == support_value) ;
194+
195+ match support_index {
196+ Some ( idx) => {
197+ raw_mixture_function[ idx] += partial_probability;
198+ }
199+ None => {
200+ raw_mixture_support. push ( support_value) ;
201+ raw_mixture_function. push ( partial_probability) ;
202+ }
203+ }
204+ }
205+ }
206+
207+ let mut raw_mixture_pair: Vec < _ > = raw_mixture_support
208+ . into_iter ( )
209+ . zip ( raw_mixture_function)
210+ . collect ( ) ;
211+
212+ raw_mixture_pair. sort_by ( |a, b| {
213+ let first_value = a. 0 . to_f64 ( ) ;
214+ let second_value = b. 0 . to_f64 ( ) ;
215+ first_value. total_cmp ( & second_value)
216+ } ) ;
217+
218+ let ( mixture_support, mixture_function) = raw_mixture_pair. into_iter ( ) . unzip ( ) ;
219+
220+ let mix_rv = RandomVariable {
221+ function : mixture_function,
222+ support : mixture_support,
223+ functional_form : FunctionalForm :: Pdf ,
224+ domain_type : DomainType :: Discrete ,
225+ } ;
226+ Ok ( mix_rv)
227+ }
228+
115229#[ cfg( test) ]
116230mod tests {
117231 use super :: * ;
@@ -136,6 +250,30 @@ mod tests {
136250 }
137251 }
138252
253+ fn sample_mixture_rv_1 ( ) -> RandomVariable {
254+ RandomVariable {
255+ function : vec ! [
256+ Number :: Rational ( Rational64 :: new( 1 , 2 ) ) ,
257+ Number :: Rational ( Rational64 :: new( 1 , 2 ) ) ,
258+ ] ,
259+ support : vec ! [ Number :: Integer ( 1 ) , Number :: Integer ( 2 ) ] ,
260+ functional_form : FunctionalForm :: Pdf ,
261+ domain_type : DomainType :: Discrete ,
262+ }
263+ }
264+
265+ fn sample_mixture_rv_2 ( ) -> RandomVariable {
266+ RandomVariable {
267+ function : vec ! [
268+ Number :: Rational ( Rational64 :: new( 1 , 4 ) ) ,
269+ Number :: Rational ( Rational64 :: new( 3 , 4 ) ) ,
270+ ] ,
271+ support : vec ! [ Number :: Integer ( 2 ) , Number :: Integer ( 3 ) ] ,
272+ functional_form : FunctionalForm :: Pdf ,
273+ domain_type : DomainType :: Discrete ,
274+ }
275+ }
276+
139277 #[ test]
140278 fn truncate_discrete_renormalizes_probabilities_within_range ( ) {
141279 let rv = sample_discrete_rv ( ) ;
@@ -166,4 +304,55 @@ mod tests {
166304 Err ( msg) if msg == "min support must be greater than or equal to the lowest support value"
167305 ) ) ;
168306 }
307+
308+ #[ test]
309+ fn mixture_discrete_combines_duplicate_support_and_sorts_output ( ) {
310+ let rv1 = sample_mixture_rv_1 ( ) ;
311+ let rv2 = sample_mixture_rv_2 ( ) ;
312+
313+ let mixed =
314+ mixture_discrete ( & [ & rv1, & rv2] , & [ Number :: Float ( 0.25 ) , Number :: Float ( 0.75 ) ] ) . unwrap ( ) ;
315+
316+ assert_eq ! (
317+ mixed. support,
318+ vec![ Number :: Integer ( 1 ) , Number :: Integer ( 2 ) , Number :: Integer ( 3 ) ]
319+ ) ;
320+ assert ! ( ( mixed. function[ 0 ] . to_f64( ) - 0.125 ) . abs( ) < 1e-12 ) ;
321+ assert ! ( ( mixed. function[ 1 ] . to_f64( ) - 0.3125 ) . abs( ) < 1e-12 ) ;
322+ assert ! ( ( mixed. function[ 2 ] . to_f64( ) - 0.5625 ) . abs( ) < 1e-12 ) ;
323+ assert ! ( matches!( mixed. functional_form, FunctionalForm :: Pdf ) ) ;
324+ assert ! ( matches!( mixed. domain_type, DomainType :: Discrete ) ) ;
325+ }
326+
327+ #[ test]
328+ fn mixture_discrete_returns_error_when_lengths_do_not_match ( ) {
329+ let rv1 = sample_mixture_rv_1 ( ) ;
330+ let rv2 = sample_mixture_rv_2 ( ) ;
331+
332+ let result = mixture_discrete ( & [ & rv1, & rv2] , & [ Number :: Rational ( Rational64 :: new ( 1 , 1 ) ) ] ) ;
333+
334+ assert ! ( matches!(
335+ result,
336+ Err ( msg) if msg == "the number of random variables and mix weights must be equal"
337+ ) ) ;
338+ }
339+
340+ #[ test]
341+ fn mixture_discrete_returns_error_when_weights_do_not_sum_to_one ( ) {
342+ let rv1 = sample_mixture_rv_1 ( ) ;
343+ let rv2 = sample_mixture_rv_2 ( ) ;
344+
345+ let result = mixture_discrete (
346+ & [ & rv1, & rv2] ,
347+ & [
348+ Number :: Rational ( Rational64 :: new ( 1 , 3 ) ) ,
349+ Number :: Rational ( Rational64 :: new ( 1 , 3 ) ) ,
350+ ] ,
351+ ) ;
352+
353+ assert ! ( matches!(
354+ result,
355+ Err ( msg) if msg == "the mix weights must sum to one"
356+ ) ) ;
357+ }
169358}
0 commit comments