Skip to content

Commit 13611a5

Browse files
committed
add error handling
1 parent f00f33e commit 13611a5

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

rust/src/algorithms/transform.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ pub fn mixture_discrete(
124124
random_variables: &[&RandomVariable],
125125
mix_weights: &[Number],
126126
) -> Result<RandomVariable, String> {
127+
if random_variables.len() != mix_weights.len() {
128+
return Err("the number of random variables and mix weights must be equal".to_string());
129+
}
130+
131+
let weight_sum = mix_weights
132+
.iter()
133+
.fold(Number::default(), |acc, x| acc + *x);
134+
135+
let one = Number::Integer(1);
136+
let tolerance = Number::Float(1e-6);
137+
if weight_sum < one - tolerance || weight_sum > one + tolerance {
138+
return Err("the mix weights must sum to one".to_string());
139+
}
127140

128141
let mut raw_mixture_function = Vec::new();
129142
let mut raw_mixture_support = Vec::new();
@@ -138,14 +151,17 @@ pub fn mixture_discrete(
138151
raw_mixture_support.push(support_value);
139152
raw_mixture_function.push(partial_probability);
140153
} else {
141-
let support_index = support.iter().position(|&x| x == support_value)
154+
let support_index = support
155+
.iter()
156+
.position(|&x| x == support_value)
142157
.expect("support value not found in mixture support");
143158
raw_mixture_function[support_index] += partial_probability;
144159
}
145160
}
146161
}
147162

148-
let mut raw_mixture_pair: Vec<_> = raw_mixture_support.into_iter()
163+
let mut raw_mixture_pair: Vec<_> = raw_mixture_support
164+
.into_iter()
149165
.zip(raw_mixture_function)
150166
.collect();
151167

@@ -155,9 +171,7 @@ pub fn mixture_discrete(
155171
first_value.total_cmp(&second_value)
156172
});
157173

158-
let (mixture_support, mixture_function) = raw_mixture_pair
159-
.into_iter()
160-
.unzip();
174+
let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip();
161175

162176
let mix_rv = RandomVariable {
163177
function: mixture_function,

0 commit comments

Comments
 (0)