diff --git a/generate/src/generation/mod.rs b/generate/src/generation/mod.rs index 801b237..0d4ad3e 100644 --- a/generate/src/generation/mod.rs +++ b/generate/src/generation/mod.rs @@ -709,28 +709,32 @@ impl GenerationCtx { let target_bb = self.add_new_bb(); targets.push(target_bb); - let target_discr = match discr_val { - Literal::Uint(i, _) => i, - Literal::Int(i, _) => i as u128, - // Literal::Bool(b) => b as u128, - // Literal::Char(c) => c as u128, - _ => unreachable!("invalid switchint discriminant"), - }; - let branches: Vec<(u128, BasicBlock)> = targets - .iter() - .enumerate() - .filter_map(|(i, &bb)| { - if bb == target_bb { - Some((target_discr, bb)) - } else if i as u128 == target_discr { - // Prevent duplicate - None - } else { - Some((i as u128, bb)) + let mut branches = Vec::new(); + for bb in targets { + if bb == target_bb { + branches.push((discr_val, bb)); + } else { + // If this is a decoy bb, try a few times to generate a literal that is not already + // in the targets array. + for _ in 0..8 { + let val = self + .rng + .borrow_mut() + .gen_literal(discr_val.ty(), &self.tcx) + .unwrap(); + // We cannot reuse the non-decoy value, even if it's not in the array because + // it might be on a later iteration. + if val == discr_val { + continue; + } + if !branches.iter().any(|(arm_val, _bb)| *arm_val == val) { + branches.push((val, bb)); + break; + } } - }) - .collect(); + } + } let term = Terminator::SwitchInt { discr: Operand::Copy(discr), diff --git a/mir/src/serialize.rs b/mir/src/serialize.rs index dfd8834..c204e48 100644 --- a/mir/src/serialize.rs +++ b/mir/src/serialize.rs @@ -306,7 +306,7 @@ impl Terminator { } } Terminator::SwitchInt { discr, targets } => { - let arms = targets.match_arms(); + let arms = targets.match_arms(tcx); format!("match {} {{\n{}\n}}", discr.serialize(tcx), arms) } Terminator::Hole => unreachable!("hole"), diff --git a/mir/src/syntax.rs b/mir/src/syntax.rs index 3edf1d7..8a44210 100644 --- a/mir/src/syntax.rs +++ b/mir/src/syntax.rs @@ -3,6 +3,7 @@ use std::num::TryFromIntError; use index_vec::{IndexVec, define_index_type}; use smallvec::SmallVec; +use crate::serialize::Serialize; use crate::tyctxt::TyCtxt; #[derive(Clone)] @@ -195,7 +196,7 @@ pub enum Terminator { #[derive(Clone)] pub struct SwitchTargets { - pub branches: Vec<(u128, BasicBlock)>, + pub branches: Vec<(Literal, BasicBlock)>, pub otherwise: BasicBlock, } @@ -225,7 +226,7 @@ pub enum AggregateKind { Adt(TyId, VariantIdx), } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Literal { Uint(u128, UintTy), Int(i128, IntTy), @@ -790,11 +791,11 @@ impl Function { } impl SwitchTargets { - pub fn match_arms(&self) -> String { + pub fn match_arms(&self, tcx: &TyCtxt) -> String { let mut arms: String = self .branches .iter() - .map(|(val, bb)| format!("{val} => {},\n", bb.identifier())) + .map(|(val, bb)| format!("{} => {},\n", val.serialize(tcx), bb.identifier())) .collect(); arms.push_str(&format!("_ => {}", self.otherwise.identifier())); arms