From 2120a30e61c4ec489ac0605eadd7384301ad0e35 Mon Sep 17 00:00:00 2001 From: Wenyao Chen Date: Fri, 8 May 2026 17:47:52 +1000 Subject: [PATCH] add rceus into RUPTA --- src/builder/fpag_builder.rs | 1 - src/builder/special_function_handler.rs | 50 ++ src/graph/call_graph.rs | 36 +- src/lib.rs | 4 +- src/mir/context.rs | 4 + src/pre_analysis/mod.rs | 2 + .../func_pointer_flow_analysis.rs | 567 ++++++++++++++++++ .../mod.rs | 2 + .../precision_critical_func_identification.rs | 175 ++++++ src/{ => pre_analysis}/rta/body_visitor.rs | 10 +- src/{ => pre_analysis}/rta/mod.rs | 2 +- src/{ => pre_analysis}/rta/rta.rs | 12 +- src/pta/andersen.rs | 3 +- src/pta/context_sensitive.rs | 64 +- src/pta/mod.rs | 15 +- src/pta/propagator/propagator.rs | 29 +- src/pta/strategies/context_strategy.rs | 189 +++++- src/pta/strategies/stack_filtering.rs | 18 +- src/util/options.rs | 20 +- 19 files changed, 1117 insertions(+), 86 deletions(-) create mode 100644 src/pre_analysis/mod.rs create mode 100644 src/pre_analysis/precision_critical_func_identification/func_pointer_flow_analysis.rs create mode 100644 src/pre_analysis/precision_critical_func_identification/mod.rs create mode 100644 src/pre_analysis/precision_critical_func_identification/precision_critical_func_identification.rs rename src/{ => pre_analysis}/rta/body_visitor.rs (98%) rename src/{ => pre_analysis}/rta/mod.rs (99%) rename src/{ => pre_analysis}/rta/rta.rs (99%) diff --git a/src/builder/fpag_builder.rs b/src/builder/fpag_builder.rs index d07b3b4..79b92cd 100644 --- a/src/builder/fpag_builder.rs +++ b/src/builder/fpag_builder.rs @@ -23,7 +23,6 @@ use rustc_middle::ty; use rustc_middle::ty::{Const, Ty, TyCtxt, TyKind, GenericArgsRef}; use rustc_span::source_map::Spanned; use rustc_target::abi::FieldIdx; - use crate::builder::{call_graph_builder, special_function_handler}; use crate::graph::func_pag::FuncPAG; use crate::graph::pag::PAGEdgeEnum; diff --git a/src/builder/special_function_handler.rs b/src/builder/special_function_handler.rs index 21748fa..756901f 100644 --- a/src/builder/special_function_handler.rs +++ b/src/builder/special_function_handler.rs @@ -80,12 +80,62 @@ lazy_static! { }; } + +lazy_static! { + static ref PRECISION_CRITICAL_FUNCTIONS: HashSet = { + let mut set = HashSet::new(); + set.insert(KnownNames::StdIntrinsicsTransmute); + set.insert(KnownNames::StdIntrinsicsOffset); + set.insert(KnownNames::StdIntrinsicsArithOffset); + set.insert(KnownNames::StdPtrConstPtrCast); + set.insert(KnownNames::StdPtrConstPtrAdd); + set.insert(KnownNames::StdPtrConstPtrSub); + set.insert(KnownNames::StdPtrConstPtrOffset); + set.insert(KnownNames::StdPtrConstPtrByteAdd); + set.insert(KnownNames::StdPtrConstPtrByteSub); + set.insert(KnownNames::StdPtrConstPtrByteOffset); + set.insert(KnownNames::StdPtrConstPtrWrappingAdd); + set.insert(KnownNames::StdPtrConstPtrWrappingSub); + set.insert(KnownNames::StdPtrConstPtrWrappingOffset); + set.insert(KnownNames::StdPtrConstPtrWrappingByteAdd); + set.insert(KnownNames::StdPtrConstPtrWrappingByteSub); + set.insert(KnownNames::StdPtrConstPtrWrappingByteOffset); + set.insert(KnownNames::StdPtrMutPtrCast); + set.insert(KnownNames::StdPtrMutPtrAdd); + set.insert(KnownNames::StdPtrMutPtrSub); + set.insert(KnownNames::StdPtrMutPtrOffset); + set.insert(KnownNames::StdPtrMutPtrByteAdd); + set.insert(KnownNames::StdPtrMutPtrByteSub); + set.insert(KnownNames::StdPtrMutPtrByteOffset); + set.insert(KnownNames::StdPtrMutPtrWrappingAdd); + set.insert(KnownNames::StdPtrMutPtrWrappingSub); + set.insert(KnownNames::StdPtrMutPtrWrappingOffset); + set.insert(KnownNames::StdPtrMutPtrWrappingByteAdd); + set.insert(KnownNames::StdPtrMutPtrWrappingByteSub); + set.insert(KnownNames::StdPtrMutPtrWrappingByteOffset); + set.insert(KnownNames::StdPtrNonNullAsPtr); + set.insert(KnownNames::StdPtrUniqueNewUnchecked); + set.insert(KnownNames::StdResultMapErr); + set.insert(KnownNames::RustRealloc); + set.insert(KnownNames::StdAllocRealloc); + set.insert(KnownNames::StdAllocAllocatorGrow); + set.insert(KnownNames::StdAllocAllocatorGrowZeroed); + set.insert(KnownNames::StdAllocAllocatorShrink); + set + }; +} + /// Returns true if the function with `def_id` is specially handled. pub fn is_specially_handled_function(acx: &mut AnalysisContext, def_id: DefId) -> bool { let known_name = acx.get_known_name_for(def_id); SPECIALLY_HANDLED_FUNCTIONS.contains(&known_name) } + +pub fn is_specially_handled_precision_critical_function(acx: &mut AnalysisContext, def_id: DefId) -> bool { + let known_name = acx.get_known_name_for(def_id); + PRECISION_CRITICAL_FUNCTIONS.contains(&known_name) +} /// Handling calls to special functions. /// /// Returns true if this callee function is handled as a special function. diff --git a/src/graph/call_graph.rs b/src/graph/call_graph.rs index 8d79d94..b70e7ed 100644 --- a/src/graph/call_graph.rs +++ b/src/graph/call_graph.rs @@ -9,13 +9,12 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::hash::Hash; - use crate::mir::analysis_context::AnalysisContext; use crate::mir::call_site::{BaseCallSite, CallType, CSBaseCallSite}; use crate::mir::function::{FuncId, CSFuncId}; +use rustc_middle::mir::Location; use crate::util::chunked_queue::{self, ChunkedQueue}; use crate::util::dot::Dot; - /// Unique identifiers for call graph nodes. pub type CGNodeId = NodeIndex; /// Unique identifiers for call graph edges. @@ -26,6 +25,8 @@ pub type CSCallGraph = CallGraph; pub trait CGFunction: Copy + Clone + PartialEq + Eq + Hash + Debug { fn dot_fmt(&self, acx: &AnalysisContext, f: &mut fmt::Formatter) -> fmt::Result; + + fn get_func_id(&self) -> FuncId; } impl CGFunction for FuncId { @@ -35,6 +36,10 @@ impl CGFunction for FuncId { acx.get_function_reference(*self).to_string() )) } + + fn get_func_id(&self) -> FuncId { + *self + } } impl CGFunction for CSFuncId { @@ -44,32 +49,52 @@ impl CGFunction for CSFuncId { acx.get_function_reference(self.func_id).to_string(), )) } + + fn get_func_id(&self) -> FuncId { + self.func_id + } } pub trait CGCallSite: Copy + Clone + PartialEq + Eq + Hash + Debug { fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result; + + fn get_location(&self) -> &Location; } impl CGCallSite for BaseCallSite { + fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_fmt(format_args!("{:?}", self.location)) } + + fn get_location(&self) -> &Location { + &self.location + } } impl CGCallSite for CSBaseCallSite { + fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_fmt(format_args!("{:?}", self.location)) } + + fn get_location(&self) -> &Location { + &self.location + } } #[derive(Debug)] pub struct CallGraphNode { pub(crate) func: F, + pub(crate) req_cs: bool, } impl CallGraphNode { pub fn new(func: F) -> Self { - CallGraphNode { func } + CallGraphNode { + func, + req_cs: false, + } } } @@ -120,7 +145,7 @@ impl CallGraph { /// Helper function to get a node or insert a new /// node if it does not exist in the map. - fn get_or_insert_node(&mut self, func: F) -> CGNodeId { + pub fn get_or_insert_node(&mut self, func: F) -> CGNodeId { match self.func_nodes.entry(func) { Entry::Occupied(o) => o.get().to_owned(), Entry::Vacant(v) => { @@ -207,6 +232,7 @@ impl CallGraph { self.reach_funcs.iter_copied() } + /// Produce a dot file representation of the call graph /// for displaying with Graphviz. pub fn to_dot(&self, acx: &AnalysisContext, dot_path: &std::path::Path) { @@ -226,4 +252,4 @@ impl CallGraph { Err(e) => panic!("Failed to write dot file output: {:?}", e), }; } -} +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 065d273..796893b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ min_specialization, // for rustc_index::newtype_index type_alias_impl_trait, // for impl Trait in trait definition, eg crate::mir::utils trait_alias, + let_chains, // for let chains in match arms )] #![allow( clippy::single_match, @@ -31,11 +32,12 @@ extern crate rustc_serialize; extern crate rustc_session; extern crate rustc_span; extern crate rustc_target; +extern crate rustc_attr; pub mod builder; pub mod graph; pub mod mir; pub mod pta; -pub mod rta; +pub mod pre_analysis; pub mod pts_set; pub mod util; diff --git a/src/mir/context.rs b/src/mir/context.rs index 1c902e0..da09224 100644 --- a/src/mir/context.rs +++ b/src/mir/context.rs @@ -129,6 +129,10 @@ impl ContextCache { pub fn context_list(&self) -> &IndexVec>> { &self.context_list } + + pub fn get_context_iter(&self) -> std::collections::hash_map::Iter<'_, Rc>, ContextId> { + self.context_to_index_map.iter() + } } diff --git a/src/pre_analysis/mod.rs b/src/pre_analysis/mod.rs new file mode 100644 index 0000000..6abd0cd --- /dev/null +++ b/src/pre_analysis/mod.rs @@ -0,0 +1,2 @@ +pub mod rta; +pub mod precision_critical_func_identification; diff --git a/src/pre_analysis/precision_critical_func_identification/func_pointer_flow_analysis.rs b/src/pre_analysis/precision_critical_func_identification/func_pointer_flow_analysis.rs new file mode 100644 index 0000000..65a474f --- /dev/null +++ b/src/pre_analysis/precision_critical_func_identification/func_pointer_flow_analysis.rs @@ -0,0 +1,567 @@ +use petgraph::graph::{DefaultIx, NodeIndex}; +use petgraph::Graph; +use rustc_middle::mir; +use rustc_hir::def_id::DefId; +use rustc_middle::ty::{Ty, TyKind, GenericArgsRef}; +use std::rc::Rc; +use petgraph::visit::EdgeRef; +use std::collections::{HashSet, VecDeque, HashMap}; +use crate::builder::substs_specializer::SubstsSpecializer; +use crate::mir::analysis_context::AnalysisContext; +use crate::mir::function::FuncId; +use rustc_middle::mir::Location; +use crate::pre_analysis::rta::rta::RapidTypeAnalysis; +use crate::mir::path::PathEnum; +use rustc_span::source_map::Spanned; + + +use crate::mir::path::Path; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DFGNode { + pub path: Rc, +} + +impl DFGNode { + pub fn new(path: Rc) -> Self { + DFGNode { path } + } + + /// Returns the path of the node. + pub fn path(&self) -> &Rc { + &self.path + } + +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PFGEdge { + pub kind: PFGEdgeEnum, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum PFGEdgeEnum { + IntraPFGEdge, + CallPFGEdge(Location) +} + +// === Pointer Flow Graph for a function === +#[derive(Clone, Debug)] +pub struct FuncPFG { + pub(crate) graph: Graph, + edges: HashSet<(NodeIndex, NodeIndex)>, + pub func_id: FuncId, + // whether there is any argument-to-return flow (precision critical) + pub has_arg_to_return_flow: bool, + // call-site to (arg paths, return path) mapping + pub callsite_to_locals: HashMap)>, Rc)>, + pub static_callsites: HashSet, + // record of function pointer (FnPtr) and definition (FnDef) call-sites + pub fn_ptr_def_callsites: HashSet, + // record of closure and dynamic dispatch call-sites + pub closure_dyn_callsites: HashSet, + // call-sites with arg->return path + pub cs_callsites: HashSet, + // parameter with flow to return + pub param_with_flow: HashSet, +} + + +impl FuncPFG { + pub fn new(func_id: FuncId) -> Self { + FuncPFG { + graph: Graph::new(), + edges: HashSet::new(), + func_id, + has_arg_to_return_flow: false, + callsite_to_locals: HashMap::new(), + static_callsites: HashSet::new(), + fn_ptr_def_callsites: HashSet::new(), + closure_dyn_callsites: HashSet::new(), + cs_callsites: HashSet::new(), + param_with_flow: HashSet::new(), + } + } + + pub fn add_node(&mut self, path: Rc) -> NodeIndex { + self.graph.add_node(DFGNode::new(path)) + } + + pub fn get_or_insert_node(&mut self, path: Rc) -> NodeIndex { + if let Some(node_index) = self.graph.node_indices().find(|&n| self.graph[n].path == ::clone(&(*path)).into()) { + node_index + } else { + self.add_node(path) + } + } + + pub fn add_edge(&mut self, src: Rc, dst: Rc, kind: PFGEdgeEnum) -> bool { + let src_index = self.get_or_insert_node(src); + let dst_index = self.get_or_insert_node(dst); + if self.edges.insert((src_index, dst_index)) { + self.graph.add_edge(src_index, dst_index, PFGEdge { kind: kind.clone() }); + return true; + } + false + } + + pub fn get_node(&self, path: &Path) -> Option<&DFGNode> { + let node = self.graph.node_indices().find(|&n| self.graph[n].path.as_ref() == path); + if let Some(node_index) = node { + Some(&self.graph[node_index]) + } else { + None + } + } + pub fn nodes(&self) -> impl Iterator { + self.graph.node_weights() + } + + // check loc ⇒ 𝑓 in CTXFunc + pub fn is_cs_callsite(&self, loc: &Location) -> bool { + self.cs_callsites.contains(loc) + } + + /// Solve the reachability in the pointer flow graph. + pub fn solve_graph_reachability(&mut self) { + // self.print_graph(); + self.has_arg_to_return_flow = false; + self.cs_callsites.clear(); + + let param_nodes: Vec = self + .graph + .node_indices() + .filter(|&n| matches!(self.graph[n].path.value, PathEnum::Parameter{..})) + .collect(); + + let return_idx = self.get_or_insert_node(Path::new_return_value(self.func_id)); + + // --- nodes that can reach return (reverse BFS) --- + let mut reaches_ret: HashSet = HashSet::new(); + { + let mut q = VecDeque::new(); + q.push_back(return_idx); + while let Some(n) = q.pop_front() { + if !reaches_ret.insert(n) { continue; } + for e in self.graph.edges_directed(n, petgraph::Direction::Incoming) { + q.push_back(e.source()); + } + } + } + + + // --- F: nodes reachable from any parameter (forward BFS) --- + let mut from_param: HashSet = HashSet::new(); + { + let mut q: VecDeque = param_nodes.iter().copied().collect(); + while let Some(n) = q.pop_front() { + if !from_param.insert(n) { continue; } + for v in self.graph.edges_directed(n, petgraph::Direction::Outgoing) { + q.push_back(v.target()); + } + } + } + + // --- Any arg → return path? --- + let has_arg_to_return_flow = param_nodes.iter().any(|p| reaches_ret.contains(p)); + self.has_arg_to_return_flow = has_arg_to_return_flow; + + if !has_arg_to_return_flow { + self.cs_callsites.clear(); + return; + } + + // --- Collect only parameters that lie on arg→return path --- + self.param_with_flow.clear(); + for &p in ¶m_nodes { + if reaches_ret.contains(&p) { + let path = &self.graph[p].path; + let idx = match path.value { + PathEnum::Parameter { ordinal, .. } => ordinal, + _ => continue, + }; + self.param_with_flow.insert(idx); + } + } + + // --- Collect only call-sites that lie on some arg→return path --- + for e in self.graph.edge_references() { + let u = e.source(); + let v = e.target(); + if from_param.contains(&u) && reaches_ret.contains(&v) { + if let PFGEdgeEnum::CallPFGEdge(loc) = &e.weight().kind { + self.cs_callsites.insert(*loc); + } + } + } + } + + pub fn print_graph(&self) { + println!("--- Pointer Flow Graph for function {:?} ---", self.func_id); + for n in self.graph.node_indices() { + let node = &self.graph[n]; + println!(" Node {:?}", node.path); + // println!(" Node {:?}: from_non_param={:?}", node.path, node.from_non_param); + } + for e in self.graph.edge_references() { + let src = &self.graph[e.source()]; + let dst = &self.graph[e.target()]; + println!( + " {:?} --{:?}--> {:?}", + src.path, + e.weight().kind, + dst.path + ); + } + } +} + + +pub struct FuncPointerFlowAnalysis<'a, 'rta, 'tcx, 'compilation> { + pub(crate) rta: &'rta mut RapidTypeAnalysis<'a, 'tcx, 'compilation>, + pub(crate) func_id: FuncId, + pub(crate) mir: &'tcx mir::Body<'tcx>, + pub (crate) substs_specializer: SubstsSpecializer<'tcx>, + pub pfg: FuncPFG, +} + +impl<'a, 'rta, 'tcx, 'compilation> FuncPointerFlowAnalysis<'a, 'rta, 'tcx, 'compilation> { + pub fn new( + rta: &'rta mut RapidTypeAnalysis<'a, 'tcx, 'compilation>, + func_id: FuncId, + mir: &'tcx mir::Body<'tcx>, + ) -> FuncPointerFlowAnalysis<'a, 'rta, 'tcx, 'compilation> { + let func_ref = rta.acx.get_function_reference(func_id); + let substs_specializer = SubstsSpecializer::new( + rta.acx.tcx, + func_ref.generic_args.clone() + ); + + FuncPointerFlowAnalysis { + rta, + func_id, + mir, + substs_specializer, + pfg: FuncPFG::new(func_id), + } + } + + #[inline] + fn acx(&mut self) -> &mut AnalysisContext<'tcx, 'compilation> { + self.rta.acx + } + + /// Construct the intra dataflow graph and calculate the dataflow information. + pub fn calculate_intra_pointer_flow(&mut self) -> FuncPFG { + + // Check if any argument or return type contains pointer or reference + let mut arg_has_ptr = false; + for arg in self.mir.args_iter() { + let decl = &self.mir.local_decls[arg]; + let arg_ty = self.substs_specializer.specialize_generic_argument_type(decl.ty); + if self.type_contains_pointer_or_ref(arg_ty) { + arg_has_ptr = true; + break; + } + } + let return_local = &self.mir.local_decls[mir::Local::from_u32(0)]; + let return_ty = self.substs_specializer.specialize_generic_argument_type(return_local.ty); + let return_has_ptr = self.type_contains_pointer_or_ref(return_ty); + // we construct the pointer flow graph only if there exist pointer/ref in args and return type + if arg_has_ptr && return_has_ptr { + self.visit_body(); + self.solve_graph_initial(); + } else { + self.pfg.has_arg_to_return_flow = false; + self.pfg.cs_callsites.clear(); + } + + return self.pfg.clone(); + } + + + fn solve_graph_initial(&mut self) { + let ret_path = Path::new_local_parameter_or_result(self.func_id, 0, self.mir.arg_count); + let _return_idx = self.pfg.get_or_insert_node(ret_path); + self.pfg.solve_graph_reachability(); + + } + + fn visit_body(&mut self){ + for bb in self.mir.basic_blocks.indices() { + self.visit_basic_block(bb); + } + } + + fn visit_basic_block(&mut self, bb: mir::BasicBlock,) { + let mir::BasicBlockData { + ref statements, + ref terminator, + .. + } = &self.mir[bb]; + let mut location = bb.start_location(); + let terminator_index = statements.len(); + + while location.statement_index < terminator_index { + self.visit_statement(location, &statements[location.statement_index]); + location.statement_index += 1; + } + + if let Some(mir::Terminator { + ref source_info, + ref kind, + }) = *terminator + { + self.visit_terminator(location, kind, *source_info); + } + } + + /// Calls a specialized visitor for each kind of statement. + fn visit_statement(&mut self, _location: mir::Location, statement: &mir::Statement<'tcx>) { + let mir::Statement {kind, source_info: _} = statement; + match kind { + mir::StatementKind::Assign(box (place, rvalue)) => { + + let place_ty = self.substs_specializer.specialize_generic_argument_type( + self.mir.local_decls[place.local].ty + ); + + if self.type_contains_pointer_or_ref(place_ty) { + self.visit_assign(place, rvalue); + } + + } + _ => (), + } + } + + fn visit_assign(&mut self, lplace: &mir::Place<'tcx>, rvalue: &mir::Rvalue<'tcx>) { + match rvalue { + mir::Rvalue::Use(operand) | mir::Rvalue::Repeat(operand, _) => { + self.visit_use(lplace, operand); + } + mir::Rvalue::Ref(_, _, place) | mir::Rvalue::AddressOf(_, place) => { + self.add_ref_pfg_edge(lplace, place); + } + mir::Rvalue::Cast(_cast_kind, operand, _ty) => { + self.visit_use(lplace, operand); + } + mir::Rvalue::Aggregate(_ ,operands) => { + for (_i, operand) in operands.iter().enumerate() { + self.visit_use(lplace, operand); + } + } + mir::Rvalue::BinaryOp(_, box (loperand, _roperand)) => { + self.visit_use(lplace, loperand); + } + mir::Rvalue::ShallowInitBox(operand, _) => { + self.visit_use(lplace, operand); + } + mir::Rvalue::CopyForDeref(place) => { + self.add_ref_pfg_edge(lplace, place); + } + _ => { + // println!("Skipping rvalue kind: {:?}", rvalue); + } + } + } + + fn visit_use(&mut self, lplace: &mir::Place<'tcx>, operand: &mir::Operand<'tcx>) { + match operand { + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + + // For Copy and Move operands, + // we add an edge from rplace to lplace only if rplace contains pointer or reference. + let local_decls = &self.mir.local_decls; + let place_ty = place.ty(local_decls, self.acx().tcx); + let place_ty = self.substs_specializer.specialize_generic_argument_type(place_ty.ty); + + if self.type_contains_pointer_or_ref(place_ty) { + let lpath = Path::new_local_parameter_or_result(self.func_id, lplace.local.as_usize(), self.mir.arg_count); + let rpath = Path::new_local_parameter_or_result(self.func_id, place.local.as_usize(), self.mir.arg_count); + let edge_kind = PFGEdgeEnum::IntraPFGEdge; + self.pfg.add_edge(rpath, lpath, edge_kind); + } + } + _ => {} // We do not consider constant operands for pointer flow analysis. + } + + } + + fn add_ref_pfg_edge(&mut self, lplace: &mir::Place<'tcx>, rplace: &mir::Place<'tcx>) { + // For Ref and AddressOf rvalues, + // we add an edge from rplace to lplace only if rplace contains pointer or reference. + let rplace_ty = self.substs_specializer.specialize_generic_argument_type( + self.mir.local_decls[rplace.local].ty + ); + + if self.type_contains_pointer_or_ref(rplace_ty) { + let lpath = Path::new_local_parameter_or_result(self.func_id, lplace.local.as_usize(), self.mir.arg_count); + let rpath = Path::new_local_parameter_or_result(self.func_id, rplace.local.as_usize(), self.mir.arg_count); + let edge_kind = PFGEdgeEnum::IntraPFGEdge; + self.pfg.add_edge(rpath, lpath, edge_kind); + } + } + + fn visit_terminator( + &mut self, + location: mir::Location, + kind: &mir::TerminatorKind<'tcx>, + _source_info: mir::SourceInfo, + ) { + match kind { + mir::TerminatorKind::Call { + func, + args, + destination, + target: _, + unwind: _, + call_source: _, + fn_span: _, + } => self.visit_call(func, args, destination, location), + mir::TerminatorKind::InlineAsm { + template: _, + operands: _, + destination: _, + .. + } => {} + _ => {} + } + } + + + /// visit call for collecting callsite information + fn visit_call( + &mut self, + func: &mir::Operand<'tcx>, + args: &Vec>>, + destination: &mir::Place<'tcx>, + location: mir::Location, + ) { + let destination_ty = self.substs_specializer.specialize_generic_argument_type( + self.mir.local_decls[destination.local].ty + ); + if !self.type_contains_pointer_or_ref(destination_ty) { + return; + } + // collect callsite information + let destination_path = Path::new_local_parameter_or_result(self.func_id, destination.local.as_usize(), self.mir.arg_count); + let args_paths = self.visit_args(args).into_iter().collect(); + + self.pfg.callsite_to_locals.insert(location, (args_paths, destination_path.clone())); + match func { + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + let fn_item_ty = self.substs_specializer.specialize_generic_argument_type( + self.mir.local_decls[place.local].ty + ); + match fn_item_ty.kind() { + TyKind::Closure(callee_def_id, gen_args) + | TyKind::FnDef(callee_def_id, gen_args) + | TyKind::Coroutine(callee_def_id, gen_args) => { + self.resolve_call(callee_def_id, gen_args, args, destination, location) + } + TyKind::FnPtr(_) => { + // consider as static call for PAG only + self.pfg.static_callsites.insert(location); + } + _ => {} + } + } + mir::Operand::Constant(box constant) => { + match constant.ty().kind() { + TyKind::Closure(callee_def_id, gen_args) + | TyKind::FnDef(callee_def_id, gen_args) + | TyKind::Coroutine(callee_def_id, gen_args) => { + self.resolve_call(callee_def_id, gen_args, args, destination, location) + } + TyKind::FnPtr(_) => { + // consider as static call for PAG only + self.pfg.static_callsites.insert(location); + } + _ => {} + } + + } + } + + } + + + /// collecting all dynamic dispatch callsites information + fn resolve_call( + &mut self, + callee_def_id: &DefId, + gen_args: &GenericArgsRef<'tcx>, + _args: &Vec>>, + _destination: &mir::Place<'tcx>, + location: mir::Location, + ) { + + if !self.acx().is_std_ops_fntrait_call(*callee_def_id) { + self.pfg.static_callsites.insert(location); + return; + } + + let mut first_subst_ty = match gen_args.types().next() { + Some(ty) => ty, + None => return, + }; + first_subst_ty = self.substs_specializer.specialize_generic_argument_type(first_subst_ty); + // Determine the callsite type based on the first generic argument type + match first_subst_ty.kind() { + TyKind::FnPtr(_) => { + self.pfg.fn_ptr_def_callsites.insert(location); + } + TyKind::FnDef(_, _) => { + self.pfg.fn_ptr_def_callsites.insert(location); + } + TyKind::Closure(_, _) | TyKind::Coroutine(_, _) => { + // The dispatch callee of Closure and Coroutine contains + // reference to the closure/coroutine as first argument. + self.pfg.closure_dyn_callsites.insert(location); + } + TyKind::Dynamic(_, _, _) => { + // The dispatch callee of Dynamic contains self reference as first argument. + self.pfg.closure_dyn_callsites.insert(location); + } + _ => {} + } + } + + // Collect argument paths + fn visit_args(&mut self, args: &Vec>>,) -> Vec<(usize, Rc)> { + let mut idx = 0; + let mut args_paths = Vec::new(); + for arg in args { + idx += 1; + match &arg.node { + mir::Operand::Copy(place) | mir::Operand::Move(place) => { + let arg_ty = self.substs_specializer.specialize_generic_argument_type( + self.mir.local_decls[place.local].ty + ); + + if self.type_contains_pointer_or_ref(arg_ty) { + let arg_path = Path::new_local_parameter_or_result(self.func_id, place.local.as_usize(), self.mir.arg_count); + args_paths.push((idx, arg_path)); + } + } + _ => {} + } + } + args_paths + } + + fn type_contains_pointer_or_ref(&mut self, ty: Ty<'tcx>) -> bool { + if ty.is_any_ptr() { + return true; + } else { + let ptr_projs = self.acx().get_pointer_projections(ty); + if !ptr_projs.is_empty() { + return true; + } else { + return false; + } + } + } + +} \ No newline at end of file diff --git a/src/pre_analysis/precision_critical_func_identification/mod.rs b/src/pre_analysis/precision_critical_func_identification/mod.rs new file mode 100644 index 0000000..8c4b8d3 --- /dev/null +++ b/src/pre_analysis/precision_critical_func_identification/mod.rs @@ -0,0 +1,2 @@ +pub mod func_pointer_flow_analysis; +pub mod precision_critical_func_identification; diff --git a/src/pre_analysis/precision_critical_func_identification/precision_critical_func_identification.rs b/src/pre_analysis/precision_critical_func_identification/precision_critical_func_identification.rs new file mode 100644 index 0000000..ef02e0e --- /dev/null +++ b/src/pre_analysis/precision_critical_func_identification/precision_critical_func_identification.rs @@ -0,0 +1,175 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::time::{Duration, Instant}; + +use petgraph::visit::EdgeRef; + +use crate::graph::call_graph::{CGCallSite, CGNodeId}; +use crate::mir::function::FuncId; +use crate::pre_analysis::rta::rta::RapidTypeAnalysis; + +use super::func_pointer_flow_analysis::{FuncPFG, FuncPointerFlowAnalysis, PFGEdgeEnum}; + +pub struct PrecCritFnIdent<'r, 'a, 'tcx, 'compilation> { + pub rta: &'r mut RapidTypeAnalysis<'a, 'tcx, 'compilation>, + + pub func_pfg_map: HashMap, + pub cs_funcs: HashSet, + + worklist: VecDeque, + + pub analysis_time: Duration, +} + +impl<'r, 'a, 'tcx, 'compilation> PrecCritFnIdent<'r, 'a, 'tcx, 'compilation> { + pub fn new(rta: &'r mut RapidTypeAnalysis<'a, 'tcx, 'compilation>) -> Self { + PrecCritFnIdent { + rta, + func_pfg_map: HashMap::new(), + cs_funcs: HashSet::new(), + worklist: VecDeque::new(), + analysis_time: Duration::ZERO, + } + } + + pub fn analyze(&mut self) { + let now = Instant::now(); + + let funcs: Vec = self.rta.call_graph.func_nodes.keys().copied().collect(); + for func_id in funcs { + if self.rta.specially_handled_functions.contains(&func_id) { + continue; + } + let def_id = self.rta.acx.get_function_reference(func_id).def_id; + if !self.rta.acx.tcx.is_mir_available(def_id) { + continue; + } + let mir = self.rta.acx.tcx.optimized_mir(def_id); + let pfg = FuncPointerFlowAnalysis::new(self.rta, func_id, mir).calculate_intra_pointer_flow(); + if pfg.has_arg_to_return_flow { + if let Some(&node_id) = self.rta.call_graph.func_nodes.get(&func_id) { + self.worklist.push_front(node_id); + } + } + self.func_pfg_map.insert(func_id, pfg); + } + + let specially_handled_pc: Vec = + self.rta.specially_handled_precision_critical_functions.iter().copied().collect(); + for f in specially_handled_pc { + if let Some(&node_id) = self.rta.call_graph.func_nodes.get(&f) { + self.worklist.push_front(node_id); + } + } + + self.precision_critical_func_identification(); + + for node_id in self.rta.call_graph.graph.node_indices() { + if let Some(node) = self.rta.call_graph.graph.node_weight(node_id) { + if node.req_cs { + self.cs_funcs.insert(node.func); + } + } + } + + self.analysis_time = now.elapsed(); + println!( + "Precision-critical function identification time: {}", + humantime::format_duration(self.analysis_time).to_string() + ); + } + + // Worklist algorithm for context sensitivity identification. + // The initial worklist contains functions with intra-procedural argument-to-return flow. + // For each function in the worklist, we propagate the pointer flow to its callers, + // and add the callers into the worklist if a caller now has argument-to-return flow. + fn precision_critical_func_identification(&mut self) { + // Split the borrows up-front so the inner loop can independently mutate + // func_pfg_map, worklist, and the call graph. + let call_graph = &mut self.rta.call_graph; + let func_pfg_map = &mut self.func_pfg_map; + let worklist = &mut self.worklist; + + let mut visited = HashSet::new(); + + while let Some(node_id) = worklist.pop_front() { + let node = match call_graph.graph.node_weight(node_id) { + Some(n) => n, + None => continue, + }; + let func_id = node.func; + visited.insert(node_id); + + // Extracting the parameters that have return flow in the callee. + // Here, arguments to all special handled callsites (without pag) will be connected to return directly. + let param_with_flow = match func_pfg_map.get(&func_id) { + Some(pfg) => Some(pfg.param_with_flow.clone()), + None => None, + }; + + // Collect incoming edges first to release the iterator's borrow on the graph + // before we mutate func_pfg_map / worklist below. + let incoming: Vec<(CGNodeId, rustc_middle::mir::Location)> = call_graph + .graph + .edges_directed(node_id, petgraph::Direction::Incoming) + .map(|e| (e.source(), *e.weight().callsite.get_location())) + .collect(); + + for (src_node_id, location) in incoming { + let src_node = match call_graph.graph.node_weight(src_node_id) { + Some(n) => n, + None => continue, + }; + + let src_func_id = src_node.func; + let src_func_pfg = match func_pfg_map.get_mut(&src_func_id) { + Some(pfg) => pfg, + None => continue, // This should not happen, caller always has pfg + }; + + // propagate pointer flow from callee to caller + let mut new_edges_added = false; + if src_func_pfg.callsite_to_locals.contains_key(&location) { + let (args, destination) = src_func_pfg.callsite_to_locals[&location].clone(); + let is_static_call = src_func_pfg.static_callsites.contains(&location); + let is_closure_or_dyn_call = src_func_pfg.closure_dyn_callsites.contains(&location); + for arg in args { + let (arg_idx, arg_path) = arg.clone(); + // For static calls, only consider the parameters that have flow to the callee. + // For closure/dyn calls, the second argument will be dispatched to parameter 2..N in callee, + // we connect arg 2 to return if any of parameter 2..N has flow in callee. + // for other dynamic resolved calls (fnptr), + // the second argument will be dispatched to parameter 1..N in callee, + // we connect arg 2 to return if any of parameter 1..N has flow in callee. + if let Some(param_with_flow) = ¶m_with_flow { + if is_static_call { + if !param_with_flow.contains(&arg_idx) { + continue; + } + } + if is_closure_or_dyn_call { + if !param_with_flow.contains(&arg_idx) && arg_idx == 1 { + continue; + } + } + } + let edge_kind = PFGEdgeEnum::CallPFGEdge(location); + new_edges_added |= src_func_pfg.add_edge(arg_path.clone(), destination.clone(), edge_kind); + } + } + + if new_edges_added { + src_func_pfg.solve_graph_reachability(); + if src_func_pfg.has_arg_to_return_flow { + worklist.push_back(src_node_id); + } + } + } + } + + for node_id in &visited { + if let Some(node) = call_graph.graph.node_weight_mut(*node_id) { + node.req_cs = true; + } + } + } +} diff --git a/src/rta/body_visitor.rs b/src/pre_analysis/rta/body_visitor.rs similarity index 98% rename from src/rta/body_visitor.rs rename to src/pre_analysis/rta/body_visitor.rs index f55907d..d443278 100644 --- a/src/rta/body_visitor.rs +++ b/src/pre_analysis/rta/body_visitor.rs @@ -24,7 +24,6 @@ pub struct BodyVisitor<'a, 'rta, 'tcx, 'compilation> { pub(crate) rta: &'rta mut RapidTypeAnalysis<'a, 'tcx, 'compilation>, pub(crate) func_id: FuncId, pub mir: &'tcx mir::Body<'tcx>, - /// For specializing the generic type in the method. substs_specializer: SubstsSpecializer<'tcx>, encountered_statics: HashSet, @@ -67,7 +66,7 @@ impl<'a, 'rta, 'tcx, 'compilation> BodyVisitor<'a, 'rta, 'tcx, 'compilation> { self.visit_baisc_block(bb); } } - + fn visit_baisc_block(&mut self, bb: mir::BasicBlock,) { let mir::BasicBlockData { ref statements, @@ -258,6 +257,9 @@ impl<'a, 'rta, 'tcx, 'compilation> BodyVisitor<'a, 'rta, 'tcx, 'compilation> { debug!("Call func {:?}, generic_args: {:?}", callee_def_id, gen_args); if special_function_handler::is_specially_handled_function(self.acx(), *callee_def_id) { + let is_precision_critical = special_function_handler::is_specially_handled_precision_critical_function(self.acx(), *callee_def_id); + + let callsite = BaseCallSite::new(self.func_id, location); // Special handlings for thread spawn functions @@ -278,7 +280,9 @@ impl<'a, 'rta, 'tcx, 'compilation> BodyVisitor<'a, 'rta, 'tcx, 'compilation> { self.rta.add_static_callsite(callsite); self.rta.add_call_edge(callsite, callee_func_id); self.rta.specially_handled_functions.insert(callee_func_id); - + if is_precision_critical { + self.rta.specially_handled_precision_critical_functions.insert(callee_func_id); + } return; } diff --git a/src/rta/mod.rs b/src/pre_analysis/rta/mod.rs similarity index 99% rename from src/rta/mod.rs rename to src/pre_analysis/rta/mod.rs index 8e83f4c..e1fdb7f 100644 --- a/src/rta/mod.rs +++ b/src/pre_analysis/rta/mod.rs @@ -1,5 +1,5 @@ pub mod body_visitor; -pub mod rta; +pub mod rta; use log::*; use rustc_driver::Compilation; diff --git a/src/rta/rta.rs b/src/pre_analysis/rta/rta.rs similarity index 99% rename from src/rta/rta.rs rename to src/pre_analysis/rta/rta.rs index f0b794f..864dbf5 100644 --- a/src/rta/rta.rs +++ b/src/pre_analysis/rta/rta.rs @@ -14,7 +14,6 @@ use crate::util::{type_util, chunked_queue, results_dumper}; use super::body_visitor::BodyVisitor; - pub struct RapidTypeAnalysis<'a, 'tcx, 'compilation> { /// The analysis context pub(crate) acx: &'a mut AnalysisContext<'tcx, 'compilation>, @@ -27,6 +26,8 @@ pub struct RapidTypeAnalysis<'a, 'tcx, 'compilation> { /// Records the functions that have been visited pub(crate) visited_functions: HashSet, pub(crate) specially_handled_functions: HashSet, + pub(crate) specially_handled_precision_critical_functions: HashSet, + pub static_callsites: HashSet, pub dyn_callsites: HashMap, HashSet<(BaseCallSite, DefId, GenericArgsRef<'tcx>)>>, @@ -52,6 +53,7 @@ impl<'a, 'tcx, 'compilation> RapidTypeAnalysis<'a, 'tcx, 'compilation> { rf_iter, visited_functions: HashSet::new(), specially_handled_functions: HashSet::new(), + specially_handled_precision_critical_functions: HashSet::new(), static_callsites: HashSet::new(), dyn_callsites: HashMap::new(), dyn_fntrait_callsites: HashMap::new(), @@ -78,8 +80,6 @@ impl<'a, 'tcx, 'compilation> RapidTypeAnalysis<'a, 'tcx, 'compilation> { FunctionReference::new_function_reference(entry_point, vec![]) ); self.call_graph.add_node(entry_func_id); - - // process terminators of reachable functions self.iteratively_process_reachable_functions(); self.analysis_time = now.elapsed(); @@ -88,6 +88,7 @@ impl<'a, 'tcx, 'compilation> RapidTypeAnalysis<'a, 'tcx, 'compilation> { "Rapid Type Analysis time: {}", humantime::format_duration(self.analysis_time).to_string() ); + } fn iteratively_process_reachable_functions(&mut self) { @@ -108,8 +109,6 @@ impl<'a, 'tcx, 'compilation> RapidTypeAnalysis<'a, 'tcx, 'compilation> { let func_ref = self.acx.get_function_reference(func_id); let def_id = func_ref.def_id; let generic_args = &func_ref.generic_args; - - // We don't count specially handled functions as we do not process them in pta if self.specially_handled_functions.contains(&func_id) { self.visited_functions.insert(func_id); continue; @@ -120,9 +119,10 @@ impl<'a, 'tcx, 'compilation> RapidTypeAnalysis<'a, 'tcx, 'compilation> { self.visited_functions.insert(func_id); continue; } - + self.promote_constants(def_id, generic_args); let mir = self.tcx().optimized_mir(def_id); + let mut bv = BodyVisitor::new(self, func_id, mir); bv.visit_body(); self.visited_functions.insert(func_id); diff --git a/src/pta/andersen.rs b/src/pta/andersen.rs index df58983..0ca692e 100644 --- a/src/pta/andersen.rs +++ b/src/pta/andersen.rs @@ -20,7 +20,7 @@ use crate::mir::function::FuncId; use crate::mir::analysis_context::AnalysisContext; use crate::mir::path::Path; use crate::pta::*; -use crate::rta::rta::RapidTypeAnalysis; +use crate::pre_analysis::rta::rta::RapidTypeAnalysis; use crate::util::chunked_queue; use crate::util::pta_statistics::AndersenStat; use crate::util::results_dumper; @@ -173,6 +173,7 @@ impl<'pta, 'tcx, 'compilation> AndersenPTA<'pta, 'tcx, 'compilation> { } } + // Add new call edges to pag fn process_new_calls(&mut self, new_calls: &Vec<(Rc, FuncId)>) { for (callsite, callee_id) in new_calls { diff --git a/src/pta/context_sensitive.rs b/src/pta/context_sensitive.rs index ed08cab..5a9cacd 100644 --- a/src/pta/context_sensitive.rs +++ b/src/pta/context_sensitive.rs @@ -11,7 +11,6 @@ use std::time::Duration; use itertools::Itertools; use log::*; use rustc_middle::ty::TyCtxt; - use super::*; use super::strategies::context_strategy::{ContextStrategy, KObjectSensitive}; use super::strategies::stack_filtering::StackFilter; @@ -24,10 +23,12 @@ use crate::mir::context::{Context, ContextId}; use crate::mir::function::{FuncId, CSFuncId}; use crate::mir::analysis_context::AnalysisContext; use crate::mir::path::{Path, CSPath, PathEnum}; -use crate::rta::rta::RapidTypeAnalysis; +use crate::pre_analysis::precision_critical_func_identification::precision_critical_func_identification::PrecCritFnIdent; +use crate::pre_analysis::rta::rta::RapidTypeAnalysis; use crate::util::pta_statistics::ContextSensitiveStat; use crate::util::{self, chunked_queue, results_dumper}; + pub type CallSiteSensitivePTA<'pta, 'tcx, 'compilation> = ContextSensitivePTA<'pta, 'tcx, 'compilation, KCallSiteSensitive>; /// The object-sensitive pointer analysis for Rust has not been throughly evaluated so far. pub type ObjectSensitivePTA<'pta, 'tcx, 'compilation> = ContextSensitivePTA<'pta, 'tcx, 'compilation, KObjectSensitive>; @@ -60,6 +61,7 @@ pub struct ContextSensitivePTA<'pta, 'tcx, 'compilation, S: ContextStrategy> { ctx_strategy: S, pub stack_filter: Option>, + pub pre_analysis_time: Duration, } @@ -223,12 +225,10 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> ContextSensitivePTA<'pta, 'tc /// Process a resolved call according to the call type fn process_new_call(&mut self, callsite: &Rc, callee: &FuncId) { let callee_def_id = self.acx.get_function_reference(*callee).def_id; - // an instance call if util::has_self_parameter(self.tcx(), callee_def_id) { - // borrow self (&self or &mut self) if util::has_self_ref_parameter(self.tcx(), callee_def_id) { - // the instance should be the pointed-to object of the self pointer - if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, None) { + // borrow self (&self or &mut self) — the instance is the pointed-to object of the self pointer + if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, None, *callee) { let cs_callee = CSFuncId::new(callee_cid, *callee); self.add_call_edge(callsite, &cs_callee); } @@ -237,21 +237,21 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> ContextSensitivePTA<'pta, 'tc self.assoc_calls.add_static_dispatch_instance_call(self_ref_id, callsite.clone(), *callee); } else { // move self let instance = callsite.args.get(0).expect("invalid arguments"); - if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance)) { + if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance), *callee) { let cs_callee = CSFuncId::new(callee_cid, *callee); self.add_call_edge(callsite, &cs_callee); } - } + } } else { - let callee_cid = self.ctx_strategy.new_static_call_context(callsite); + let callee_cid = self.ctx_strategy.new_static_call_context(callsite, *callee); let cs_callee = CSFuncId::new(callee_cid, *callee); self.add_call_edge(callsite, &cs_callee); } } - fn special_callsite_context(&mut self, callsite: &Rc, _callee: &FuncId) -> ContextId { - // Currently we treat all special callsites as statical callsites - self.ctx_strategy.new_static_call_context(callsite) + fn special_callsite_context(&mut self, callsite: &Rc, callee: &FuncId) -> ContextId { + // Currently we treat all special callsites as static callsites + self.ctx_strategy.new_static_call_context(callsite, *callee) } // Add new call edges to pag @@ -262,9 +262,11 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> ContextSensitivePTA<'pta, 'tc self.process_reach_funcs(); } + fn process_new_call_instances(&mut self, new_call_instances: &Vec<(Rc, Rc, FuncId)>) { + for (callsite, instance, callee_id) in new_call_instances { - if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance)) { + if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance), *callee_id) { let cs_callee = CSFuncId::new(callee_cid, *callee_id); self.add_call_edge(callsite, &cs_callee); } @@ -347,16 +349,32 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> PointerAnalysis<'tcx, 'compil for ContextSensitivePTA<'pta, 'tcx, 'compilation, S> { fn pre_analysis(&mut self) { - if !self.acx.analysis_options.stack_filtering { + let stack_filtering = self.acx.analysis_options.stack_filtering; + let rceus = self.acx.analysis_options.rceus; + if !stack_filtering && !rceus { return; } info!("Start pre-analysis"); + println!("Starting Rapid Type Analysis..."); let mut rta = RapidTypeAnalysis::new(&mut self.acx); rta.analyze(); self.pre_analysis_time += rta.analysis_time; - self.stack_filter = Some(StackFilter::new(rta.call_graph)); - self.ctx_strategy.with_stack_filter(self.stack_filter.as_mut().unwrap()); - self.pre_analysis_time += self.stack_filter.as_ref().unwrap().fra_time(); + + if rceus { + let mut pcfi = PrecCritFnIdent::new(&mut rta); + pcfi.analyze(); + self.pre_analysis_time += pcfi.analysis_time; + let cs_funcs = std::mem::take(&mut pcfi.cs_funcs); + let func_pfg_map = std::mem::take(&mut pcfi.func_pfg_map); + self.ctx_strategy.set_prec_crit_fn_ident_data(cs_funcs, func_pfg_map); + } + + if stack_filtering { + self.stack_filter = Some(StackFilter::new(rta.call_graph)); + self.ctx_strategy.with_stack_filter(self.stack_filter.as_mut().unwrap()); + self.pre_analysis_time += self.stack_filter.as_ref().unwrap().fra_time(); + } + println!("Pre-analysis time {}", humantime::format_duration(self.pre_analysis_time).to_string() ); @@ -369,7 +387,7 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> PointerAnalysis<'tcx, 'compil let empty_context_id = self.get_empty_context_id(); let entry_func_id = self.acx.get_func_id(entry_point, self.tcx().mk_args(&[])); self.call_graph.add_node(CSFuncId::new(empty_context_id, entry_func_id)); - + // process statements of reachable functions self.process_reach_funcs(); } @@ -397,8 +415,12 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> PointerAnalysis<'tcx, 'compil if new_calls.is_empty() && new_call_instances.is_empty() { break; } else { + // Note: RCEUS is for call-site sensitivity only, + // so new call instances become normal new calls there. self.process_new_calls(&new_calls); - self.process_new_call_instances(&new_call_instances); + if !self.acx.analysis_options.rceus { + self.process_new_call_instances(&new_call_instances); + } } } } @@ -407,9 +429,9 @@ impl<'pta, 'tcx, 'compilation, S: ContextStrategy> PointerAnalysis<'tcx, 'compil fn finalize(&self) { // dump call graph, points-to results results_dumper::dump_results(self.acx, &self.call_graph, &self.pt_data, &self.pag); - + // dump pta statistics let pta_stat = ContextSensitiveStat::new(self); pta_stat.dump_stats(); } -} +} \ No newline at end of file diff --git a/src/pta/mod.rs b/src/pta/mod.rs index 94a957a..5537d0f 100644 --- a/src/pta/mod.rs +++ b/src/pta/mod.rs @@ -12,7 +12,7 @@ use rustc_middle::ty::TyCtxt; use self::andersen::AndersenPTA; use self::context_sensitive::ContextSensitivePTA; -use self::strategies::context_strategy::KCallSiteSensitive; +use self::strategies::context_strategy::{KCallSiteSensitive, RCEUSCallSiteSensitive}; use crate::graph::pag::*; use crate::mir::function::FuncId; use crate::mir::analysis_context::AnalysisContext; @@ -61,7 +61,6 @@ pub trait PointerAnalysis<'tcx, 'compilation> { "Analysis time: {}", humantime::format_duration(elapsed).to_string() ); - self.finalize(); } } @@ -89,12 +88,12 @@ impl PTACallbacks { if let Some(mut acx) = AnalysisContext::new(&compiler.sess, tcx, self.options.clone()) { let mut pta: Box = match self.options.pta_type { PTAType::CallSiteSensitive => { - Box::new( - ContextSensitivePTA::new( - &mut acx, - KCallSiteSensitive::new(self.options.context_depth as usize) - ), - ) + let k = self.options.context_depth as usize; + if self.options.rceus { + Box::new(ContextSensitivePTA::new(&mut acx, RCEUSCallSiteSensitive::new(k))) + } else { + Box::new(ContextSensitivePTA::new(&mut acx, KCallSiteSensitive::new(k))) + } } PTAType::Andersen => Box::new(AndersenPTA::new(&mut acx)), }; diff --git a/src/pta/propagator/propagator.rs b/src/pta/propagator/propagator.rs index 7f45deb..9db4302 100644 --- a/src/pta/propagator/propagator.rs +++ b/src/pta/propagator/propagator.rs @@ -73,7 +73,7 @@ pub struct Propagator<'pta, 'tcx, 'compilation, F, P: PAGPath> { } impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> where - F: Copy + Into + std::cmp::Eq + std::hash::Hash + SFReachable, + F: Copy + Into + std::cmp::Eq + std::hash::Hash + SFReachable + std::fmt::Debug, P: PAGPath, { /// Constructor @@ -439,8 +439,14 @@ impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> replaced_args, ) { let func_id = self.acx.get_func_id(callee_def_id, gen_args); - // self.add_new_call(&dyn_callsite, &func_id); - self.add_new_call_instance(&dyn_callsite, &pointee_path, &func_id); + if self.acx.analysis_options.rceus { + // Since Rceus is for callsite-sensitive analyses only, + // we only add the callsite for dynamic dispatch calls. + self.add_new_call(&dyn_callsite, &func_id); + } else { + // For other context-insensitive or object-sensitive analyses, we add the call instance + self.add_new_call_instance(&dyn_callsite, &pointee_path, &func_id); + } } else { warn!( "Could not resolve function: {:?}, {:?}", @@ -672,8 +678,14 @@ impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> if self.tcx().is_mir_available(resolved_def_id) { // The pointee type cannot be FnDef, FnPtr, Closure, therefore its mir is supposed to be available let func_id = self.acx.get_func_id(resolved_def_id, instance_args); - // self.add_new_call(&dynamic_fntrait_callsite, &func_id); - self.add_new_call_instance(&dynamic_fntrait_callsite, &pointee_path, &func_id); + if self.acx.analysis_options.rceus { + // Since Rceus is for callsite-sensitive analyses only, + // we only add the callsite for dynamic dispatch calls. + self.add_new_call(&dynamic_fntrait_callsite, &func_id); + } else { + // For other context-insensitive or object-sensitive analyses, we add the call instance + self.add_new_call_instance(&dynamic_fntrait_callsite, &pointee_path, &func_id); + } } else { warn!("Unavailable mir for def_id: {:?}", resolved_def_id); } @@ -871,6 +883,7 @@ impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> continue; } } + if matches!(regularized_path.value(), PathEnum::HeapObj { .. }) { // For heap objects that have a concretized type, we do not let it been cast from // a simple type to other incompatible types. @@ -1153,10 +1166,10 @@ impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> } if let Some(sf) = stack_filter { if let Some(&edge_func) = sf.get_container_func_of_edge(&edge_id) { - return !sf.is_potentially_alive(acx, edge_func, pointee); + let rceus = acx.analysis_options.rceus; + return !sf.is_potentially_alive(acx, edge_func, pointee, rceus); } } - return false; } } @@ -1174,4 +1187,4 @@ impl<'pta, 'tcx, 'compilation, F, P> Propagator<'pta, 'tcx, 'compilation, F, P> } } } -} +} \ No newline at end of file diff --git a/src/pta/strategies/context_strategy.rs b/src/pta/strategies/context_strategy.rs index e8fae04..4ee649f 100644 --- a/src/pta/strategies/context_strategy.rs +++ b/src/pta/strategies/context_strategy.rs @@ -9,13 +9,17 @@ //! Only k-callsite-sensitive pointer analyses have been thoroughly evaluated so far. use std::rc::Rc; - +use std::collections::HashSet; +use std::collections::hash_map::Iter; +use std::collections::HashMap; use crate::mir::call_site::{BaseCallSite, CSCallSite}; use crate::mir::context::{Context, ContextCache, ContextElement, ContextId, HybridCtxElem}; use crate::mir::function::FuncId; use crate::mir::path::{CSPath, Path}; use crate::rustc_index::Idx; use super::stack_filtering::{StackFilter, SFReachable}; +use crate::pre_analysis::precision_critical_func_identification::func_pointer_flow_analysis::FuncPFG; + pub trait ContextStrategy { type E: ContextElement; @@ -23,12 +27,28 @@ pub trait ContextStrategy { fn get_empty_context_id(&mut self) -> ContextId; fn get_context_id(&mut self, context: &Rc>) -> ContextId; fn get_context_by_id(&self, context_id: ContextId) -> Rc>; - fn new_instance_call_context(&mut self, callsite: &Rc, receiver: Option<&Rc>) -> Option; - fn new_static_call_context(&mut self, callsite: &Rc) -> ContextId; - fn with_stack_filter(&mut self, _stack_filter: &mut StackFilter) - where - F: Copy + Into + std::cmp::Eq + std::hash::Hash, + fn get_context_iter(&self) -> Option>, ContextId>> { + None + } + fn new_instance_call_context( + &mut self, + callsite: &Rc, + receiver: Option<&Rc>, + callee: FuncId, + ) -> Option; + + fn new_static_call_context(&mut self, callsite: &Rc, callee: FuncId) -> ContextId; + + fn with_stack_filter(&mut self, _stack_filter: &mut StackFilter) + where + F: Copy + Into + std::cmp::Eq + std::hash::Hash, {} + + fn set_prec_crit_fn_ident_data( + &mut self, + _cs_funcs: HashSet, + _func_pfg_map: HashMap, + ) {} } pub struct ContextInsensitive {} @@ -53,19 +73,24 @@ impl ContextStrategy for ContextInsensitive { self.empty_context() } - fn new_instance_call_context(&mut self, _callsite: &Rc, _receiver: Option<&Rc>) -> Option { + fn new_instance_call_context( + &mut self, + _callsite: &Rc, + _receiver: Option<&Rc>, + _callee: FuncId, + ) -> Option { Some(ContextId::new(0)) } - fn new_static_call_context(&mut self, _callsite: &Rc) -> ContextId { + fn new_static_call_context(&mut self, _callsite: &Rc, _callee: FuncId) -> ContextId { ContextId::new(0) } } pub struct KCallSiteSensitive { /// Context length limit for methods - k: usize, - pub(crate) ctx_cache: ContextCache, + pub(crate) k: usize, + pub ctx_cache: ContextCache, } impl KCallSiteSensitive { @@ -83,7 +108,7 @@ impl KCallSiteSensitive { &caller_ctx, callsite.into(), self.k, - ); + ); let callee_ctx_id = self.ctx_cache.get_context_id(&callee_ctx); callee_ctx_id } @@ -108,16 +133,25 @@ impl ContextStrategy for KCallSiteSensitive { self.get_context_id(&Context::new_empty()) } - fn new_instance_call_context(&mut self, callsite: &Rc, _receiver: Option<&Rc>) -> Option { + fn get_context_iter(&self) -> Option>, ContextId>> { + Some(self.ctx_cache.get_context_iter()) + } + + fn new_instance_call_context( + &mut self, + callsite: &Rc, + _receiver: Option<&Rc>, + _callee: FuncId, + ) -> Option { Some(self.new_context(callsite)) } - fn new_static_call_context(&mut self, callsite: &Rc) -> ContextId { - self.new_context(callsite) + fn new_static_call_context(&mut self, callsite: &Rc, _callee: FuncId) -> ContextId { + self.new_context(callsite) } - fn with_stack_filter(&mut self, stack_filter: &mut StackFilter) - where + fn with_stack_filter(&mut self, stack_filter: &mut StackFilter) + where F: Copy + Into + std::cmp::Eq + std::hash::Hash, { stack_filter.with_kcs_context_strategy(self); @@ -171,7 +205,12 @@ impl ContextStrategy for KObjectSensitive { self.get_context_id(&Context::new_empty()) } - fn new_instance_call_context(&mut self, _callsite: &Rc, receiver: Option<&Rc>) -> Option { + fn new_instance_call_context( + &mut self, + _callsite: &Rc, + receiver: Option<&Rc>, + _callee: FuncId, + ) -> Option { if let Some(cs_path) = receiver { Some(self.new_context(cs_path.clone())) } else { @@ -179,7 +218,7 @@ impl ContextStrategy for KObjectSensitive { } } - fn new_static_call_context(&mut self, callsite: &Rc) -> ContextId { + fn new_static_call_context(&mut self, callsite: &Rc, _callee: FuncId) -> ContextId { // use the same context as the caller function callsite.func.cid } @@ -225,6 +264,7 @@ impl SimpleHybridContextSensitive { let callee_ctx_id = self.ctx_cache.get_context_id(&callee_ctx); callee_ctx_id } + } impl ContextStrategy for SimpleHybridContextSensitive { @@ -246,7 +286,12 @@ impl ContextStrategy for SimpleHybridContextSensitive { self.get_context_id(&Context::new_empty()) } - fn new_instance_call_context(&mut self, _callsite: &Rc, receiver: Option<&Rc>) -> Option { + fn new_instance_call_context( + &mut self, + _callsite: &Rc, + receiver: Option<&Rc>, + _callee: FuncId, + ) -> Option { if let Some(cs_path) = receiver { Some(self.new_instance_call_context(cs_path.clone())) } else { @@ -254,8 +299,112 @@ impl ContextStrategy for SimpleHybridContextSensitive { } } - fn new_static_call_context(&mut self, callsite: &Rc) -> ContextId { + fn new_static_call_context(&mut self, callsite: &Rc, _callee: FuncId) -> ContextId { // use the same context as the caller function self.new_static_call_context(callsite) } +} + + +/// Call-site-sensitive context strategy that applies the RCEUS context-augmentation +/// algorithm to precision-critical callees, and falls back to plain k-cfa for +/// non-PC callees (or PC callees whose caller has no PFG entry). +pub struct RCEUSCallSiteSensitive { + inner: KCallSiteSensitive, + cs_funcs: HashSet, + func_pfg_map: HashMap, +} + +impl RCEUSCallSiteSensitive { + pub fn new(k: usize) -> Self { + Self { + inner: KCallSiteSensitive::new(k), + cs_funcs: HashSet::new(), + func_pfg_map: HashMap::new(), + } + } + + /// Apply the RCEUS context-augmentation algorithm using the caller's PFG. + /// The first context element is forced to a "flow-entry" callsite, then the + /// k-limited tail is appended. + /// + /// Takes `&mut KCallSiteSensitive` rather than `&mut self` so callers can + /// hold an immutable borrow of `self.func_pfg_map` simultaneously (disjoint + /// field borrow). + fn rceus_context( + inner: &mut KCallSiteSensitive, + callsite: &Rc, + caller_pfg: &FuncPFG, + ) -> ContextId { + let caller_ctx = inner.get_context_by_id(callsite.func.cid); + let caller_ctx_elem = &caller_ctx.context_elems; + let callsite_location = callsite.location; + + let flow_entry = if !caller_pfg.is_cs_callsite(&callsite_location) { + // callsite_location ⇒ 𝑓 ∉ CTXFuncs — this callsite is the flow entry. + callsite.into() + } else { + // The first element of the caller context is always the flow-entry + // callsite from RCEUS; the unwrap is safe because any caller reaching + // here has a non-empty context. + caller_ctx_elem.first().unwrap().clone() + }; + + let mut new_callee_ctx_elem = vec![flow_entry]; + let callee_ctx = Context::new_k_limited_context(&caller_ctx, callsite.into(), inner.k); + new_callee_ctx_elem.extend(callee_ctx.context_elems.iter().cloned()); + let new_callee_ctx = Rc::new(Context { context_elems: new_callee_ctx_elem }); + inner.ctx_cache.get_context_id(&new_callee_ctx) + } +} + +impl ContextStrategy for RCEUSCallSiteSensitive { + type E = BaseCallSite; + + fn empty_context(&self) -> Rc> { self.inner.empty_context() } + fn get_empty_context_id(&mut self) -> ContextId { self.inner.get_empty_context_id() } + fn get_context_id(&mut self, context: &Rc>) -> ContextId { + self.inner.get_context_id(context) + } + fn get_context_by_id(&self, context_id: ContextId) -> Rc> { + self.inner.get_context_by_id(context_id) + } + fn get_context_iter(&self) -> Option>, ContextId>> { + self.inner.get_context_iter() + } + + fn new_static_call_context(&mut self, callsite: &Rc, callee: FuncId) -> ContextId { + if self.cs_funcs.contains(&callee) { + if let Some(caller_pfg) = self.func_pfg_map.get(&callsite.func.func_id) { + return Self::rceus_context(&mut self.inner, callsite, caller_pfg); + } + } + self.inner.new_static_call_context(callsite, callee) + } + + fn new_instance_call_context( + &mut self, + callsite: &Rc, + receiver: Option<&Rc>, + callee: FuncId, + ) -> Option { + if self.cs_funcs.contains(&callee) { + if let Some(caller_pfg) = self.func_pfg_map.get(&callsite.func.func_id) { + return Some(Self::rceus_context(&mut self.inner, callsite, caller_pfg)); + } + } + self.inner.new_instance_call_context(callsite, receiver, callee) + } + + fn with_stack_filter(&mut self, stack_filter: &mut StackFilter) + where + F: Copy + Into + std::cmp::Eq + std::hash::Hash, + { + self.inner.with_stack_filter(stack_filter); + } + + fn set_prec_crit_fn_ident_data(&mut self, cs_funcs: HashSet, func_pfg_map: HashMap) { + self.cs_funcs = cs_funcs; + self.func_pfg_map = func_pfg_map; + } } \ No newline at end of file diff --git a/src/pta/strategies/stack_filtering.rs b/src/pta/strategies/stack_filtering.rs index 72db716..c3261f5 100644 --- a/src/pta/strategies/stack_filtering.rs +++ b/src/pta/strategies/stack_filtering.rs @@ -19,7 +19,6 @@ use crate::pts_set::points_to::PointsToSet; use super::context_strategy::{KCallSiteSensitive, ContextStrategy}; - pub trait RowRelation: std::fmt::Debug + std::clone::Clone + std::marker::Send { fn new_empty() -> Self; fn with_capacity(capacity: usize) -> Self; @@ -214,8 +213,7 @@ impl StackFilter where { pub fn new(call_graph: CallGraph) -> Self { let now = Instant::now(); - let reach_relation = - FunctionReachabilityAnalysis::compute_func_reach_relations_mt(&call_graph); + let reach_relation = FunctionReachabilityAnalysis::compute_func_reach_relations_mt(&call_graph); let fra_time = now.elapsed(); StackFilter { call_graph, @@ -279,7 +277,8 @@ impl StackFilter where &self, acx: &AnalysisContext, current_func: F, - target_path: &P + target_path: &P, + rceus: bool ) -> bool { match target_path.value() { PathEnum::HeapObj { .. } => { return true; } @@ -306,6 +305,10 @@ impl StackFilter where return true; } + // We apply only naive reachability relation for Rceus analysis. + if rceus { + return self.naive_reachability_relation(path_container_func, current_func); + } return current_func.is_reachable_from(&path_container_func, self); } return true; @@ -384,7 +387,12 @@ impl SFReachable for CSFuncId { if match_suffix_and_prefix(&to_call_chain, &from_call_chain) { return true; - } + } + // else { + // println!("Call chain from {:?} to {:?} does not match, checking reachability relation", from, self); + // println!(" From call chain: {:?}", from_call_chain); + // println!(" To call chain: {:?}", to_call_chain); + // } } let from_id = stack_filter.call_graph.func_nodes.get(&(*from).into()).unwrap(); diff --git a/src/util/options.rs b/src/util/options.rs index b98036a..4280a66 100644 --- a/src/util/options.rs +++ b/src/util/options.rs @@ -51,8 +51,7 @@ fn make_options_parser() -> Command<'static> { .long("context-depth") .takes_value(true) .value_parser(clap::value_parser!(u32)) - .default_value("1") - .help("The context depth limit for a context-sensitive pointer analysis.")) + .help("The context depth limit for a context-sensitive pointer analysis. Defaults to 1, or 0 when --rceus is set.")) .arg(Arg::new("no-cast-constraint") .long("no-cast-constraint") .takes_value(false) @@ -62,6 +61,10 @@ fn make_options_parser() -> Command<'static> { .long("stack-filtering") .takes_value(false) .help("Enable stack filtering in pointer analysis.")) + .arg(Arg::new("rceus") + .long("rceus") + .takes_value(false) + .help("Enable rceus in pointer analysis.")) .arg(Arg::new("dump-stats") .long("dump-stats") .takes_value(false) @@ -111,7 +114,8 @@ pub struct AnalysisOptions { // options for handling cast propagation pub cast_constraint: bool, pub stack_filtering: bool, - + pub rceus: bool, + pub dump_stats: bool, pub call_graph_output: Option, pub pts_output: Option, @@ -131,6 +135,7 @@ impl Default for AnalysisOptions { context_depth: 1, cast_constraint: true, stack_filtering: false, + rceus: false, dump_stats: true, call_graph_output: None, pts_output: None, @@ -209,13 +214,16 @@ impl AnalysisOptions { } } + self.cast_constraint = !matches.contains_id("no-cast-constraint"); + self.stack_filtering = matches.contains_id("stack-filtering"); + self.rceus = matches.contains_id("rceus"); + if let Some(depth) = matches.get_one::("context-depth") { self.context_depth = *depth; + } else if self.rceus { + self.context_depth = 0; } - self.cast_constraint = !matches.contains_id("no-cast-constraint"); - self.stack_filtering = matches.contains_id("stack-filtering"); - self.dump_stats = matches.contains_id("dump-stats"); self.call_graph_output = matches.get_one::("call-graph-output").cloned(); self.pts_output = matches.get_one::("pts-output").cloned();