@@ -124,6 +124,19 @@ pub fn mixture_discrete(
124124 random_variables : & [ & RandomVariable ] ,
125125 mix_weights : & [ Number ] ,
126126) -> Result < RandomVariable , String > {
127+ if random_variables. len ( ) != mix_weights. len ( ) {
128+ return Err ( "the number of random variables and mix weights must be equal" . to_string ( ) ) ;
129+ }
130+
131+ let weight_sum = mix_weights
132+ . iter ( )
133+ . fold ( Number :: default ( ) , |acc, x| acc + * x) ;
134+
135+ let one = Number :: Integer ( 1 ) ;
136+ let tolerance = Number :: Float ( 1e-6 ) ;
137+ if weight_sum < one - tolerance || weight_sum > one + tolerance {
138+ return Err ( "the mix weights must sum to one" . to_string ( ) ) ;
139+ }
127140
128141 let mut raw_mixture_function = Vec :: new ( ) ;
129142 let mut raw_mixture_support = Vec :: new ( ) ;
@@ -138,14 +151,17 @@ pub fn mixture_discrete(
138151 raw_mixture_support. push ( support_value) ;
139152 raw_mixture_function. push ( partial_probability) ;
140153 } else {
141- let support_index = support. iter ( ) . position ( |& x| x == support_value)
154+ let support_index = support
155+ . iter ( )
156+ . position ( |& x| x == support_value)
142157 . expect ( "support value not found in mixture support" ) ;
143158 raw_mixture_function[ support_index] += partial_probability;
144159 }
145160 }
146161 }
147162
148- let mut raw_mixture_pair: Vec < _ > = raw_mixture_support. into_iter ( )
163+ let mut raw_mixture_pair: Vec < _ > = raw_mixture_support
164+ . into_iter ( )
149165 . zip ( raw_mixture_function)
150166 . collect ( ) ;
151167
@@ -155,9 +171,7 @@ pub fn mixture_discrete(
155171 first_value. total_cmp ( & second_value)
156172 } ) ;
157173
158- let ( mixture_support, mixture_function) = raw_mixture_pair
159- . into_iter ( )
160- . unzip ( ) ;
174+ let ( mixture_support, mixture_function) = raw_mixture_pair. into_iter ( ) . unzip ( ) ;
161175
162176 let mix_rv = RandomVariable {
163177 function : mixture_function,
0 commit comments